当前位置:   article > 正文

rcnn代码—demo.py

rcnn代码

rcnn代码—demo.py

个人理解学习,仅供参考!

一、总体流程

总体流程如下:
在这里插入图片描述

二、main主函数

备注如下:

if __name__ == '__main__':

    config, args = parse_arg()#初始化配置信息
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')#选用设备
    
    model = crnn.get_crnn(config).to(device)#构建model
    print('loading pretrained model from {0}'.format(args.checkpoint))
    checkpoint = torch.load(args.checkpoint)#导入训练好的权重
    if 'state_dict' in checkpoint.keys():
        model.load_state_dict(checkpoint['state_dict'])
    else:
        model.load_state_dict(checkpoint)

    started = time.time()

    img = cv2.imread(args.image_path)#读取原始输入图像
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)#对输入图像做灰度处理
    converter = utils.strLabelConverter(config.DATASET.ALPHABETS)#字典信息

    #sim_pred1=recognition(config, img, model, converter, device)
    #获取识别结果
    recognition(config, img, model, converter, device)
    #print('my resoult is =',sim_pred1)

    finished = time.time()
    print('elapsed time: {0}'.format(finished - started))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

三、recognition(config, img, model, converter, device)函数

def recognition(config, img, model, converter, device):

    # github issues: https://github.com/Sierkinhane/CRNN_Chinese_Characters_Rec/issues/211
    h, w = img.shape
    # fisrt step: resize the height and width of image to (32, x)
    #按照高32调整尺寸
    img = cv2.resize(img, (0, 0), fx=config.MODEL.IMAGE_SIZE.H / h, fy=config.MODEL.IMAGE_SIZE.H / h, interpolation=cv2.INTER_CUBIC)

    # second step: keep the ratio of image's text same with training
    #这一步的尺寸调整是为了与train中的尺寸匹配,一方面这一步也固定了尺寸,不利于不定长文本的情况。修改方法参考上面的链接
    h, w = img.shape
    w_cur = int(img.shape[1] / (config.MODEL.IMAGE_SIZE.OW / config.MODEL.IMAGE_SIZE.W))
    img = cv2.resize(img, (0, 0), fx=w_cur / w, fy=1.0, interpolation=cv2.INTER_CUBIC)
    img = np.reshape(img, (config.MODEL.IMAGE_SIZE.H, w_cur, 1))

    # normalize数据归一化
    img = img.astype(np.float32)
    img = (img / 255. - config.DATASET.MEAN) / config.DATASET.STD
    img = img.transpose([2, 0, 1])
    img = torch.from_numpy(img)

    img = img.to(device)
    #img.view后,增加一维,img成为4维数据
    img = img.view(1, *img.size())
    model.eval()
    preds = model(img)
    print(preds.shape)
    _, preds = preds.max(2)
    #获取识别结果,但是现在还是编码结果,需要解码成为字符
    preds = preds.transpose(1, 0).contiguous().view(-1)

    preds_size = Variable(torch.IntTensor([preds.size(0)]))
    #解码识别结果,成为字符。
    sim_pred = converter.decode(preds.data, preds_size.data, raw=False)

    print('results: {0}'.format(sim_pred))
    #return sim_pred
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
'
运行
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/代码探险家/article/detail/978355
推荐阅读
相关标签
  

闽ICP备14008679号