赞
踩
【OCR】【专题系列】四、基于RCNN-CTC的不定长文本识别
目录
在上篇博客《【OCR】基于图像分类的定长文本识别》中,通过图像像素分类的方法实现固定图片的识别方法。本篇主要是针对OCR经典论文《An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition》代码复现和实验结果分析。
论文的网络结构如下图所示:
图1 CRNN-CTC网络结构图
网络结构主要包括CNN和BiLSTM两部分构成,CNN主要用于图像特征信息提取,BiLSTM连接语义信息,最后通过CTCLoss损失用于约束不定长文本连续的错误识别。在开源代码的基础上,本文针对自己已有数据集复现了代码、做了小规模实验,局部测试了模型效果。
本文代码结构承接上文,模型结构通过Model类完成,数据通过MyDataset类+collate_fn完成,相关配置通过configs完成配置。在模型定义中通过pytorch实现CRNN-CTC的模型,损失函数采用torch.nn.ctcloss,所用词表可通过字符串按顺序构建。下述为代码实现,修改对应配置项即可跑通复现实验。
- from torch.utils.data import Dataset
- from torch import nn as nn
- import torchvision.transforms as T
- import torch.nn.functional as F
- from torch.utils.data import DataLoader
- import os
- import torch
- from PIL import Image
- from tqdm import tqdm
- import numpy as np
-
-
- class configs():
- def __init__(self):
- #Data
- self.data_dir = './captcha_datasets'
- self.train_dir = 'train-data'
- self.valid_dir = 'valid-data'
- self.test_dir = 'test-data-1'
- self.save_model_dir = 'models_ocr'
- self.get_lexicon_dir = './lbl2id_map.txt'
- self.img_transform = T.Compose([
- T.Resize((32, 100)),
- T.ToTensor(),
- T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
- ])
- # self.lexicon = self.get_lexicon(lexicon_name=self.get_lexicon_dir)
- self.lexicon = "0123456789"+"_"
- self.all_chars = {v: k for k, v in enumerate(self.lexicon)}
- self.all_nums = {v: k for v, k in enumerate(self.lexicon)}
- self.class_num = len(self.lexicon)
- self.label_word_length = 4
-
- #train
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- self.batch_size = 64
- self.epoch = 31
- self.save_model_fre_epoch = 1
- self.nh = 128 # 隐层数量
-
- self.istrain = True
- self.istest = True
-
- def get_lexicon(self,lexicon_name):
- '''
- #获取词表 lbl2id_map.txt',词表格式如下
- #0\t0\n
- #a\t1\n
- #...
- #z\t63\n
- :param lexicons_name:
- :return:
- '''
- lexicons = open(lexicon_name, 'r', encoding='utf-8').readlines()
- lexicons_str = ''.join(word[0].split('\t')[0] for word in lexicons)
- return lexicons_str
-
- cfg = configs()
-
- #model define
- class BidirectionalLSTM(nn.Module):
-
- def __init__(self, nIn, nHidden, nOut):
- super(BidirectionalLSTM, self).__init__()
-
- self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
- self.embedding = nn.Linear(nHidden * 2, nOut)
-
- def forward(self, input):
- recurrent, _ = self.rnn(input)
- T, b, h = recurrent.size()
- t_rec = recurrent.view(T * b, h)
-
- output = self.embedding(t_rec) # [T * b, nOut]
- output = output.view(T, b, -1)
-
- return output
-
- class Model(nn.Module):
- def __init__(self, imgH, nc, nclass, nh, n_rnn=2, leakyRelu=False):
- super(Model, self).__init__()
- assert imgH % 16 == 0, 'imgH has to be a multiple of 16'
-
- ks = [3, 3, 3, 3, 3, 3, 2]
- ps = [1, 1, 1, 1, 1, 1, 0]
- ss = [1, 1, 1, 1, 1, 1, 1]
- nm = [64, 128, 256, 256, 512, 512, 512]
-
- cnn = nn.Sequential()
-
- def convRelu(i, batchNormalization=False):
- nIn = nc if i == 0 else nm[i - 1]
- nOut = nm[i]
- cnn.add_module('conv{0}'.format(i),
- nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
- if batchNormalization:
- cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
- if leakyRelu:
- cnn.add_module('relu{0}'.format(i),
- nn.LeakyReLU(0.2, inplace=True))
- else:
- cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
-
- convRelu(0)
- cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
- convRelu(1)
- cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
- convRelu(2, True)
- convRelu(3)
- cnn.add_module('pooling{0}'.format(2),
- nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
- convRelu(4, True)
- convRelu(5)
- cnn.add_module('pooling{0}'.format(3),
- nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
- convRelu(6, True) # 512x1x16
-
- self.cnn = cnn
- self.rnn = nn.Sequential(
- BidirectionalLSTM(512, nh, nh),
- BidirectionalLSTM(nh, nh, nclass))
-
- def forward(self, input):
- # conv features
- conv = self.cnn(input)
- b, c, h, w = conv.size()
- assert h == 1, "the height of conv must be 1"
- conv = conv.squeeze(2)
- conv = conv.permute(2, 0, 1) # [w, b, c]
-
- # rnn features
- output = self.rnn(conv)
-
- # add log_softmax to converge output
- output = F.log_softmax(output, dim=2)
- output_lengths = torch.full(size=(output.size(1),), fill_value=output.size(0), dtype=torch.long,
- device=cfg.device)
-
- return output, output_lengths
-
- def backward_hook(self, module, grad_input, grad_output):
- for g in grad_input:
- g[g != g] = 0 # replace all nan/inf in gradients to zero
-
- #dataset define
- class MyDataset(Dataset):
-
- def __init__(self, path: str, transform=None, ):
- if transform == None:
- self.transform = T.Compose(
- [
- T.ToTensor()
- ])
- else:
- self.transform = transform
- self.path = path
- self.picture_list = list(os.walk(self.path))[0][-1]
-
- def __len__(self):
- return len(self.picture_list)
-
- def __getitem__(self, item):
- """
- :param item: ID
- :return: (图片,标签)
- """
- picture_path_list = self._load_picture()
- img = Image.open(picture_path_list[item]).convert("RGB")
- img = self.transform(img)
- label = os.path.splitext(self.picture_list[item])[0].split("_")[1]
-
- label = [[cfg.all_chars[i]] for i in label]
- label = torch.as_tensor(label, dtype=torch.int64)
-
- return img, label
-
- def _load_picture(self):
- return [self.path + '/' + i for i in self.picture_list]
-
- def collate_fn(batch):
- sequence_lengths = []
- max_width, max_height = 0, 0
- for image, label in batch:
- if image.size(1) > max_height:
- max_height = image.size(1)
- if image.size(2) > max_width:
- max_width = image.size(2)
- sequence_lengths.append(label.size(0))
- seq_lengths = torch.LongTensor(sequence_lengths)
- seq_tensor = torch.zeros(seq_lengths.size(0), seq_lengths.max()).long()
- img_tensor = torch.zeros(seq_lengths.size(0), 3, max_height, max_width)
- for idx, (image, label) in enumerate(batch):
- seq_tensor[idx, :label.size(0)] = torch.squeeze(label)
- img_tensor[idx, :, :image.size(1), :image.size(2)] = image
- return img_tensor, seq_tensor, seq_lengths
-
- class ocr():
- def train(self):
- model = Model(imgH = 32,nc = 3, nclass = cfg.class_num, nh = cfg.nh)
- model = model.to(cfg.device)
- criterion = torch.nn.CTCLoss(blank=cfg.class_num - 1, zero_infinity=True)
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
- model.train()
-
- # train dataset
- train_dataset = MyDataset(os.path.join(cfg.data_dir, cfg.train_dir),
- transform=cfg.img_transform) # 训练路径以及transform
- train_loader = DataLoader(dataset=train_dataset, batch_size=cfg.batch_size, shuffle=True,drop_last=True,num_workers=0, collate_fn=collate_fn)
-
- for epoch in range(cfg.epoch):
- bar = tqdm(enumerate(train_loader,0))
- loss_sum = []
- total = 0
- correct = 0
- for idx, (images, labels,label_lengths) in bar:
- images, labels, label_lengths = images.to(cfg.device), \
- labels.to(cfg.device), \
- label_lengths.to(cfg.device)
- optimizer.zero_grad()
- outputs, output_lengths = model(images)
- loss = criterion(outputs, labels, output_lengths, label_lengths)
- loss.backward()
- optimizer.step()
- loss_sum.append(loss.item())
- c, t = self.calculat_train_acc(outputs, labels, label_lengths)
- correct +=c
- total += t
- bar.set_description("epcoh:{} idx:{},loss:{:.6f},acc:{:.6f}".format(epoch, idx, np.mean(loss_sum),100 * correct / total))
- if epoch%cfg.save_model_fre_epoch ==0:
- torch.save(model.state_dict(), os.path.join(cfg.save_model_dir,"epoch_"+str(epoch)+'.pkl'), _use_new_zipfile_serialization=True) # 模型保存
- torch.save(optimizer.state_dict(), os.path.join(cfg.save_model_dir,"epoch_"+str(epoch)+"_opti"+'.pkl'), _use_new_zipfile_serialization=True) # 优化器保存
-
- def infer(self):
- for modelname in os.listdir(cfg.save_model_dir):
- #model define
- train_weights_path = os.path.join(cfg.save_model_dir, modelname)
- train_weights_dict = torch.load(train_weights_path)
- model = Model(imgH=32, nc=3, nclass=cfg.class_num, nh=cfg.nh)
- model.load_state_dict(train_weights_dict, strict=True)
- model = model.to(cfg.device)
- model.eval()
-
- #test dataset
- test_dataset = MyDataset(os.path.join(cfg.data_dir, cfg.test_dir), transform=cfg.img_transform) # 训练路径以及transform
- test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)
-
- total = 0
- correct = 0
- results = []
- for idx,(images, labels,label_lengths) in enumerate(test_loader,0):
- labels = torch.squeeze(labels).to(cfg.device)
-
- with torch.no_grad():
- predicts,output_lengths = model(images.to(cfg.device))
- c, t, result = self.calculat_infer_acc(predicts, labels, label_lengths)
- correct += c
- total += t
- results.append(result)
- print("model name: "+modelname+'\t'+"|| acc: "+str(correct / total)+'\n')
-
- # 计算训练准确率
- def calculat_train_acc(self,output, target, target_lengths):
- output = torch.argmax(output, dim=-1)
- output = output.permute(1, 0)
-
- correct_num = 0
- for predict, label, label_length in zip(output, target, target_lengths):
- predict = torch.unique_consecutive(predict)
- predict = predict[predict != (cfg.class_num - 1)]
- if (predict.size()[0] == label_length.item()
- and (predict == label[:label_length.item()]).all()):
- correct_num += 1
- return correct_num, target.size(0)
-
- #计算推理准确率
- def calculat_infer_acc(self,output, target, target_lengths):
- output = torch.argmax(output, dim=-1)
- output = output.permute(1, 0)
-
- correct_num = 0
- total_num = 0
- predict_list = []
- for predict, label, label_length in zip(output, target, target_lengths):
- total_num +=1
- predict = torch.unique_consecutive(predict)
- predict = predict[predict != (cfg.class_num - 1)]
- predict_list = predict.cpu().tolist()
- label_list = target.cpu().tolist()
- if predict_list == label_list:
- correct_num += 1
-
- if predict_list == []:
- predict_str = '____'
- else:
- predict_str = ''.join([cfg.all_nums[s] for s in predict_list])
- label_str = ''.join([cfg.all_nums[s] for s in label_list])
- return correct_num, total_num,','.join([predict_str,label_str])
-
- if __name__ == '__main__':
- myocr = ocr()
- if cfg.istrain == True:
- myocr.train()
- if cfg.istest == True:
- myocr.infer()
本文采用captcha_datasets数据集作为实验数据集,训练集:验证集:测试集=25000:10000:10000,图片内容主要是数字验证码。在本次实验中采用30次迭代测试模型效果,train-ctcloss、train-acc、test-acc效果如下表所示
epoch | loss | train-acc | val/test-acc |
1 | 2.772569 | 0 | 0 |
2 | 0.957933 | 0.45997596 | 0.7438 |
3 | 0.038466 | 0.96987179 | 0.9706 |
4 | 0.018337 | 0.984375 | 0.9653 |
5 | 0.01449 | 0.98766026 | 0.9836 |
10 | 0.008008 | 0.99246795 | 0.9714 |
15 | 0.002388 | 0.99759615 | 0.9941 |
20 | 0.004845 | 0.99583333 | 0.9952 |
25 | 0.001462 | 0.99863782 | 0.9867 |
30 | 0.003154 | 0.99767628 | 0.9949 |
部分识别效果图展示:
图 识别效果实例图
由上述的训练过程可以看出,ctcloss在5次迭代后就有了较好的识别效果。原因是数据量较小、数据质量较单一,可以期待在更大数据集上的识别效果。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。