赞
踩
本文分享自华为云社区《CTPN+CRNN 算法端到端实现文字识别》,作者:HWCloudAI。
光学字符识别(英语:Optical Character Recognition,OCR)是指对文本资料的图像文件进行分析识别处理,获取文字及版面信息的过程。发展时间较长,使用很普遍。OCR作为计算机视觉中较早使用深度学习技术的领域,有很多优秀的模型出现。普遍的深度学习下的OCR技术将文字识别过程分为:文本区域检测以及字符识别。
文字区域检测:将图片中出现的文本位置检测出来,可能存在不同语言,不同文字大小,不同角度倾斜,不同程度遮挡等情况。CTPN网络结合了CNN与LSTM深度网络,通过固定宽度的anchor提取proposal,能有效的检测出复杂场景的横向分布的文字区域,不定长度文本识别效果较好,是目前使用广泛的文字检测算法。
字符识别算法:将文本区域的字符识别出来。通过深度神经网络对目标区域进行特征提取,然后对固定特征进行提取和比对,得出识别结果。采用文本识别网络CRNN+CTC。CRNN全称为卷积循环神经网络,将特征提取,序列建模以及转录整合到统一的模型框架中。主要用于端到端地对不定长的文本序列进行识别,不用先对单个文字进行切割,而是将文本识别转化为时序依赖的序列学习问题,就是基于图像的序列识别。如下图,CRNN网络分为:卷积层、循环层和转录层三部分,CTC为无词典的转录方式, 不会被局限在预定义词汇范围中。
了解了文本区域检测以及字符识别后,下面详细讲解完整的端到端OCR流程:
(1)准备一张含有文字的原图;
(2)对原图进行文字位置的检测,检测结果可能是水平矩形框,也可能是倾斜矩形框;
(3)从原图中把文字框对应的图片切下来,并旋转正,得到水平的文字块切片图;
(4)对每个文字块切片图依次进行字符识别,每个切片图的识别结果汇总起来,就得到原图的文字识别结果。
因此完整的端到端OCR流程是:输入原图 -> 文字检测 -> 文字块切片 -> 字符识别 -> 识别结果汇总。
理论部分到此告一段落,下面开始在ModelArts中体验实战项目开发!
注意事项:
本案例使用框架**:** TensorFlow-1.8
本案例使用硬件规格**:** 8 vCPU + 64 GiB + 1 x Tesla V100-PCIE-32GB
进入运行环境方法:点此链接进入AI Gallery,点击Run in ModelArts按钮进入ModelArts运行环境,如需使用GPU,您可以在ModelArts JupyterLab运行界面右边的工作区进行切换
运行代码方法**:** 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码
JupyterLab的详细用法**:** 请参考《ModelAtrs JupyterLab使用指导》
碰到问题的解决办法**:** 请参考《ModelAtrs JupyterLab常见问题解决办法》
本案例中已经将CTPN和CRNN的代码模型都整合到一起
- import os
- from modelarts.session import Session
- sess = Session()
-
- if sess.region_name == 'cn-north-1':
- bucket_path="modelarts-labs/notebook/DL_ocr_crnn_sequence_recognition/E2E_ocr.zip"
- elif sess.region_name == 'cn-north-4':
- bucket_path="modelarts-labs-bj4/notebook/DL_ocr_crnn_sequence_recognition/E2E_ocr.zip"
- else:
- print("请更换地区到北京一或北京四")
-
- if not os.path.exists('E2E_ocr'):
- sess.download_data(bucket_path=bucket_path, path="./E2E_ocr.zip")
-
- if os.path.exists('./E2E_ocr.zip'):
- status = os.system("unzip -q E2E_ocr.zip")
- if status == 0:
- os.system("rm E2E_ocr.zip")
Successfully download file modelarts-labs-bj4/notebook/DL_ocr_crnn_sequence_recognition/E2E_ocr.zip from OBS to local ./E2E_ocr.zip
- import shutil
- import cv2
- import numpy as np
- import datetime
- import os
- import sys
- import time
- import json
- import codecs
- from PIL import Image
- import tensorflow as tf
- sys.path.append(os.getcwd() + '/E2E_ocr')
- sys.path.append(os.getcwd() + '/E2E_ocr/CRNN/')
- from collections import OrderedDict
- from tensorflow.contrib import slim
-
- from CTPN import data_provider as data_provider
- from CTPN.model import mean_image_subtraction,Bilstm,lstm_fc,loss
- from CTPN import vgg
- from CTPN import model
- from CTPN.utils.rpn_msr.proposal_layer import proposal_layer
- from CTPN.utils.text_connector.detectors import TextDetector
- from CTPN.utils.image import resize_image
- /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
-
- _np_qint8 = np.dtype([("qint8", np.int8, 1)])
-
- /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
-
- _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
-
- /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:521: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
-
- _np_qint16 = np.dtype([("qint16", np.int16, 1)])
-
- /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:522: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
-
- _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
-
- /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:523: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
-
- _np_qint32 = np.dtype([("qint32", np.int32, 1)])
-
- /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
-
- np_resource = np.dtype([("resource", np.ubyte, 1)])
- !pip install -i https://pypi.tuna.tsinghua.edu.cn/simple keras==2.1.6
- !pip install -i https://pypi.tuna.tsinghua.edu.cn/simple keras_applications==1.0.5
- Requirement already satisfied: keras==2.1.6 in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages
-
- Requirement already satisfied: numpy>=1.9.1 in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from keras==2.1.6)
-
- Requirement already satisfied: six>=1.9.0 in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from keras==2.1.6)
-
- Requirement already satisfied: scipy>=0.14 in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from keras==2.1.6)
-
- Requirement already satisfied: pyyaml in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from keras==2.1.6)
-
- Requirement already satisfied: h5py in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from keras==2.1.6)
-
- [33mYou are using pip version 9.0.1, however version 21.0.1 is available.
-
- You should consider upgrading via the 'pip install --upgrade pip' command.[0m
-
- Requirement already satisfied: keras_applications==1.0.5 in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages
-
- Requirement already satisfied: h5py in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from keras_applications==1.0.5)
-
- Requirement already satisfied: keras>=2.1.6 in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from keras_applications==1.0.5)
-
- Requirement already satisfied: numpy>=1.9.1 in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from keras_applications==1.0.5)
-
- Requirement already satisfied: six in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from h5py->keras_applications==1.0.5)
-
- Requirement already satisfied: pyyaml in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from keras>=2.1.6->keras_applications==1.0.5)
-
- Requirement already satisfied: scipy>=0.14 in /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages (from keras>=2.1.6->keras_applications==1.0.5)
-
- [33mYou are using pip version 9.0.1, however version 21.0.1 is available.
-
- You should consider upgrading via the 'pip install --upgrade pip' command.[0m
- from keras.layers import Flatten, BatchNormalization, Permute, TimeDistributed, Dense, Bidirectional, GRU
- from keras.layers import Input, Conv2D, MaxPooling2D, ZeroPadding2D,Lambda
- from keras.models import Model
- from keras.optimizers import SGD
- from keras import backend as K
-
- import keys as keys
- from CRNN_model import decode
Using TensorFlow backend.
- checkpoint_path = './E2E_ocr/models/checkpoints/' # 训练模型保存路径
- vgg_path = "./E2E_ocr/models/vgg_16.ckpt" # vgg16预训练模型
- image_path = './E2E_ocr/data/CTW-200' # 训练集图片路径
-
- CHECKPOINT_PATH = './E2E_ocr/models/checkpoints' # 测试模型保存路径
- os.environ['CUDA_VISIBLE_DEVICES'] = '0' #计算设备调用,空值为CPU计算,数字为GPU的序号
-
- tf.reset_default_graph()
- # 定义模型输入信息占位符
- input_image = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='input_image')
- input_im_info = tf.placeholder(tf.float32, shape=[None, 3], name='input_im_info')
- init_op = tf.initialize_all_variables()
- # 定义模型训练步骤数
- global_step = tf.variable_scope('global_step', [], initializer=tf.constant_initializer(0))
-
- # 加载预训练模型
- bbox_pred, cls_pred, cls_prob = model.model(input_image)
- variable_averages = tf.train.ExponentialMovingAverage(0.997, global_step)
- # 将变量存储到saver中
- saver = tf.train.Saver(variable_averages.variables_to_restore())
-
- ctpn_sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
- with ctpn_sess.as_default():
- # 加载预训练模型权重信息
- ckpt_state = tf.train.get_checkpoint_state(CHECKPOINT_PATH)
- model_path = os.path.join(CHECKPOINT_PATH, os.path.basename(ckpt_state.model_checkpoint_path))
- saver.restore(ctpn_sess, model_path)
- print('CTPN model load success')
- WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.8/lib/python3.6/site-packages/tensorflow/python/util/tf_should_use.py:118: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02.
-
- Instructions for updating:
-
- Use `tf.global_variables_initializer` instead.
- CTPN model load success
CTPN为了更好检测出文本区域,anchor为 宽度固定为16 , 高度为[11, 16, 23, 33, 48, 68, 97, 139, 198, 283] 的文本框,共10个anchor。
这样的设计是为了更好检测出文字区域的水平位置,在文字检测中,检测文字的水平范围比较垂直范围要更困难。将anchor的宽度固定,只检测10个高度的anchor,尤其在面对多个分离的文本的情况时,能够更好检测文字的范围。
不同的anchor得到了边界框,利用nms(非极大值抑制)进行边界框回归计算,最终得到细粒度的文本区域。
下图给出CRNN的结构参考:
- characters = keys.alphabet[:]
- nclass=len(characters)+1
-
- input = Input(shape=(32, None, 1), name='the_input')
- # CNN卷积层部分
- m = Conv2D(64, kernel_size=(3, 3), activation='relu', padding='same', name='conv1')(input)
- m = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool1')(m)
- m = Conv2D(128, kernel_size=(3, 3), activation='relu', padding='same', name='conv2')(m)
- m = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), name='pool2')(m)
- m = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same', name='conv3')(m)
- m = Conv2D(256, kernel_size=(3, 3), activation='relu', padding='same', name='conv4')(m)
-
- m = ZeroPadding2D(padding=(0, 1))(m)
- m = MaxPooling2D(pool_size=(2, 2), strides=(2, 1), padding='valid', name='pool3')(m)
-
- m = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same', name='conv5')(m)
- m = BatchNormalization(axis=1)(m)
- m = Conv2D(512, kernel_size=(3, 3), activation='relu', padding='same', name='conv6')(m)
- m = BatchNormalization(axis=1)(m)
- m = ZeroPadding2D(padding=(0, 1))(m)
- m = MaxPooling2D(pool_size=(2, 2), strides=(2, 1), padding='valid', name='pool4')(m)
- m = Conv2D(512, kernel_size=(2, 2), activation='relu', padding='valid', name='conv7')(m)
-
- m = Permute((2, 1, 3), name='permute')(m)
- m = TimeDistributed(Flatten(), name='timedistrib')(m)
- # RNN循环层部分
- m = Bidirectional(GRU(256, return_sequences=True), name='blstm1')(m)
- m = Dense(256, name='blstm1_out', activation='linear')(m)
- m = Bidirectional(GRU(256, return_sequences=True), name='blstm2')(m)
- y_pred = Dense(nclass, name='blstm2_out', activation='softmax')(m)
-
- basemodel = Model(inputs=input, outputs=y_pred)
- basemodel.load_weights('./E2E_ocr/CRNN/model_crnn.h5')
- print("CRNN model load success")
CRNN model load success
- from CTPN.utils.text_connector.text_connect_cfg import Config as TextLineCfg
-
- def ctpn_text_detection(img_path):
- """
- CTPN文字位置检测函数
- :param img_path: 图片路径
- :return: img: 需要进行文字检测的图片
- :return: boxes: 图片上检测到的文字框
- """
- try:
- im = cv2.imread(img_path)[:, :, ::-1]
- except Exception as e:
- raise Exception("打开图片文件失败,图片路径:", img_path)
- img, (rh, rw) = resize_image(im) #对图片进行形状调整
- h, w, c = img.shape
- im_info = np.array([h, w, c]).reshape([1, 3])
- #将图片信息传入模型得出预测结果,分别为文字区域坐标以及其得分
- bbox_pred_val, cls_prob_val = ctpn_sess.run([bbox_pred, cls_prob],feed_dict={input_image: [img],input_im_info: im_info})
- textsegs_total, _ = proposal_layer(cls_prob_val, bbox_pred_val, im_info)
- scores = textsegs_total[:, 0]
- textsegs = textsegs_total[:, 1:5]
- """文本框合并策略"""
- TextLineCfg.MAX_HORIZONTAL_GAP = 50 # 两个框之间的距离小于50,才会被判定为临近框。该值越小,两个框之间要进行合并的要求就越高
- TextLineCfg.TEXT_PROPOSALS_MIN_SCORE = 0.7 # 单个小文本框的置信度,高于这个置信度的框才会被合并。该值越大,越多的框就会被丢弃掉
- TextLineCfg.TEXT_PROPOSALS_NMS_THRESH = 0.2 # 非极大值抑制阈值。该值越大,越多的框就会被丢弃掉
- TextLineCfg.MIN_V_OVERLAPS = 0.7 # 两个框之间的垂直重合度大于0.7,才会被判定为临近框。该值越大,两个在垂直方向上有偏差的框进行合并的可能性就越小
- textdetector = TextDetector(DETECT_MODE='H') # DETECT_MODE有两种取值:'H'和'O','H'模式适合检测水平文字,'O'模式适合检测有轻微倾斜的文字
- """文本框合并策略"""
- boxes = textdetector.detect(textsegs, scores[:, np.newaxis], img.shape[:2])
- boxes = np.array(boxes, dtype=np.int)
-
- return img, boxes
- def img_transform_perspective(image, points, w_pad_rate=(0.0, 0.0), h_pad_rate=(0.0, 0.0)):
- """
- 根据四个点进行透视变换,将四个点表示的四边形图变换成水平矩形图
- :param image: 原图
- :param points: 参考的四个点,坐标顺序是xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax
- :param w_pad_rate: 数组(rate1, rate2),对图像宽度左右两边的扩宽比例
- :param h_pad_rate: 数组(rate1, rate2),对图像宽度上下两边的扩宽比例
- :return: persp_img: 变换后的图
- :return: points2: 变换后的四点
- """
- if not isinstance(points, np.ndarray):
- points = np.array(points)
- points = points.reshape((4, 2))
- widths = np.linalg.norm(points[::2] - points[1::2], axis=1) # points的4点组成的四边形的上下两边的长度
- width = int(round(widths.mean()))
- heights = np.linalg.norm(points[:2] - points[3:1:-1], axis=1) # points的4点组成的四边形的左右两边的长度
- height = int(round(heights.mean()))
-
- points2 = np.array([[0, 0], [width - 1, 0],
- [width - 1, height - 1], [0, height - 1]], np.float32)
- points2 += np.array([int(width * w_pad_rate[0]), int(height * h_pad_rate[0])]).reshape(1, 2)
- size = (int(width * (1 + w_pad_rate[0] + w_pad_rate[1])),
- int(height * (1 + h_pad_rate[0] + h_pad_rate[1])))
-
- mat = cv2.getPerspectiveTransform(points.astype(np.float32), points2)
- persp_img = cv2.warpPerspective(image, mat, size,
- borderMode=cv2.BORDER_CONSTANT,
- borderValue=(255, 255, 255))
-
- return persp_img, points2
- def crnn_ocr(img):
- """
- CRNN字符识别函数
- :param img: 需要进行字符识别的图片
- :return: ocr_result: 图片的字符识别结果,数据类型为字符串
- """
- img = img.convert('L')
-
- img = img.convert('L') # 图片灰度化
-
- scale = img.size[1] * 1.0 / 32 # 图片尺寸调整,把图片高度调整为32
- w = img.size[0] / scale
- w = int(w)
- img = img.resize((w, 32))
- img = np.array(img).astype(np.float32) / 255.0
- X = img.reshape((32, w, 1))
- X = np.array([X])
- y_pred = basemodel.predict(X) # 预测
- ocr_result = decode(y_pred) # 处理预测结果
-
- return ocr_result
- img = Image.open('./E2E_ocr/test_dataset/text.png')
- img
- test_dir = './E2E_ocr/test_dataset' # 待测试图片目录
- save_results = True
- output_dir = test_dir + '_output'
- if not os.path.exists(output_dir):
- os.mkdir(output_dir)
- ocr_results = OrderedDict()
- files = os.listdir(test_dir)
- for file_name in files:
- if not (file_name.endswith('jpg') or file_name.endswith('png')
- or file_name.endswith('JPG') or file_name.endswith('PNG')):
- continue
- print(file_name, 'ocr result:')
- file_path = os.path.join(test_dir, file_name)
-
- img, boxes = ctpn_text_detection(file_path) # step1, 检测文字位置
- sorted_boxes = sorted(boxes.tolist(), key = lambda x: (x[1], x[0])) # step2, 对文字框进行排序,优先按文字框左上顶点的y坐标升序排序,其次按x坐标升序排序
- for index, box in enumerate(sorted_boxes):
- cut_text_img, _ = img_transform_perspective(img, box[:8]) # step3, 从原图上切割出各个文字块,并将倾斜的文字块变换为水平矩形文字块
- ocr_result = crnn_ocr(Image.fromarray(cut_text_img)) # step4, 对每个文字块进行字符识别
- ocr_results[str(index)] = ocr_result
- print(str(index) + ',', ocr_result)
-
- if save_results:
- draw_img = img[:, :, ::-1].copy()
- for i, box in enumerate(boxes):
- cv2.polylines(draw_img, [box[:8].astype(np.int32).reshape((-1, 1, 2))], True, color=(0, 0, 255), thickness=2)
- cv2.imwrite(os.path.join(output_dir, file_name), draw_img)
- #将输出结果转为json格式
- with codecs.open(os.path.join(output_dir, file_name.split('.')[0] + '.json'), 'w', 'utf-8') as f:
- json.dump(ocr_results, f, indent=4, ensure_ascii=False)
- print('end')
- text.png ocr result:
-
- 0, A1正在改变我们的生活,
-
- 1, 正在改变我们身边的各行各业,
-
- 2, 但是这条通往智能世界的路并不平坦,
-
- 3, 其中一个巨大鸿沟就是AI人才的稀缺。
-
- 4, 在中国庞大的I从业群体,
-
- 5, A开发者缺口达百万级。
-
- 6, A1将成为全民普及性的技能,
-
- 7, 所以今天华为云El为大家带来《2020华为云AI实战营》免费课程,
-
- 8, 大幅降低A1拳习门]椤,
-
- 9, 帮助庞大的软件开发者群体快速拳握A1技能,
-
- 10, 把AI用起来。
-
- end
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。