当前位置:   article > 正文

CRNN文本识别与tensorflow实现_tensorflow crnn

tensorflow crnn

1.引言

    文本识别即对一张文本图像进行识别,将其中的文字转化为文本信息,这样才能变成计算机可以理解的语言。前面我们介绍了两种文本检测方法,请参见《CTPN文本检测与tensorflow实现》《EAST文本检测与Keras实现》,在文本检测之后,我们可以获得了一张图像中各个文本的位置,这时,我们可以将各个文本片段剪切出来,进行仿射变换,得到类似图1这样的文本图像,但是,这时计算机还是没法理解图像中具体是什么文字,因此,需要进行文本识别,即将图像中的文本转化为纯文本,我们平时见到的验证码识别其实也是文字识别的一种场景。

图1 从自然场景图像中剪切出来的文本片段

    在以往的文本识别模型中,习惯是采用一种滑动窗口的方式,逐步检测每个窗口下的文本,这种做法对于不同的字体、字体检测效果就特别差,特别对于中文文字的识别。然后也有一些模型采用对齐的方式,对图像的每一帧都进行文本标注,然后采用类似encoder-decoder这样的结构来进行文本识别,但是这样的做法需要耗费大量的人力进行对齐标注,特别是当文本前后带有空白字符时,标注起来就特别繁琐。因此,文本将介绍一个在文本识别中效果相对比较好的模型——CRNN,该模型不需要对图像进行对齐标注 ,直接输入文本图像,然后就可以输出对应的识别结果,而且准确率非常高!

2.模型介绍

2.1 模型结构介绍

    CRNN的模型结构总共包含三部分,分别是卷积层、RNN层和转录层,如图2所示。

图2 CRNN模型结构

    在卷积层部分,首先将每一张图像的高度固定在某一个值,然后对图像进行卷积操作,接着,对于卷积后得到的feature maps构建RNN层的输入特征序列,具体的操作就是,将这些feature maps从左到右每次取出一列,然后将每个feature map对应该列的向量进行拼接,拼接后的向量就作为RNN该时间步对应的特征输入。由于卷积后得到的feature maps每一列都对应原图的一个矩形区域,因此,按照这种操作得到的feature Sequence中每一个向量其实也是与原图的某个矩形区域相对应,并且这些矩形区域也是按照从左到右顺序排列的,因此,每个特征向量之间其实是带有时序关系的。如图3所示。

图3 卷积层得到的特征序列与原图区域的对应关系

     接着,是模型的RNN层部分,由前面我们知道,卷积层结束后得到的feature Sequence中,每个向量之间是具有时序关系的,不是独立的,因此,很自然就会想到用RNN来操作,作者在论文中采用的是深层双向递归神经网络,其中RNN单元采用的是LSTM单元,如图4所示。引入RNN主要有三个好处:①有些比较大的字符同时横跨多列,采用RNN可以记住前面序列的信息,另外,有些字符放在一起时,可以进行高度对比,更容易识别出其标签,比如‘i’和‘l’。②RNN可以将误差传递给CNN层,从而使得模型可以同时训练RNN和CNN的参数。③RNN可以解决文本序列变长的问题。

图4 LSTM单元和深层双向RNN

    假设在卷积层得到的feature Sequence为\mathbf { x } = x _ { 1 } , \dots , x _ { T },则对于每个时间步的输入x _ { t },RNN将输出该时间步对应的类别分布y _ { t },其中y _ { t }的长度即为所有字符类别的长度。记RNN层得到的输出序列为y= y _ { 1 } , \dots , y _ { T },其中T为序列的长度,其中,y _ { t } \in \Re ^ { \left| \mathcal { L } ^ { \prime } \right| }表示第t个时间步的字符类别概率分布,\mathcal { L } ^ { \prime } =\mathcal { L } \cup表示所有字符类别和空字符的集合。这里可能有人会觉得,既然已经输出了各个时间步的输出,那么可不可以像机器翻译那样,直接对输出序列的前后标记start和end字符,然后从输出里面进行截取,获得预测的标签序列,这么想是可以的,不过呢,就需要人为对整个图像每个时间步对应的感受野事先标记好其标签,会产生很繁琐的手工标注工作,因此,作者并没有这样操作,而是采用了一种转录方法,即模型中的转录层。

    在转录层,作者引入了一个\mathcal { B }变换,即对于一个字符序列\pi \in \mathcal { L } ^ { \prime T }\mathcal { B }变换会将其中的重复字符、空字符移除,得到最后的字符序列l,比如对于预测序列“--hh-e-l-ll-oo--”,其中“-”表示空字符,则经过\mathcal { B }变换后得到的输出为“hello”,这里需要注意的是,当两个字符相同,并且中间隔着“-”时,则去重时不移除,因此,l的条件概率即为那些经过\mathcal { B }变换后得到l的字符序列\pi的概率加总,具体表达式如下:

                                                           p ( l | \mathbf { y } ) = \sum _ { \boldsymbol { \pi } : \mathcal { B } ( \boldsymbol { \pi } ) = l } p ( \pi | \mathbf { y } )

其中,p ( \pi | y ) =\prod _ { t = 1 } ^ { T } y _ { \pi _ { t } } ^ { t }为每个字符序列中每个字符概率的乘积,y _ { \pi _ { t } } ^ { t }表示第t个时间步为字符\pi _ { t }的概率,但是,这种算法将非常耗时,因此,作者借鉴了CTC中的forward-backward的算法使其更有效率。

    关于CTC中forward-backward的算法原理介绍可以参见我另一篇博文《CTC原理介绍》,这里不再具体展开。

    转录的时候有两种方式,一种是无词典的转录方式,一种是基于词典的转录方式。

    对于无词典的转录方式,其计算公式如下:

                                                           l ^ { * } \approx \mathcal { B } \left( \arg \max _ { \pi } p ( \pi | \mathbf { y } ) \right)

其实就是对每个时间步选择概率最大的字符,最后将该字符序列用\mathcal { B }变换得到对应的l

    对于基于词典的转录方式,其思想是构建一个词典集,然后计算词典中每个字符序列的概率,从中选择概率最大的作为最终的转录文本,其计算公式如下:

                                                           l ^ { * } =\arg \max _ { \mathrm { l } \in \mathcal { D } } p ( \mathrm { l } | \mathrm { y } )

