赞
踩
个人理解学习,仅供参考!
总体流程如下:
备注如下:
if __name__ == '__main__': config, args = parse_arg()#初始化配置信息 device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')#选用设备 model = crnn.get_crnn(config).to(device)#构建model print('loading pretrained model from {0}'.format(args.checkpoint)) checkpoint = torch.load(args.checkpoint)#导入训练好的权重 if 'state_dict' in checkpoint.keys(): model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) started = time.time() img = cv2.imread(args.image_path)#读取原始输入图像 img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)#对输入图像做灰度处理 converter = utils.strLabelConverter(config.DATASET.ALPHABETS)#字典信息 #sim_pred1=recognition(config, img, model, converter, device) #获取识别结果 recognition(config, img, model, converter, device) #print('my resoult is =',sim_pred1) finished = time.time() print('elapsed time: {0}'.format(finished - started))
def recognition(config, img, model, converter, device): # github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211 h, w = img.shape # fisrt step: resize the height and width of image to (32, x) #按照高32调整尺寸 img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC) # second step: keep the ratio of image's text same with training #这一步的尺寸调整是为了与train中的尺寸匹配,一方面这一步也固定了尺寸,不利于不定长文本的情况。修改方法参考上面的链接 h, w = img.shape w_cur = int(img.shape[1] / (config.MODEL.IMAGE_SIZE.OW / config.MODEL.IMAGE_SIZE.W)) img = cv2.resize(img, (0, 0), fx=w_cur / w, fy=1.0, interpolation=cv2.INTER_CUBIC) img = np.reshape(img, (config.MODEL.IMAGE_SIZE.H, w_cur, 1)) # normalize数据归一化 img = img.astype(np.float32) img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD img = img.transpose([2, 0, 1]) img = torch.from_numpy(img) img = img.to(device) #img.view后,增加一维,img成为4维数据 img = img.view(1, *img.size()) model.eval() preds = model(img) print(preds.shape) _, preds = preds.max(2) #获取识别结果,但是现在还是编码结果,需要解码成为字符 preds = preds.transpose(1, 0).contiguous().view(-1) preds_size = Variable(torch.IntTensor([preds.size(0)])) #解码识别结果,成为字符。 sim_pred = converter.decode(preds.data, preds_size.data, raw=False) print('results: {0}'.format(sim_pred)) #return sim_pred
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。