当前位置:   article > 正文

从零写CRNN文字识别 —— (6)训练_crnn 训练

crnn 训练

目录

从零写CRNN文字识别 —— (1)准备工作
从零写CRNN文字识别 —— (2)准备配置文件
从零写CRNN文字识别 —— (3)数据加载器
从零写CRNN文字识别 —— (4)搭建模型
从零写CRNN文字识别 —— (5)优化器和Loss
从零写CRNN文字识别 —— (6)训练

前言

完整代码已经上传githubhttps://github.com/xmy0916/pytorch_crnn

训练

训练部分的代码逻辑如下:

for epoch in range(total_epoch):
  for data in dataloader:
    数据输入模型(前馈)
    根据输出计算loss
    loss反馈更新网络参数
  if epoch % eval_epoch == 0:
    评估数据输入模型(前馈)
    根据输出计算loss
    解码输出计算识别准确率
    if now_acc > best_acc:
      保存模型
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

对应的完整代码如下:

# 训练
    best_acc = 0.0
    for epoch in range(last_epoch,config.TRAIN.END_EPOCH):
      model.train()
      for i, (inp, idx) in enumerate(train_loader):
          # 前馈
          inp = inp.to(device)
          preds = model(inp).to(device)
          # 计算loss
          labels = get_batch_label(train_dataset, idx)
          batch_size = inp.size(0)
          text, length = encode(config.DICT,labels)
          preds_size = torch.IntTensor([preds.size(0)] * batch_size)
          loss = criterion(preds, text, preds_size, length)
          # 反馈
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          if i % config.PRINT_FREQ == 0:
            print("epoch:{} step:{} loss:{} lr:{}".format(epoch,i,loss.item(),lr_scheduler.get_lr()))
      # 每个epoch更新学习率
      lr_scheduler.step()

      # 每EVAL_FREQ评估一次并保存best模型
      if epoch % config.EVAL_FREQ == 0:
          model.eval()
          n_correct = 0
          test_num = len(val_loader) * config.TEST.BATCH_SIZE_PER_GPU
          with torch.no_grad():
              for i, (inp, idx) in enumerate(val_loader):
                  # 计算前馈
                  inp = inp.to(device)
                  preds = model(inp).cpu()
                  # 计算loss
                  labels = get_batch_label(val_dataset, idx)
                  batch_size = inp.size(0)
                  text, length = encode(config.DICT,labels)
                  preds_size = torch.IntTensor([preds.size(0)] * batch_size)
                  loss = criterion(preds, text, preds_size, length)
                  # 后处理解码
                  print("网络输出的preds的shape:",preds.cpu().detach().shape)
                  _, preds = preds.max(2)
                  print("max(2)的shape:",preds.cpu().detach().shape)
                  preds = preds.transpose(1, 0).contiguous().view(-1)
                  print("transpose的shape:",preds.cpu().detach().shape)
                  sim_preds = decode(preds.data, preds_size.data, config.DICT,raw=False)
                  for pred, target in zip(sim_preds, labels):
                    if pred == target:
                      n_correct += 1

              
          # 抓一个batch来显示
          raw_preds = decode(preds.data, preds_size.data, config.DICT, raw=True)[:config.TEST.NUM_TEST_DISP]
          for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
              print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
          print("preds:",preds.cpu().detach().numpy())
          print("preds_shape:",preds.cpu().detach().shape)
          print("dict:",config.DICT)
          now_acc = n_correct * 1.0 / test_num
          print("best_acc:{} correct:{}".format(now_acc,n_correct))
          if now_acc >= best_acc:
              torch.save(
                    {
                        "state_dict": model.state_dict(),
                        "epoch": epoch + 1,
                        # "optimizer": optimizer.state_dict(),
                        # "lr_scheduler": lr_scheduler.state_dict(),
                        "best_acc": best_acc,
                    },  os.path.join(config.OUTPUT_DIR, "checkpoint_{}_acc_{:.4f}.pth".format(epoch, now_acc)))
              best_acc = now_acc
              print("save_model!")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71

看看评估过程(摘一段代码出来):

preds = model(inp).cpu()
# 计算loss
labels = get_batch_label(val_dataset, idx)
batch_size = inp.size(0)
text, length = encode(config.DICT,labels)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
# 后处理解码
print("网络输出的preds的shape:",preds.cpu().detach().shape)
_, preds = preds.max(2)
print("max(2)的shape:",preds.cpu().detach().shape)
preds = preds.transpose(1, 0).contiguous().view(-1)
print("transpose的shape:",preds.cpu().detach().shape)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

打印结果:
在这里插入图片描述
稍微解释下:
preds的shape[41,16,109]:

  • 41是卷积后的长度
  • 16是测试时的batch_size大小
  • 109是字典的类别数

preds.max(2)得到了从属于那一类的向量,2表示在109的纬度上取所以输出的shape是[41,16]
transpose是把二维向量拉平,656=41*16
这里注意一点,测试的时候每个batch_size是16,但是我们数据集不一定是16的整数倍,所以最后一个batch的大小不一定有16,例如我们这里最后一个batch的大小是14:
在这里插入图片描述
在代码中我将最后一个batch的测试图片可视化的打印了,结果如下:
在这里插入图片描述
这是第一个epoch训练的输出,
在这里插入图片描述
上图的横杠是设置的空字符的占位符,在config/config.yml中设置这个字符BLANK_CHAR
在这里插入图片描述
上图一共574个0,574 = 41 * 14因为是最后一个batch所以不够16个,上图理论上可以解码成574个字符,因为这是第一个epoch训练的结果,网络参数基本不对所以没有输出。

第16个epoch输出如下:
在这里插入图片描述
第一行的37这个值就是dict中L的位置
在这里插入图片描述

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号