其中,\mathcal{D}即为构建的词典集,基于这种计算方法有个缺点,就是当词典集比较大时,计算复杂度比较大,因此,作者提出了一种改进方法,作者发现基于无词典的转录方式其实与真实的标签很接近,因此,作者首先采用无词典的转录方式获得转录文本l ^ { \prime },然后用BK-tree从词典集中搜索与它编辑距离(有关编辑距离的概念可以参考这篇文章:《Edit Distance(编辑距离)》)小于\delta的词典,记为\mathcal { N } _ { \delta } \left( \mathrm { l } ^ { \prime } \right),然后再从近邻词典里面计算每个字符序列的概率,选择概率最大的作为最后的转录文本,其计算公式如下:

                                                          \mathrm { l} ^ { * } = \arg \max _ { \mathrm {l} \in \mathcal { N } _ { \delta } \left( \mathrm { l} ^ { \prime } \right) } p ( \mathrm { l } | \mathrm { y } )

2.2 模型的损失函数

    CRNN的损失函数采用的是负对数似然函数,记训练集为\mathcal { X } = \left\{ I _ { i } , l _ { i } \right\} _ { i },其中,I _ { i }表示输入的图像,l _ { i }表示真实的字符序列,则对应的损失函数为:

                                                         \mathcal { O } = - \sum _ { I _ { i } , \mathbf { l } _ { i } \in \mathcal { X } } \log p \left( \mathbf { l } _ { i } | \mathbf { y } _ { i } \right) 

3.tensorflow实现

    本文采用tensorflow对CRNN原理进行复现,项目的结构如图5所示,下面将对每个模块进行具体介绍。

图5 项目结构

    首先是data路径,存放的是训练集和测试集,train_images存放的是训练时的数据集,test_images存放的是测试时的数据集,本文的数据有两种来源,一种是ICPR比赛数据集,一种是模拟的数据集。

图6 data路径下结构

     dict下存放的是字符集文档,有三种可以选择,chinese.txt存放的是中文常用3000字,english.txt存放的是英文字母以及一些标点符号,而english_chinese.txt则是前面两个文档的集合,当选择english_chinese.txt时,则支持对中英文的文本识别,本文训练时默认使用的是english_chinese.txt。

图7 字符集合文档

    fonts路径存放的是生成模拟数据时的字体文件,window系统一般可以在C:\Windows\Fonts下查找,这个可以自己选择字体文件。

