当前位置:   article > 正文

【OCR】【专题系列】四、基于RCNN-CTC的不定长文本识别

cnn-ctc

【OCR】【专题系列】四、基于RCNN-CTC的不定长文本识别

目录

一、论文阅读

二、代码实现

三、结果讨论


一、论文阅读

        在上篇博客《【OCR】基于图像分类的定长文本识别》中,通过图像像素分类的方法实现固定图片的识别方法。本篇主要是针对OCR经典论文《An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition》代码复现和实验结果分析。
        论文的网络结构如下图所示:

 图1 CRNN-CTC网络结构图

         网络结构主要包括CNN和BiLSTM两部分构成,CNN主要用于图像特征信息提取,BiLSTM连接语义信息,最后通过CTCLoss损失用于约束不定长文本连续的错误识别。在开源代码的基础上,本文针对自己已有数据集复现了代码、做了小规模实验,局部测试了模型效果。

二、代码实现

        本文代码结构承接上文,模型结构通过Model类完成,数据通过MyDataset类+collate_fn完成,相关配置通过configs完成配置。在模型定义中通过pytorch实现CRNN-CTC的模型,损失函数采用torch.nn.ctcloss,所用词表可通过字符串按顺序构建。下述为代码实现,修改对应配置项即可跑通复现实验。

  1. from torch.utils.data import Dataset
  2. from torch import nn as nn
  3. import torchvision.transforms as T
  4. import torch.nn.functional as F
  5. from torch.utils.data import DataLoader
  6. import os
  7. import torch
  8. from PIL import Image
  9. from tqdm import tqdm
  10. import numpy as np
  11. class configs():
  12. def __init__(self):
  13. #Data
  14. self.data_dir = './captcha_datasets'
  15. self.train_dir = 'train-data'
  16. self.valid_dir = 'valid-data'
  17. self.test_dir = 'test-data-1'
  18. self.save_model_dir = 'models_ocr'
  19. self.get_lexicon_dir = './lbl2id_map.txt'
  20. self.img_transform = T.Compose([
  21. T.Resize((32, 100)),
  22. T.ToTensor(),
  23. T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  24. ])
  25. # self.lexicon = self.get_lexicon(lexicon_name=self.get_lexicon_dir)
  26. self.lexicon = "0123456789"+"_"
  27. self.all_chars = {v: k for k, v in enumerate(self.lexicon)}
  28. self.all_nums = {v: k for v, k in enumerate(self.lexicon)}
  29. self.class_num = len(self.lexicon)
  30. self.label_word_length = 4
  31. #train
  32. self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  33. self.batch_size = 64
  34. self.epoch = 31
  35. self.save_model_fre_epoch = 1
  36. self.nh = 128 # 隐层数量
  37. self.istrain = True
  38. self.istest = True
  39. def get_lexicon(self,lexicon_name):
  40. '''
  41. #获取词表 lbl2id_map.txt',词表格式如下
  42. #0\t0\n
  43. #a\t1\n
  44. #...
  45. #z\t63\n
  46. :param lexicons_name:
  47. :return:
  48. '''
  49. lexicons = open(lexicon_name, 'r', encoding='utf-8').readlines()
  50. lexicons_str = ''.join(word[0].split('\t')[0] for word in lexicons)
  51. return lexicons_str
  52. cfg = configs()
  53. #model define
  54. class BidirectionalLSTM(nn.Module):
  55. def __init__(self, nIn, nHidden, nOut):
  56. super(BidirectionalLSTM, self).__init__()
  57. self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
  58. self.embedding = nn.Linear(nHidden * 2, nOut)
  59. def forward(self, input):
  60. recurrent, _ = self.rnn(input)
  61. T, b, h = recurrent.size()
  62. t_rec = recurrent.view(T * b, h)
  63. output = self.embedding(t_rec) # [T * b, nOut]
  64. output = output.view(T, b, -1)
  65. return output
  66. class Model(nn.Module):
  67. def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
  68. super(Model, self).__init__()
  69. assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
  70. ks = [3, 3, 3, 3, 3, 3, 2]
  71. ps = [1, 1, 1, 1, 1, 1, 0]
  72. ss = [1, 1, 1, 1, 1, 1, 1]
  73. nm = [64, 128, 256, 256, 512, 512, 512]
  74. cnn = nn.Sequential()
  75. def convRelu(i, batchNormalization=False):
  76. nIn = nc if i == 0 else nm[i - 1]
  77. nOut = nm[i]
  78. cnn.add_module('conv{0}'.format(i),
  79. nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
  80. if batchNormalization:
  81. cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
  82. if leakyRelu:
  83. cnn.add_module('relu{0}'.format(i),
  84. nn.LeakyReLU(0.2, inplace=True))
  85. else:
  86. cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
  87. convRelu(0)
  88. cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
  89. convRelu(1)
  90. cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
  91. convRelu(2, True)
  92. convRelu(3)
  93. cnn.add_module('pooling{0}'.format(2),
  94. nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
  95. convRelu(4, True)
  96. convRelu(5)
  97. cnn.add_module('pooling{0}'.format(3),
  98. nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
  99. convRelu(6, True) # 512x1x16
  100. self.cnn = cnn
  101. self.rnn = nn.Sequential(
  102. BidirectionalLSTM(512, nh, nh),
  103. BidirectionalLSTM(nh, nh, nclass))
  104. def forward(self, input):
  105. # conv features
  106. conv = self.cnn(input)
  107. b, c, h, w = conv.size()
  108. assert h == 1, "the height of conv must be 1"
  109. conv = conv.squeeze(2)
  110. conv = conv.permute(2, 0, 1) # [w, b, c]
  111. # rnn features
  112. output = self.rnn(conv)
  113. # add log_softmax to converge output
  114. output = F.log_softmax(output, dim=2)
  115. output_lengths = torch.full(size=(output.size(1),), fill_value=output.size(0), dtype=torch.long,
  116. device=cfg.device)
  117. return output, output_lengths
  118. def backward_hook(self, module, grad_input, grad_output):
  119. for g in grad_input:
  120. g[g != g] = 0 # replace all nan/inf in gradients to zero
  121. #dataset define
  122. class MyDataset(Dataset):
  123. def __init__(self, path: str, transform=None, ):
  124. if transform == None:
  125. self.transform = T.Compose(
  126. [
  127. T.ToTensor()
  128. ])
  129. else:
  130. self.transform = transform
  131. self.path = path
  132. self.picture_list = list(os.walk(self.path))[0][-1]
  133. def __len__(self):
  134. return len(self.picture_list)
  135. def __getitem__(self, item):
  136. """
  137. :param item: ID
  138. :return: (图片,标签)
  139. """
  140. picture_path_list = self._load_picture()
  141. img = Image.open(picture_path_list[item]).convert("RGB")
  142. img = self.transform(img)
  143. label = os.path.splitext(self.picture_list[item])[0].split("_")[1]
  144. label = [[cfg.all_chars[i]] for i in label]
  145. label = torch.as_tensor(label, dtype=torch.int64)
  146. return img, label
  147. def _load_picture(self):
  148. return [self.path + '/' + i for i in self.picture_list]
  149. def collate_fn(batch):
  150. sequence_lengths = []
  151. max_width, max_height = 0, 0
  152. for image, label in batch:
  153. if image.size(1) > max_height:
  154. max_height = image.size(1)
  155. if image.size(2) > max_width:
  156. max_width = image.size(2)
  157. sequence_lengths.append(label.size(0))
  158. seq_lengths = torch.LongTensor(sequence_lengths)
  159. seq_tensor = torch.zeros(seq_lengths.size(0), seq_lengths.max()).long()
  160. img_tensor = torch.zeros(seq_lengths.size(0), 3, max_height, max_width)
  161. for idx, (image, label) in enumerate(batch):
  162. seq_tensor[idx, :label.size(0)] = torch.squeeze(label)
  163. img_tensor[idx, :, :image.size(1), :image.size(2)] = image
  164. return img_tensor, seq_tensor, seq_lengths
  165. class ocr():
  166. def train(self):
  167. model = Model(imgH = 32,nc = 3, nclass = cfg.class_num, nh = cfg.nh)
  168. model = model.to(cfg.device)
  169. criterion = torch.nn.CTCLoss(blank=cfg.class_num - 1, zero_infinity=True)
  170. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  171. model.train()
  172. # train dataset
  173. train_dataset = MyDataset(os.path.join(cfg.data_dir, cfg.train_dir),
  174. transform=cfg.img_transform) # 训练路径以及transform
  175. train_loader = DataLoader(dataset=train_dataset, batch_size=cfg.batch_size, shuffle=True,drop_last=True,num_workers=0, collate_fn=collate_fn)
  176. for epoch in range(cfg.epoch):
  177. bar = tqdm(enumerate(train_loader,0))
  178. loss_sum = []
  179. total = 0
  180. correct = 0
  181. for idx, (images, labels,label_lengths) in bar:
  182. images, labels, label_lengths = images.to(cfg.device), \
  183. labels.to(cfg.device), \
  184. label_lengths.to(cfg.device)
  185. optimizer.zero_grad()
  186. outputs, output_lengths = model(images)
  187. loss = criterion(outputs, labels, output_lengths, label_lengths)
  188. loss.backward()
  189. optimizer.step()
  190. loss_sum.append(loss.item())
  191. c, t = self.calculat_train_acc(outputs, labels, label_lengths)
  192. correct +=c
  193. total += t
  194. bar.set_description("epcoh:{} idx:{},loss:{:.6f},acc:{:.6f}".format(epoch, idx, np.mean(loss_sum),100 * correct / total))
  195. if epoch%cfg.save_model_fre_epoch ==0:
  196. torch.save(model.state_dict(), os.path.join(cfg.save_model_dir,"epoch_"+str(epoch)+'.pkl'), _use_new_zipfile_serialization=True) # 模型保存
  197. torch.save(optimizer.state_dict(), os.path.join(cfg.save_model_dir,"epoch_"+str(epoch)+"_opti"+'.pkl'), _use_new_zipfile_serialization=True) # 优化器保存
  198. def infer(self):
  199. for modelname in os.listdir(cfg.save_model_dir):
  200. #model define
  201. train_weights_path = os.path.join(cfg.save_model_dir, modelname)
  202. train_weights_dict = torch.load(train_weights_path)
  203. model = Model(imgH=32, nc=3, nclass=cfg.class_num, nh=cfg.nh)
  204. model.load_state_dict(train_weights_dict, strict=True)
  205. model = model.to(cfg.device)
  206. model.eval()
  207. #test dataset
  208. test_dataset = MyDataset(os.path.join(cfg.data_dir, cfg.test_dir), transform=cfg.img_transform) # 训练路径以及transform
  209. test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
  210. total = 0
  211. correct = 0
  212. results = []
  213. for idx,(images, labels,label_lengths) in enumerate(test_loader,0):
  214. labels = torch.squeeze(labels).to(cfg.device)
  215. with torch.no_grad():
  216. predicts,output_lengths = model(images.to(cfg.device))
  217. c, t, result = self.calculat_infer_acc(predicts, labels, label_lengths)
  218. correct += c
  219. total += t
  220. results.append(result)
  221. print("model name: "+modelname+'\t'+"|| acc: "+str(correct / total)+'\n')
  222. # 计算训练准确率
  223. def calculat_train_acc(self,output, target, target_lengths):
  224. output = torch.argmax(output, dim=-1)
  225. output = output.permute(1, 0)
  226. correct_num = 0
  227. for predict, label, label_length in zip(output, target, target_lengths):
  228. predict = torch.unique_consecutive(predict)
  229. predict = predict[predict != (cfg.class_num - 1)]
  230. if (predict.size()[0] == label_length.item()
  231. and (predict == label[:label_length.item()]).all()):
  232. correct_num += 1
  233. return correct_num, target.size(0)
  234. #计算推理准确率
  235. def calculat_infer_acc(self,output, target, target_lengths):
  236. output = torch.argmax(output, dim=-1)
  237. output = output.permute(1, 0)
  238. correct_num = 0
  239. total_num = 0
  240. predict_list = []
  241. for predict, label, label_length in zip(output, target, target_lengths):
  242. total_num +=1
  243. predict = torch.unique_consecutive(predict)
  244. predict = predict[predict != (cfg.class_num - 1)]
  245. predict_list = predict.cpu().tolist()
  246. label_list = target.cpu().tolist()
  247. if predict_list == label_list:
  248. correct_num += 1
  249. if predict_list == []:
  250. predict_str = '____'
  251. else:
  252. predict_str = ''.join([cfg.all_nums[s] for s in predict_list])
  253. label_str = ''.join([cfg.all_nums[s] for s in label_list])
  254. return correct_num, total_num,','.join([predict_str,label_str])
  255. if __name__ == '__main__':
  256. myocr = ocr()
  257. if cfg.istrain == True:
  258. myocr.train()
  259. if cfg.istest == True:
  260. myocr.infer()

三、结果讨论

        本文采用captcha_datasets数据集作为实验数据集,训练集:验证集:测试集=25000:10000:10000,图片内容主要是数字验证码。在本次实验中采用30次迭代测试模型效果,train-ctcloss、train-acc、test-acc效果如下表所示

epochlosstrain-accval/test-acc
12.77256900
20.9579330.459975960.7438
30.0384660.969871790.9706
40.0183370.9843750.9653
50.014490.987660260.9836
100.0080080.992467950.9714
150.0023880.997596150.9941
200.0048450.995833330.9952
250.0014620.998637820.9867
300.0031540.997676280.9949

        部分识别效果图展示:

 图 识别效果实例图

        由上述的训练过程可以看出,ctcloss在5次迭代后就有了较好的识别效果。原因是数据量较小、数据质量较单一,可以期待在更大数据集上的识别效果。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/454474
推荐阅读
相关标签
  

闽ICP备14008679号