赞
踩
从零写CRNN文字识别 —— (1)准备工作
从零写CRNN文字识别 —— (2)准备配置文件
从零写CRNN文字识别 —— (3)数据加载器
从零写CRNN文字识别 —— (4)搭建模型
从零写CRNN文字识别 —— (5)优化器和Loss
从零写CRNN文字识别 —— (6)训练
完整代码已经上传github:https://github.com/xmy0916/pytorch_crnn
训练部分的代码逻辑如下:
for epoch in range(total_epoch):
for data in dataloader:
数据输入模型(前馈)
根据输出计算loss
loss反馈更新网络参数
if epoch % eval_epoch == 0:
评估数据输入模型(前馈)
根据输出计算loss
解码输出计算识别准确率
if now_acc > best_acc:
保存模型
对应的完整代码如下:
# 训练 best_acc = 0.0 for epoch in range(last_epoch,config.TRAIN.END_EPOCH): model.train() for i, (inp, idx) in enumerate(train_loader): # 前馈 inp = inp.to(device) preds = model(inp).to(device) # 计算loss labels = get_batch_label(train_dataset, idx) batch_size = inp.size(0) text, length = encode(config.DICT,labels) preds_size = torch.IntTensor([preds.size(0)] * batch_size) loss = criterion(preds, text, preds_size, length) # 反馈 optimizer.zero_grad() loss.backward() optimizer.step() if i % config.PRINT_FREQ == 0: print("epoch:{} step:{} loss:{} lr:{}".format(epoch,i,loss.item(),lr_scheduler.get_lr())) # 每个epoch更新学习率 lr_scheduler.step() # 每EVAL_FREQ评估一次并保存best模型 if epoch % config.EVAL_FREQ == 0: model.eval() n_correct = 0 test_num = len(val_loader) * config.TEST.BATCH_SIZE_PER_GPU with torch.no_grad(): for i, (inp, idx) in enumerate(val_loader): # 计算前馈 inp = inp.to(device) preds = model(inp).cpu() # 计算loss labels = get_batch_label(val_dataset, idx) batch_size = inp.size(0) text, length = encode(config.DICT,labels) preds_size = torch.IntTensor([preds.size(0)] * batch_size) loss = criterion(preds, text, preds_size, length) # 后处理解码 print("网络输出的preds的shape:",preds.cpu().detach().shape) _, preds = preds.max(2) print("max(2)的shape:",preds.cpu().detach().shape) preds = preds.transpose(1, 0).contiguous().view(-1) print("transpose的shape:",preds.cpu().detach().shape) sim_preds = decode(preds.data, preds_size.data, config.DICT,raw=False) for pred, target in zip(sim_preds, labels): if pred == target: n_correct += 1 # 抓一个batch来显示 raw_preds = decode(preds.data, preds_size.data, config.DICT, raw=True)[:config.TEST.NUM_TEST_DISP] for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels): print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt)) print("preds:",preds.cpu().detach().numpy()) print("preds_shape:",preds.cpu().detach().shape) print("dict:",config.DICT) now_acc = n_correct * 1.0 / test_num print("best_acc:{} correct:{}".format(now_acc,n_correct)) if now_acc >= best_acc: torch.save( { "state_dict": model.state_dict(), "epoch": epoch + 1, # "optimizer": optimizer.state_dict(), # "lr_scheduler": lr_scheduler.state_dict(), "best_acc": best_acc, }, os.path.join(config.OUTPUT_DIR, "checkpoint_{}_acc_{:.4f}.pth".format(epoch, now_acc))) best_acc = now_acc print("save_model!")
看看评估过程(摘一段代码出来):
preds = model(inp).cpu()
# 计算loss
labels = get_batch_label(val_dataset, idx)
batch_size = inp.size(0)
text, length = encode(config.DICT,labels)
preds_size = torch.IntTensor([preds.size(0)] * batch_size)
loss = criterion(preds, text, preds_size, length)
# 后处理解码
print("网络输出的preds的shape:",preds.cpu().detach().shape)
_, preds = preds.max(2)
print("max(2)的shape:",preds.cpu().detach().shape)
preds = preds.transpose(1, 0).contiguous().view(-1)
print("transpose的shape:",preds.cpu().detach().shape)
打印结果:
稍微解释下:
preds的shape[41,16,109]:
preds.max(2)得到了从属于那一类的向量,2表示在109的纬度上取所以输出的shape是[41,16]
transpose是把二维向量拉平,656=41*16
这里注意一点,测试的时候每个batch_size是16,但是我们数据集不一定是16的整数倍,所以最后一个batch的大小不一定有16,例如我们这里最后一个batch的大小是14:
在代码中我将最后一个batch的测试图片可视化的打印了,结果如下:
这是第一个epoch训练的输出,
上图的横杠是设置的空字符的占位符,在config/config.yml中设置这个字符BLANK_CHAR
上图一共574个0,574 = 41 * 14因为是最后一个batch所以不够16个,上图理论上可以解码成574个字符,因为这是第一个epoch训练的结果,网络参数基本不对所以没有输出。
第16个epoch输出如下:
第一行的37这个值就是dict中L的位置
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。