图8 字体文件

     images_base存放的是模拟数据的背景图像,models文件夹存放的是训练后的模型文件。接着,是各个py脚本文件的功能介绍,其中,charset_generate.py,该脚本存放的是字符集文本生成函数,从图像的label中提取字符集合,存生成charset.txt存放在data路径下。其代码如下:


 
 
  1. import tqdm
  2. from crnn import config as crnn_config
  3. def generate_charset(labels_path, charset_path):
  4. """
  5. generate char dictionary with text label
  6. :param labels_path:label_path: path of your text label
  7. :param charset_path: path for restore char dict
  8. :return:
  9. """
  10. with open(labels_path, 'r', encoding= 'utf-8') as fr:
  11. lines = fr.read().split( '\n')
  12. dic = str()
  13. for label in tqdm.tqdm(lines[: -1]):
  14. for char in label:
  15. if char in dic:
  16. continue
  17. else:
  18. dic += char
  19. with open(charset_path, 'w', encoding= 'utf-8') as fw:
  20. fw.write(dic)
  21. if __name__ == '__main__':
  22. label_path = crnn_config.train_label_path
  23. char_dict_path = crnn_config.charset_path
  24. generate_charset(label_path, char_dict_path)
  • 1

    然后是data_provider.py文件,该文件一方面用于从自然场景图像中对文本进行切割,然后进行放射片段,并保存到data下的训练集和测试集路径下,用于训练和测试时使用,另一方面用于生成模拟的数据,模拟的数据也同样会存放在训练集路劲下。


 
 
  1. import os
  2. import cv2
  3. import math
  4. import random
  5. import shutil
  6. import numpy as np
  7. from tqdm import trange
  8. from collections import Counter
  9. from crnn import charset_generate
  10. from multiprocessing import Process
  11. from crnn import config as crnn_config
  12. from PIL import Image, ImageDraw, ImageFont
  13. class TextCut(object):
  14. def __init__(self,
  15. org_images_path,
  16. org_labels_path,
  17. cut_train_images_path,
  18. cut_train_labels_path,
  19. cut_test_images_path,
  20. cut_test_labels_path,
  21. train_test_ratio=0.8,
  22. filter_ratio=1.5,
  23. filter_height=25,
  24. is_transform=True,
  25. angle_range=[-15.0, 15.0],
  26. write_mode='w',
  27. use_blank=False,
  28. num_process=1):
  29. """
  30. 对ICPR原始图像进行切图
  31. :param org_images_path: ICPR数据集原始图像路径,[str]
  32. :param org_labels_path: ICPR数据集原始label路径,[str]
  33. :param cut_train_images_path: 训练集切图的保存路径,[str]
  34. :param cut_train_labels_path: 训练集切图对应label的保存路径,[str]
  35. :param cut_test_images_path: 测试集切图的保存路径,[str]
  36. :param cut_test_labels_path: 测试集切图对应label的保存路径,[str]
  37. :param train_test_ratio: 训练测试数据集比例,[float]
  38. :param filter_ratio: 图片过滤的高宽比例,高于该比例的图片将被过滤,default:1.5 ,[float]
  39. :param filter_height:高度过滤,切图后的图像高度低于该值的将被过滤掉,[int]
  40. :param is_transform: 是否进行仿射变换,default:True [boolean]
  41. :param angle_range: 不进行仿射变换的角度范围default:[-15.0, 15.0],[list]
  42. :param write_mode: 数据写入模式,'w':write,'a':add,[str]
  43. :param use_blank: 是否使用空格,[boolean]
  44. :param num_process: 并行处理的进程数
  45. :return:
  46. """
  47. self.org_images_path = org_images_path
  48. self.org_labels_path = org_labels_path
  49. self.cut_train_images_path = cut_train_images_path
  50. self.cut_train_labels_path = cut_train_labels_path
  51. self.cut_test_images_path = cut_test_images_path
  52. self.cut_test_labels_path = cut_test_labels_path
  53. self.train_test_ratio = train_test_ratio
  54. self.filter_ratio = filter_ratio
  55. self.filter_height = filter_height
  56. self.is_transform = is_transform
  57. self.angle_range = angle_range
  58. assert write_mode in [ 'w', 'a'], "write mode should be 'w'(write) or 'a'(add)"
  59. self.write_mode = write_mode
  60. self.use_blank = use_blank
  61. self.num_process = num_process
  62. self.org_labels_list = None
  63. super().__init__()
  64. def data_load(self, org_images_list):
  65. """
  66. 对ICPR图像做文本切割处理
  67. :param org_images_list: 原始图片文件名
  68. :return:
  69. """
  70. data_len = len(org_images_list)
  71. train_test_offset = data_len * self.train_test_ratio
  72. for data_i in range(len(org_images_list)):
  73. org_image_path = org_images_list[data_i]
  74. org_image_name = os.path.basename(org_image_path)[: -4]
  75. org_label_path = org_image_name + ".txt"
  76. if org_label_path not in self.org_labels_list:
  77. continue
  78. org_image = Image.open(os.path.join(self.org_images_path, org_image_path))
  79. with open(os.path.join(self.org_labels_path, org_label_path), 'r', encoding= 'utf-8') as fr:
  80. org_label = fr.read().split( '\n')
  81. cut_images_list, cut_labels_list = self.cut_text(org_image, org_label,
  82. self.filter_ratio,
  83. self.is_transform,
  84. self.angle_range)
  85. if data_i < train_test_offset:
  86. img_save_path = self.cut_train_images_path
  87. label_save_path = self.cut_train_labels_path
  88. else:
  89. img_save_path = self.cut_test_images_path
  90. label_save_path = self.cut_test_labels_path
  91. for i in range(len(cut_images_list)):
  92. cut_img = cut_images_list[i]
  93. if cut_img.shape[ 0] >= self.filter_height:
  94. cut_img = Image.fromarray(cut_img)
  95. cut_img = cut_img.convert( 'RGB')
  96. cut_label = cut_labels_list[i]
  97. cut_img_name = org_image_name + '_' + str(i) + '.jpg'
  98. cut_img.save(os.path.join(img_save_path, cut_img_name))
  99. with open(label_save_path, 'a', encoding= 'utf-8') as fa:
  100. fa.write(cut_img_name + '\t' + cut_label + '\n')
  101. def data_load_multi_process(self, num_process=None):
  102. """
  103. 多进程对ICPR图像做文本切割处理
  104. :param num_process:进程数,默认16,[int]
  105. :return:
  106. """
  107. if num_process is None:
  108. num_process = self.num_process
  109. org_images_list = os.listdir(self.org_images_path)
  110. self.org_labels_list = os.listdir(self.org_labels_path)
  111. # clear label.txt at first step
  112. check_path([self.cut_train_images_path,
  113. self.cut_train_labels_path,
  114. self.cut_test_images_path,
  115. self.cut_test_labels_path])
  116. if self.write_mode == 'w':
  117. clear_content([self.cut_train_images_path,
  118. self.cut_train_labels_path,
  119. self.cut_test_images_path,
  120. self.cut_test_labels_path])
  121. all_data_len = len(org_images_list)
  122. data_offset = all_data_len // num_process
  123. processes = list()
  124. for data_i in trange( 0, all_data_len, data_offset):
  125. if data_i + data_offset >= all_data_len:
  126. processes.append(Process(target=self.data_load, args=(org_images_list[data_i:],)))
  127. else:
  128. processes.append(Process(target=self.data_load, args=(org_images_list[data_i:data_i + data_offset],)))
  129. for process in processes:
  130. process.start()
  131. for process in processes:
  132. process.join()
  133. def cut_text(self, image, labels, filter_ratio, is_transform, angle_range):
  134. """
  135. 文本切图
  136. :param image: 原始图像,[array]
  137. :param labels: 文本的label,[str]
  138. :param filter_ratio: 图片过滤的高宽比例,高于该比例的图片将被过滤,e.g. 1.5 ,[float]
  139. :param is_transform: 是否进行仿射变换,[boolean]
  140. :param angle_range: 不进行仿射变换的角度范围e.g.[-15.0, 15.0],[list]
  141. :return:
  142. """
  143. cut_images = list()
  144. cut_labels = list()
  145. w, h = image.size
  146. for label in labels:
  147. if label == '':
  148. continue
  149. label_text = label.split( ',')
  150. text = label_text[ -1]
  151. if not self.use_blank:
  152. text = text.replace( ' ', '')
  153. if text == '###' or text == '★' or text == '':
  154. continue
  155. position = self.reorder_vertexes(
  156. np.array([[round(float(label_text[i])), round(float(label_text[i + 1]))] for i in range( 0, 8, 2)]))
  157. position = np.reshape(position, 8).tolist()
  158. left = max(min([position[i] for i in range( 0, 8, 2)]), 0)
  159. right = min(max([position[i] for i in range( 0, 8, 2)]), w)
  160. top = max(min([position[i] for i in range( 1, 8, 2)]), 0)
  161. bottom = min(max([position[i] for i in range( 1, 8, 2)]), h)
  162. if (bottom - top) / (right - left + 1e-3) > filter_ratio:
  163. continue
  164. image = np.asarray(image)
  165. cut_image = image[top:bottom, left:right]
  166. if is_transform:
  167. trans_img = self.transform(image, position, angle_range)
  168. if trans_img is not None:
  169. cut_image = trans_img
  170. cut_images.append(cut_image)
  171. cut_labels.append(text)
  172. return cut_images, cut_labels
  173. def transform(self, image, position, angle_range):
  174. """
  175. 仿射变换
  176. :param image: 原始图像,[array]
  177. :param position: 文本所在的位置e.g.[x0,y0,x1,y1,x2,y2],[list]
  178. :param angle_range: 不进行仿射变换的角度范围e.g.[-15.0, 15.0],[list]
  179. :return: 变换后的图像
  180. """
  181. from_points = [position[ 2: 4], position[ 4: 6]]
  182. width = round(float(self.calc_dis(position[ 2: 4], position[ 4: 6])))
  183. height = round(float(self.calc_dis(position[ 2: 4], position[ 0: 2])))
  184. to_points = [[ 0, 0], [width, 0]]
  185. from_mat = self.list2col_matrix(from_points)
  186. to_mat = self.list2col_matrix(to_points)
  187. tran_m, tran_b = self.get_transform(from_mat, to_mat)
  188. probe_vec = np.matrix([ 1.0, 0.0]).transpose()
  189. probe_vec = tran_m * probe_vec
  190. scale = np.linalg.norm(probe_vec)
  191. angle = 180.0 / np.pi * math.atan2(probe_vec[ 1, 0], probe_vec[ 0, 0])
  192. if (angle > angle_range[ 0]) and (angle < angle_range[ 1]):
  193. return None
  194. else:
  195. from_center = position[ 2: 4]
  196. to_center = [ 0, 0]
  197. dx = to_center[ 0] - from_center[ 0]
  198. dy = to_center[ 1] - from_center[ 1]
  199. trans_m = cv2.getRotationMatrix2D((from_center[ 0], from_center[ 1]), -1 * angle, scale)
  200. trans_m[ 0][ 2] += dx
  201. trans_m[ 1][ 2] += dy
  202. dst = cv2.warpAffine(image, trans_m, (int(width), int(height)))
  203. return dst
  204. def get_transform(self, from_shape, to_shape):
  205. """
  206. 计算变换矩阵A,使得y=A*x
  207. :param from_shape: 变换之前的形状x,形式为矩阵,[list]
  208. :param to_shape: 变换之后的形状y,形式为矩阵,[list]
  209. :return: A
  210. """
  211. assert from_shape.shape[ 0] == to_shape.shape[ 0] and from_shape.shape[ 0] % 2 == 0
  212. sigma_from = 0.0
  213. sigma_to = 0.0
  214. cov = np.matrix([[ 0.0, 0.0], [ 0.0, 0.0]])
  215. # compute the mean and cov
  216. from_shape_points = from_shape.reshape(from_shape.shape[ 0] // 2, 2)
  217. to_shape_points = to_shape.reshape(to_shape.shape[ 0] // 2, 2)
  218. mean_from = from_shape_points.mean(axis= 0)
  219. mean_to = to_shape_points.mean(axis= 0)
  220. for i in range(from_shape_points.shape[ 0]):
  221. temp_dis = np.linalg.norm(from_shape_points[i] - mean_from)
  222. sigma_from += temp_dis * temp_dis
  223. temp_dis = np.linalg.norm(to_shape_points[i] - mean_to)
  224. sigma_to += temp_dis * temp_dis
  225. cov += (to_shape_points[i].transpose() - mean_to.transpose()) * (from_shape_points[i] - mean_from)
  226. sigma_from = sigma_from / to_shape_points.shape[ 0]
  227. sigma_to = sigma_to / to_shape_points.shape[ 0]
  228. cov = cov / to_shape_points.shape[ 0]
  229. # compute the affine matrix
  230. s = np.matrix([[ 1.0, 0.0], [ 0.0, 1.0]])
  231. u, d, vt = np.linalg.svd(cov)
  232. if np.linalg.det(cov) < 0:
  233. if d[ 1] < d[ 0]:
  234. s[ 1, 1] = -1
  235. else:
  236. s[ 0, 0] = -1
  237. r = u * s * vt
  238. c = 1.0
  239. if sigma_from != 0:
  240. c = 1.0 / sigma_from * np.trace(np.diag(d) * s)
  241. tran_b = mean_to.transpose() - c * r * mean_from.transpose()
  242. tran_m = c * r
  243. return tran_m, tran_b
  244. def list2col_matrix(self, pts_list):
  245. """
  246. 列表转为列矩阵
  247. :param pts_list:点列表e.g[x0,y0,x1,y1,x2,y1],[list]
  248. :return:
  249. """
  250. assert len(pts_list) > 0
  251. col_mat = []
  252. for i in range(len(pts_list)):
  253. col_mat.append(pts_list[i][ 0])
  254. col_mat.append(pts_list[i][ 1])
  255. col_mat = np.matrix(col_mat).transpose()
  256. return col_mat
  257. def calc_dis(self, point1, point2):
  258. """
  259. 计算两个点的欧式距离
  260. :param point1:二维坐标e.g.[12.3, 34.1],list
  261. :param point2:二维坐标e.g.[12.3, 34.1],list
  262. :return:两个点的欧式距离
  263. """
  264. return np.sqrt((point2[ 1] - point1[ 1]) ** 2 + (point2[ 0] - point1[ 0]) ** 2)
  265. def reorder_vertexes(self, xy_list):
  266. """
  267. 对文本线的四个顶点坐标进行重新排序,按照逆时针排序
  268. :param xy_list: 文本线的四个顶点坐标, [array]
  269. :return:
  270. """
  271. reorder_xy_list = np.zeros_like(xy_list)
  272. # 确定第一个顶点的坐标,选择横坐标最小的作为第一个顶点
  273. ordered = np.argsort(xy_list, axis= 0)
  274. xmin1_index = ordered[ 0, 0]
  275. xmin2_index = ordered[ 1, 0]
  276. if xy_list[xmin1_index, 0] == xy_list[xmin2_index, 0]:
  277. if xy_list[xmin1_index, 1] <= xy_list[xmin2_index, 1]:
  278. reorder_xy_list[ 0] = xy_list[xmin1_index]
  279. first_v = xmin1_index
  280. else:
  281. reorder_xy_list[ 0] = xy_list[xmin2_index]
  282. first_v = xmin2_index
  283. else:
  284. reorder_xy_list[ 0] = xy_list[xmin1_index]
  285. first_v = xmin1_index
  286. # 计算另外三个顶点与第一个顶点的正切,将值处于中间的顶点作为第三个顶点
  287. others = list(range( 4))
  288. others.remove(first_v)
  289. k = np.zeros((len(others),))
  290. for index, i in zip(others, range(len(others))):
  291. k[i] = (xy_list[index, 1] - xy_list[first_v, 1]) \
  292. / (xy_list[index, 0] - xy_list[first_v, 0] + crnn_config.epsilon)
  293. k_mid = np.argsort(k)[ 1]
  294. third_v = others[k_mid]
  295. reorder_xy_list[ 2] = xy_list[third_v]
  296. # 比较第二个顶点与第四个顶点与第一个顶点的正切与第三个顶点与第一个顶点的正切的大小,
  297. # 将大于中间值的顶点作为第二个顶点,另一个作为第四个顶点
  298. others.remove(third_v)
  299. b_mid = xy_list[first_v, 1] - k[k_mid] * xy_list[first_v, 0]
  300. second_v, fourth_v = 0, 0
  301. for index, i in zip(others, range(len(others))):
  302. # delta = y - (k * x + b)
  303. delta_y = xy_list[index, 1] - (k[k_mid] * xy_list[index, 0] + b_mid)
  304. if delta_y > 0:
  305. second_v = index
  306. else:
  307. fourth_v = index
  308. reorder_xy_list[ 1] = xy_list[second_v]
  309. reorder_xy_list[ 3] = xy_list[fourth_v]
  310. # 判断是否需要对顶点进行旋转,当第一个顶点是四边形的左下顶点时,则按照逆时针旋转一个单位
  311. k13 = k[k_mid]
  312. k24 = (xy_list[second_v, 1] - xy_list[fourth_v, 1]) / (
  313. xy_list[second_v, 0] - xy_list[fourth_v, 0] + crnn_config.epsilon)
  314. if k13 < k24:
  315. tmp_x, tmp_y = reorder_xy_list[ 3, 0], reorder_xy_list[ 3, 1]
  316. for i in range( 2, -1, -1):
  317. reorder_xy_list[i + 1] = reorder_xy_list[i]
  318. reorder_xy_list[ 0, 0], reorder_xy_list[ 0, 1] = tmp_x, tmp_y
  319. return [reorder_xy_list[ 1], reorder_xy_list[ 0], reorder_xy_list[ 3], reorder_xy_list[ 2]]
  320. class ImageGenerate(object):
  321. def __init__(self,
  322. img_base_path,
  323. font_style_path,
  324. text_size_limit,
  325. font_size,
  326. font_color,
  327. train_images_path,
  328. train_labels_path,
  329. test_images_path,
  330. test_labels_path,
  331. train_test_ratio,
  332. num_samples,
  333. dictionary_file,
  334. margin=20,
  335. write_mode='w',
  336. use_blank=False,
  337. num_process=1):
  338. """
  339. 生成类代码图像
  340. :param img_base_path: 背景文件夹路径,[str]
  341. :param font_style_path: 字体文件夹路径,包括中英文字体文件夹,[dict]
  342. :param text_size_limit: 文本字符个数范围列表e.g.[1,8],[list]
  343. :param font_size: 文本字体大小列表e.g.[24,32,36],[list]
  344. :param font_color: 文本字体颜色列表e.g.[[0, 0, 0], [255, 36, 36]],[list]
  345. :param train_images_path: 训练集图片保存路径,[str]
  346. :param train_labels_path: 训练集标签保存路径,[str]
  347. :param test_images_path:测试集图片保存路径,[str]
  348. :param test_labels_path:测试集标签保存路径,[str]
  349. :param train_test_ratio: 训练集测试集比例,[float]
  350. :param num_samples: 生成样本总数,[int]
  351. :param dictionary_file: 字典文件路径,[str]
  352. :param margin: 文本离背景图的边距
  353. :param write_mode: 数据写入模式,'w':write,'a':add,[str]
  354. :param use_blank: 是否使用空格,[boolean]
  355. :param num_process: 并行生成样本的进程数
  356. """
  357. self.img_base_path = img_base_path
  358. self.font_style_path = font_style_path
  359. self.text_size_limit = text_size_limit
  360. self.font_size = font_size
  361. self.font_color = font_color
  362. self.train_images_path = train_images_path
  363. self.train_labels_path = train_labels_path
  364. self.test_images_path = test_images_path
  365. self.test_labels_path = test_labels_path
  366. self.train_test_ratio = train_test_ratio
  367. self.num_samples = num_samples
  368. self.dictionary_file = dictionary_file
  369. assert write_mode in [ 'w', 'a'], "write mode should be 'w'(write) or 'a'(add)"
  370. self.write_mode = write_mode
  371. self.use_blank = use_blank
  372. self.num_process = num_process
  373. self.margin = margin
  374. self.base_image_paths = None
  375. self.list_words = None
  376. self.used_ch_word = list()
  377. self.ch_fonts_list = os.listdir(self.font_style_path[ 'ch'])
  378. self.en_fonts_list = os.listdir(self.font_style_path[ 'en'])
  379. super().__init__()
  380. def generate_image(self, start_end):
  381. """
  382. 生成样本图片并保存
  383. :param start_end: 开始ID和结尾ID的list,[list]
  384. :return:
  385. """
  386. # check dir and files
  387. train_test_offset = start_end[ 0] + (start_end[ 1] - start_end[ 0]) * self.train_test_ratio
  388. for i in range(start_end[ 0], start_end[ 1]):
  389. # get base image by order
  390. base_img_path = self.base_image_paths[
  391. (i - start_end[ 0]) * len(self.base_image_paths) // (start_end[ 1] - start_end[ 0])]
  392. # choice font_color depend on base image
  393. if os.path.basename(base_img_path).split( '_')[ 1] == '0':
  394. font_color = random.choice(self.font_color[ 3:])
  395. elif os.path.basename(base_img_path).split( '_')[ 1] == '1':
  396. font_color = random.choice(self.font_color[ 0: 6] + self.font_color[ 12:])
  397. elif os.path.basename(base_img_path).split( '_')[ 1] == '2':
  398. font_color = random.choice(self.font_color[ 0: 12] + self.font_color[ 15:])
  399. elif os.path.basename(base_img_path).split( '_')[ 1] == '3':
  400. font_color = random.choice(self.font_color[ 0: 16])
  401. # create image draw
  402. base_img = Image.open(base_img_path)
  403. base_img_width, base_img_height = base_img.size
  404. draw = ImageDraw.Draw(base_img)
  405. while 1:
  406. try:
  407. # randomly choice font size
  408. font_size = random.choice(self.font_size)
  409. # randomly choice words str
  410. words_str_len = random.randint(self.text_size_limit[ 0], self.text_size_limit[ 1])
  411. only_latin, words_str = self.get_word_str(words_str_len)
  412. # randomly choice font style
  413. if only_latin:
  414. font_style_path = random.choice(self.en_fonts_list)
  415. font_style_path = os.path.join(self.font_style_path[ 'en'], font_style_path)
  416. else:
  417. font_style_path = random.choice(self.ch_fonts_list)
  418. font_style_path = os.path.join(self.font_style_path[ 'ch'], font_style_path)
  419. font = ImageFont.truetype(font_style_path, font_size)
  420. words_str_width, words_str_height = draw.textsize(words_str, font)
  421. x0 = random.randint(self.margin, base_img_width - self.margin - words_str_width)
  422. y0 = random.randint(self.margin, base_img_height - self.margin - words_str_height)
  423. draw.text((x0, y0), words_str, tuple(font_color), font=font)
  424. # save Image
  425. x_left = x0 - random.randint( 0, self.margin)
  426. y_top = y0 - random.randint( 0, self.margin)
  427. x_right = x0 + words_str_width + random.randint( 0, self.margin)
  428. y_bottom = y0 + words_str_height + random.randint( 0, self.margin)
  429. base_img = np.asarray(base_img)[:, :, 0: 3]
  430. image = base_img[y_top:y_bottom, x_left:x_right]
  431. image = Image.fromarray(image)
  432. if i < train_test_offset:
  433. image_dir = self.train_images_path
  434. labels_path = self.train_labels_path
  435. else:
  436. image_dir = self.test_images_path
  437. labels_path = self.test_labels_path
  438. image_name = 'img_' + str(i).zfill(len(str(self.num_samples))) + '.jpg'
  439. image_save_path = os.path.join(image_dir, image_name)
  440. image.save(image_save_path)
  441. # save labels
  442. with open(labels_path, 'a', encoding= 'utf-8') as fa:
  443. fa.write(image_name + '\t' + words_str + '\n')
  444. break
  445. except Exception as e:
  446. continue
  447. def generate_image_multi_process(self, num_process=None):
  448. """
  449. 多进程生成样本图片并保存
  450. :return:
  451. """
  452. if num_process is None:
  453. num_process = self.num_process
  454. self.base_image_paths = [os.path.join(self.img_base_path, img) for img in
  455. os.listdir(self.img_base_path)]
  456. words = [Counter(extract_words_i) for extract_words_i in
  457. self.extract_words(open(self.dictionary_file, encoding= "utf-8").read())]
  458. self.list_words = [list(words_i.keys()) for words_i in words]
  459. # check dir and files
  460. check_path([self.train_images_path,
  461. self.train_labels_path,
  462. self.test_images_path,
  463. self.test_labels_path])
  464. if self.write_mode == 'w':
  465. clear_content([self.train_images_path,
  466. self.train_labels_path,
  467. self.test_images_path,
  468. self.test_labels_path])
  469. data_offset = self.num_samples // num_process
  470. processes = list()
  471. for i in trange( 0, self.num_samples, data_offset):
  472. if i + data_offset >= self.num_samples:
  473. processes.append(Process(target=self.generate_image, args=([i, self.num_samples],)))
  474. else:
  475. processes.append(Process(target=self.generate_image, args=([i, i + data_offset],)))
  476. for process in processes:
  477. process.start()
  478. for process in processes:
  479. process.join()
  480. def extract_words(self, text):
  481. """
  482. 提取文字
  483. :param text:all char about en and ch divided by \n
  484. :return:word_list,e.g[['1','2',..],['a','b',...,'A','B',...],[',','!',...],['甲','风',...]]
  485. """
  486. words_list = text.split( '\n')
  487. words_list = [i.replace( ' ', '') for i in words_list]
  488. words_list = [[j for j in i] for i in words_list]
  489. if self.use_blank:
  490. words_list.append([ ' '])
  491. return words_list
  492. def get_word_str(self, length):
  493. """
  494. generate word str randomly
  495. :param length: length of word str
  496. :return:
  497. """
  498. word_str = ''
  499. self.used_ch_word = list()
  500. only_latin = False
  501. # only latin char
  502. if random.random() < 0.2:
  503. for i in range(length):
  504. if self.use_blank and (i == 0 or i == length - 1):
  505. words_list_i = random.choice(self.list_words[: 3])
  506. else:
  507. if self.use_blank and random.random() < 0.2:
  508. words_list_i = random.choice(self.list_words[: 3] + self.list_words[ -1])
  509. else:
  510. words_list_i = random.choice(self.list_words[: 3])
  511. word_str += random.choice(words_list_i)
  512. only_latin = True
  513. else:
  514. for i in range(length):
  515. if self.use_blank and (i == 0 or i == length - 1):
  516. words_list_i = random.choice(self.list_words[: -1])
  517. else:
  518. if self.use_blank and random.random() < 0.2:
  519. words_list_i = random.choice(self.list_words)
  520. else:
  521. words_list_i = random.choice(self.list_words[: -1])
  522. word_str += random.choice(words_list_i)
  523. return only_latin, word_str
  524. def check_path(path_list):
  525. """
  526. 检查路径列表中的路径是否存在,如不存在就生存文件夹或者文件
  527. :param path_list: path list,[list]
  528. :return:
  529. """
  530. for path in path_list:
  531. if not os.path.exists(path) and '.' not in path[ 2:]:
  532. os.mkdir(path)
  533. elif not os.path.exists(path) and '.' in path[ 2:]:
  534. with open(path, 'w', encoding= 'utf-8') as fw:
  535. fw.write( '')
  536. def clear_content(path_list):
  537. """
  538. 清空文件夹和文件内容
  539. :param path_list: path list,[list]
  540. :return:
  541. """
  542. for path in path_list:
  543. if os.path.isdir(path):
  544. shutil.rmtree(path)
  545. os.mkdir(path)
  546. elif os.path.isfile(path):
  547. os.remove(path)
  548. with open(path, 'w', encoding= 'utf-8') as fw:
  549. fw.write( '')
  550. def do_text_cut(write_mode):
  551. print( "{0}".format( 'text cutting...').center( 100, '='))
  552. print( 'train_test_ratio={0}\nfilter_ratio={1}\nfilter_height={2}'
  553. '\nis_transform={3}\nangle_range={4}\nwrite_mode={5}\nuse_blank={6}\nnum_process={7}'.format(
  554. crnn_config.train_test_ratio,
  555. crnn_config.filter_ratio,
  556. crnn_config.filter_height,
  557. crnn_config.is_transform,
  558. crnn_config.angle_range,
  559. write_mode,
  560. crnn_config.use_blank,
  561. crnn_config.num_process))
  562. print( '=' * 100)
  563. text_cut = TextCut(org_images_path=crnn_config.org_images_path,
  564. org_labels_path=crnn_config.org_labels_path,
  565. cut_train_images_path=crnn_config.cut_train_images_path,
  566. cut_train_labels_path=crnn_config.cut_train_labels_path,
  567. cut_test_images_path=crnn_config.cut_test_images_path,
  568. cut_test_labels_path=crnn_config.cut_test_labels_path,
  569. train_test_ratio=crnn_config.train_test_ratio,
  570. filter_ratio=crnn_config.filter_ratio,
  571. filter_height=crnn_config.filter_height,
  572. is_transform=crnn_config.is_transform,
  573. angle_range=crnn_config.angle_range,
  574. write_mode=write_mode,
  575. use_blank=crnn_config.use_blank,
  576. num_process=crnn_config.num_process
  577. )
  578. text_cut.data_load_multi_process()
  579. def do_image_generate(write_mode):
  580. print( "{0}".format( 'image generating...').center( 100, '='))
  581. print( 'train_test_ratio={0}\nnum_samples={1}\nmargin={2}\nwrite_mode={3}\nuse_blank={4}\nnum_process={5}'
  582. .format(crnn_config.train_test_ratio, crnn_config.num_samples, crnn_config.margin, write_mode, crnn_config.use_blank,
  583. crnn_config.num_process))
  584. image_generate = ImageGenerate(img_base_path=crnn_config.base_img_dir,
  585. font_style_path=crnn_config.font_style_path,
  586. text_size_limit=crnn_config.text_size_limit,
  587. font_size=crnn_config.font_size,
  588. font_color=crnn_config.font_color,
  589. train_images_path=crnn_config.train_images_path,
  590. train_labels_path=crnn_config.train_label_path,
  591. test_images_path=crnn_config.test_images_path,
  592. test_labels_path=crnn_config.test_label_path,
  593. train_test_ratio=crnn_config.train_test_ratio,
  594. num_samples=crnn_config.num_samples,
  595. dictionary_file=crnn_config.dictionary_file,
  596. margin=crnn_config.margin,
  597. write_mode=write_mode,
  598. use_blank=crnn_config.use_blank,
  599. num_process=crnn_config.num_process)
  600. image_generate.generate_image_multi_process()
  601. def do_generate_charset(label_path, charset_path):
  602. """
  603. 生成字符集文件
  604. :param label_path: 训练的label地址
  605. :param charset_path: 字符集文件地址
  606. :return:
  607. """
  608. print( "{0}".format( 'charset generating...').center( 100, '='))
  609. print( 'label_path={0}\ncharset_path={1}'.format(label_path, charset_path))
  610. print( '=' * 100)
  611. charset_generate.generate_charset(label_path, charset_path)
  612. if __name__ == '__main__':
  613. do_text_cut(write_mode= 'w')
  614. do_image_generate(write_mode= 'a')
  615. # do_generate_charset(crnn_config.train_label_path, crnn_config.charset_path)
  • 1

    data_gengretor.py则存放的是一些数据的预处理函数,用于训练和测试时调用。


 
 
  1. import re
  2. import os
  3. import PIL
  4. import math
  5. import numpy as np
  6. from PIL import Image
  7. from crnn.config import seed
  8. from captcha.image import ImageCaptcha
  9. def get_img_label(label_path, images_path):
  10. """
  11. 获取图像路径列表和图像标签列表
  12. :param label_path: 图像路径、标签存放文件对应的路径. [str]
  13. :param images_path: 图像路径. [str]
  14. :return:
  15. """
  16. with open(label_path, 'r', encoding= 'utf-8') as f:
  17. lines = f.read()
  18. lines = lines.split( '\n')
  19. img_path_list = []
  20. img_label_list = []
  21. for line in lines[: -1]:
  22. this_img_path, this_img_label = line.split( '\t')
  23. this_img_path = os.path.join(images_path, this_img_path)
  24. img_path_list.append(this_img_path)
  25. img_label_list.append(this_img_label)
  26. return img_path_list, img_label_list
  27. def get_charsets(dict=None, mode=1, charset_path=None):
  28. """
  29. 生成字符集
  30. :param mode: 当mode=1时,则生成实时验证码进行训练,此时生成验证码的字符集存放在dict路径下的charsets.txt下,
  31. 当mode=2时,则采用真实场景的图像进行训练,此时会读取data文件夹下label.txt中所有的文本标签,
  32. 然后汇总去重得到所有的字符集
  33. :param dict: 字符集文件路径
  34. :param charset_path: 字符集文件存储路径,only use with mode=2
  35. :return:
  36. """
  37. if mode == 1:
  38. with open(dict, 'r', encoding= 'utf-8') as f:
  39. lines = f.readlines()
  40. charsets = ''.join(lines)
  41. else:
  42. with open(charset_path, 'r', encoding= 'utf-8') as fr:
  43. charsets = fr.read()
  44. charsets = re.sub( '\n|\t|', '', charsets)
  45. charsets = list(set(list(charsets)))
  46. charsets = sorted(charsets)
  47. charsets = ''.join(charsets)
  48. charsets = charsets.encode( 'utf-8').decode( 'utf-8')
  49. return charsets
  50. def gen_random_text(charsets, min_len, max_len):
  51. """
  52. 生成长度在min_len到max_len的随机文本
  53. :param charsets: 字符集合. [str]
  54. :param min_len: 最小文本长度. [int]
  55. :param max_len: 最长文本长度. [int]
  56. :return:返回生成的文本编码序列和文本字符串
  57. """
  58. length = seed.random_integers(low=min_len, high=max_len)
  59. idxs = seed.randint(low= 0, high=len(charsets), size=length)
  60. str = ''.join([charsets[i] for i in idxs])
  61. return idxs, str
  62. def captcha_gen_img(text, image_shape, fonts):
  63. """
  64. 将文本生成对应的验证码图像
  65. :param text: 输入的文本. [str]
  66. :param image_shape: 图像的尺寸. [list]
  67. :param fonts: 字体文件路径列表. [list]
  68. :return:
  69. """
  70. image = ImageCaptcha(height=image_shape[ 0], width=image_shape[ 1], fonts=fonts)
  71. data = image.generate_image(text)
  72. data = np.reshape(np.frombuffer(data.tobytes(), dtype=np.uint8), image_shape)
  73. return data
  74. def captcha_batch_gen(batch_size, charsets, min_len, max_len, image_shape, blank_symbol, fonts):
  75. """
  76. 生成一个batch验证码数据集,每个batch包含三部分,分别是图像、每张图像的宽度、图像的标签
  77. :param batch_size: batch_size
  78. :param charsets: 字符集合
  79. :param min_len: 最小的文本长度
  80. :param max_len: 最大的文本长度
  81. :param image_shape: 生成的图像尺寸
  82. :param blank_symbol: 当文本长度小于最大的长度时,对其尾部进行padding的数字
  83. :param fonts: 字体文件路径列表
  84. :return:
  85. """
  86. batch_labels = []
  87. batch_images = []
  88. batch_image_widths = []
  89. for _ in range(batch_size):
  90. idxs, text = gen_random_text(charsets, min_len, max_len)
  91. image = captcha_gen_img(text, image_shape, fonts)
  92. image = image / 255
  93. pad_size = max_len - len(idxs)
  94. if pad_size > 0:
  95. idxs = np.pad(idxs, pad_width=( 0, pad_size), mode= 'constant', constant_values=blank_symbol)
  96. batch_image_widths.append(image.shape[ 1])
  97. batch_labels.append(idxs)
  98. batch_images.append(image)
  99. batch_labels = np.array(batch_labels, dtype=np.int32)
  100. batch_images = np.array(batch_images, dtype=np.float32)
  101. batch_image_widths = np.array(batch_image_widths, dtype=np.int32)
  102. return batch_images, batch_image_widths, batch_labels
  103. def scence_batch_gen(batch_img_list, batch_img_label_list,
  104. charsets, image_shape, max_len, blank_symbol):
  105. """
  106. 生成一个batch真实场景数据集,每个batch包含三部分,分别是图像、每张图像的宽度、图像的标签
  107. :param batch_img_list: 图像路径列表
  108. :param batch_img_label_list: 图像标签列表
  109. :param charsets: 字符集字符串
  110. :param image_shape: 生成的图像尺寸
  111. :param max_len: 文本序列的最大长度
  112. :param blank_symbol: 当文本长度小于最大的长度时,对其尾部进行padding的数字
  113. :return:
  114. """
  115. batch_labels = []
  116. batch_image_widths = []
  117. batch_size = len(batch_img_label_list)
  118. batch_images = np.zeros(shape=(batch_size, image_shape[ 0], image_shape[ 1], image_shape[ 2]), dtype=np.float32)
  119. for i, path, label in zip(range(batch_size), batch_img_list, batch_img_label_list):
  120. # 对图像进行放缩
  121. image = Image.open(path)
  122. img_size = image.size
  123. height_ratio = image_shape[ 0] / img_size[ 1]
  124. if int(img_size[ 0] * height_ratio) > image_shape[ 1]:
  125. new_img_size = (image_shape[ 1], image_shape[ 0])
  126. image = image.resize(new_img_size, Image.ANTIALIAS).convert( 'RGB')
  127. image = np.array(image, np.float32)
  128. image = image / 255
  129. batch_images[i, :, :, :] = image
  130. else:
  131. new_img_size = (int(img_size[ 0] * height_ratio), image_shape[ 0])
  132. image = image.resize(new_img_size, Image.ANTIALIAS).convert( 'RGB')
  133. image = np.array(image, np.float32)
  134. image = image / 255
  135. batch_images[i, :image.shape[ 0], :image.shape[ 1], :] = image
  136. # 对标签进行编码
  137. if len(label) > max_len:
  138. label = label[:max_len]
  139. idxs = [charsets.index(i) for i in label]
  140. # 对标签进行padding
  141. pad_size = max_len - len(idxs)
  142. if pad_size > 0:
  143. idxs = np.pad(idxs, pad_width=( 0, pad_size), mode= 'constant', constant_values=blank_symbol)
  144. batch_image_widths.append(image_shape[ 1])
  145. batch_labels.append(idxs)
  146. batch_labels = np.array(batch_labels, dtype=np.int32)
  147. batch_image_widths = np.array(batch_image_widths, dtype=np.int32)
  148. return batch_images, batch_image_widths, batch_labels
  149. def load_images(batch_img_list, image_shape):
  150. """
  151. 生成一个batch真实场景数据集,每个batch包含三部分,分别是图像、每张图像的宽度、图像的标签
  152. :param batch_img_list: 图像路径列表或图像列表[list]
  153. :param image_shape: 生成的图像尺寸
  154. :return:
  155. """
  156. # 参数为图像路径列表
  157. if isinstance(batch_img_list[ 0], str):
  158. batch_size = len(batch_img_list)
  159. batch_image_widths = []
  160. batch_images = np.zeros(shape=(batch_size, image_shape[ 0], image_shape[ 1], image_shape[ 2]), dtype=np.float32)
  161. for i, path in zip(range(batch_size), batch_img_list):
  162. # 对图像进行放缩
  163. image = Image.open(path)
  164. img_size = image.size
  165. height_ratio = image_shape[ 0] / img_size[ 1]
  166. if int(img_size[ 0] * height_ratio) > image_shape[ 1]:
  167. new_img_size = (image_shape[ 1], image_shape[ 0])
  168. image = image.resize(new_img_size, Image.ANTIALIAS).convert( 'RGB')
  169. image = np.array(image, np.float32)
  170. image = image / 255
  171. batch_images[i, :, :, :] = image
  172. else:
  173. new_img_size = (int(img_size[ 0] * height_ratio), image_shape[ 0])
  174. image = image.resize(new_img_size, Image.ANTIALIAS).convert( 'RGB')
  175. image = np.array(image, np.float32)
  176. image = image / 255
  177. batch_images[i, :image.shape[ 0], :image.shape[ 1], :] = image
  178. batch_image_widths.append(image_shape[ 1])
  179. # 参数为图像列表
  180. elif isinstance(batch_img_list[ 0], PIL.Image.Image):
  181. batch_size = len(batch_img_list)
  182. batch_image_widths = []
  183. batch_images = np.zeros(shape=(batch_size, image_shape[ 0], image_shape[ 1], image_shape[ 2]), dtype=np.float32)
  184. for i in range(batch_size):
  185. # 对图像进行放缩
  186. image = batch_img_list[i]
  187. img_size = image.size
  188. height_ratio = image_shape[ 0] / img_size[ 1]
  189. if int(img_size[ 0] * height_ratio) > image_shape[ 1]:
  190. new_img_size = (image_shape[ 1], image_shape[ 0])
  191. image = image.resize(new_img_size, Image.ANTIALIAS).convert( 'RGB')
  192. image = np.array(image, np.float32)
  193. image = image / 255
  194. batch_images[i, :, :, :] = image
  195. else:
  196. new_img_size = (int(img_size[ 0] * height_ratio), image_shape[ 0])
  197. image = image.resize(new_img_size, Image.ANTIALIAS).convert( 'RGB')
  198. image = np.array(image, np.float32)
  199. image = image / 255
  200. batch_images[i, :image.shape[ 0], :image.shape[ 1], :] = image
  201. batch_image_widths.append(image_shape[ 1])
  202. return batch_images, batch_image_widths
  • 1

     最后是模型的类文件,主要是定义模型的结构和损失函数以及训练函数,其代码如下:


 
 
  1. import os
  2. import random
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow.contrib import slim
  6. from tensorflow.contrib.rnn import BasicLSTMCell
  7. from crnn.data_generator import get_charsets, captcha_batch_gen, scence_batch_gen, get_img_label
  8. class CRNN(object):
  9. def __init__(self,
  10. image_shape,
  11. min_len,
  12. max_len,
  13. lstm_hidden,
  14. pool_size,
  15. learning_decay_rate,
  16. learning_rate,
  17. learning_decay_steps,
  18. mode,
  19. dict,
  20. is_training,
  21. train_label_path,
  22. train_images_path,
  23. charset_path):
  24. self.m
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/酷酷是懒虫/article/detail/965763
推荐阅读
相关标签
  

闽ICP备14008679号