当前位置:   article > 正文

YOLO10+OCR识别电子发票的指定文字内容_yolo 文字识别

yolo 文字识别

本文概述

实验室里经常有大量的发票需要报销,每次都需要人工一张一张的去手动核对发票上的关键信息是否符合要求,于是我打算使用yolo+ocr的技术去实现自动核对电子发票上的关键信息。ps:因为发票信息可能比较敏感,因此本文中提到的发票数据集和合成逼真发票图片的代码将不被提供

YOLO部分

一、准备训练所需的数据集

因为今年广西才开始全面实施电子发票,所以我手头上的电子发票只有百来张,这点数据量还是太少了,因此我按照真实数据和合成数据,一比五的比例去制作了YOLO的训练数据集,然后按照训练集:测试集=8:2的比例去划分数据集(共800张图片)

其中批量合成的电子发票图片,是使用一张真实发票在抹除部分文字后按照原电子发票的格式随机填充相同字号相同字体去合成的。标签是使用labelimg去标注的

二、克隆yolov10项目并修改配置文件

创建并修改数据集配置

首先我们在yolov10-main的ultralytics/cfg/datasets目录下,创建一个新的数据集配置文件并命名为Fapiao.yml

然后将train、val改为你们自定义数据集的训练集路径和测试集路径,将nc改为类别数,然后为每个类别赋予名字

创建train.py

在yolov10-main的根目录下创建一个名为train.py的文件,然后将模型配置文件和数据集配置文件的路径修改正确,然后超参数可以按照我这里的来进行设置或者根据你们数据集的特点进行设置,我这里将部分数据增强都给关了,并且将dfl损失函数的权重设置为0,将cls的权重设置为1.0

  1. # coding:utf-8
  2. from ultralytics import YOLOv10
  3. # 模型配置文件
  4. model_yaml_path = "ultralytics/cfg/models/v10/yolov10m.yaml"
  5. # 数据集配置文件
  6. data_yaml_path = 'ultralytics/cfg/datasets/Fapiao.yaml'
  7. if __name__ == '__main__':
  8. # 加载预训练模型
  9. model = YOLOv10(model_yaml_path)
  10. # 训练模型
  11. results = model.train(data=data_yaml_path, epochs=100, batch=32,cls=1.0,dfl=0.0, name='train_v10',cos_lr=False,imgsz=800,lr0=0.01,lrf=0.0001,
  12. translate=0,scale=0,fliplr=0,mosaic=0,erasing=0,hsv_h=0.1,hsv_v=0.1,hsv_s=0.1)

三、执行训练

到这步默认大家已经完成了所有yolo必须环境的配置和安装,关于环境方面不做赘述。

直接cd进入yolov10-main的目录下执行python train.py即可开始训练

快进.............

训练完后的模型评估结果如上图所示,预测框的精度达到了99%,mAP50-95的值也达到了94.5%,可以说训练效果非常好了

  1. # coding:utf-8
  2. import torch
  3. from ultralytics import YOLOv10
  4. # 模型配置文件
  5. model_yaml_path = "ultralytics/cfg/models/v10/yolov10m.yaml"
  6. # 数据集配置文件
  7. data_yaml_path = 'ultralytics/cfg/datasets/Fapiao.yaml'
  8. if __name__ == '__main__':
  9. # 加载预训练模型
  10. model = YOLOv10(model_yaml_path).load("/home/ma-user/ocr/yolov10-main/runs/detect/train_v1015/weights/best.pt")
  11. # 训练模型
  12. results = model.val(data=data_yaml_path, batch=32, name='train_v10',imgsz=800,save_txt=True,iou=0.6,conf=0.001,max_det=300,rect=False)

这个是验证用的val.py文件的代码,需要强调的一点是,在官方的代码中val阶段是默认设置rect=True的,如果你训练阶段没有开启rect=True的话,请在val阶段也将rect设置为False,否则得出的指标是不准确的,或者部分类别无法识别。

如果你在训练阶段中rect设置为False的话,建议你将model.py中val的rect手动改为false,否则可能导致评估指标不准确

四、模型量化

1、导出onnx格式的模型(export.py)

执行export.py将.pt格式的模型文件转化为.onnx格式的模型文件

  1. import torch
  2. from ultralytics import YOLOv10
  3. # 加载预训练的YOLOv10模型
  4. model = YOLOv10('/home/ma-user/ocr/yolov10-main/runs/detect/train_v1015/weights/best.pt')
  5. # 导出模型为ONNX格式
  6. model.export(format='onnx',name='yolov10m',ops=11)

2、安装tensorRT8.5.3.1进行静态量化为int8(quantization.py)

执行quantization.py完成量化

  1. import json
  2. import os
  3. import pathlib
  4. from datetime import datetime
  5. import cv2 as cv
  6. import numpy as np
  7. import onnxruntime
  8. from ultralytics import YOLOv10
  9. import tensorrt as trt
  10. from tqdm import tqdm
  11. from polygraphy.backend.trt import NetworkFromOnnxPath, CreateConfig, EngineFromNetwork
  12. from polygraphy.backend.trt import Calibrator
  13. def _get_metadata():
  14. description = f'Ultralytics YOLOv8X model'
  15. names = {'0': '购买方名称', '1': '纳税人识别号', '2': '项目名称', '3': '数量', '4': '金额', '5': '发票号码'} # 各个检测类别索引和名字的对应关系
  16. metadata = {
  17. 'description': description,
  18. 'author': 'Ultralytics',
  19. 'license': 'AGPL-3.0 https://ultralytics.com/license',
  20. 'date': datetime.now().isoformat(),
  21. 'version': '10.0',
  22. 'stride': 32,
  23. 'task': 'detect',
  24. 'batch': 1,
  25. 'imgsz': [800, 800],
  26. 'names': names
  27. }
  28. return metadata
  29. def _calib_data_yolo8(onnx_input_name, onnx_input_shape, calibration_images_quantity, calibration_images_folder):
  30. print(f' {onnx_input_shape= }') #
  31. if onnx_input_shape[1] != 3: # ONNX 输入的形状可以是:1, 3, 1504, 1504。第一维度是深度通道。
  32. raise ValueError(f'Error, expected input depth is 3, '
  33. f'but {onnx_input_shape= }')
  34. calibration_images_folder = pathlib.Path(calibration_images_folder).expanduser().resolve()
  35. if not calibration_images_folder.exists():
  36. raise FileNotFoundError(f'{calibration_images_folder} does not exist.')
  37. print(f'{calibration_images_folder= }')
  38. batch_size = onnx_input_shape[0]
  39. required_height = onnx_input_shape[2]
  40. required_width = onnx_input_shape[3]
  41. # 初始化第 0 批数据。标定时必须给 engine 输入 FP32 格式的数据。
  42. output_images = np.zeros(shape=onnx_input_shape, dtype=np.float32)
  43. # 如果图片总数不够,则使用所有图片进行标定。
  44. calibration_images_quantity = min(calibration_images_quantity,
  45. len(os.listdir(calibration_images_folder)))
  46. print(f'Calibration images quantity: {calibration_images_quantity}')
  47. print(f'Calibrating ...')
  48. # 创建一个进度条。
  49. tqdm_images_folder = tqdm(calibration_images_folder.iterdir(),
  50. total=calibration_images_quantity, ncols=80)
  51. for i, one_image_path in enumerate(tqdm_images_folder):
  52. # 只有一个循环完整结束后,tqdm 进度条才会前进一格。因此要在 for 循环的开头
  53. # 使用 i == calibration_images_quantity 作为停止条件,才能看到完整的 tqdm 进度条
  54. if i == calibration_images_quantity:
  55. break
  56. bgr_image = cv.imread(str(one_image_path)) # noqa
  57. # 改变图片尺寸,注意是宽度 width 在前。
  58. bgr_image = cv.resize(bgr_image, (required_width, required_height)) # noqa
  59. one_rgb_image = bgr_image[..., ::-1] # 从 bgr 转换到 rgb
  60. one_image = one_rgb_image / 255 # 归一化,转换到 [0, 1]
  61. one_image = one_image.transpose(2, 0, 1) # 形状变为 depth, height, width
  62. batch_index = i % batch_size # 该批次数据中的索引位置
  63. output_images[batch_index] = one_image # 把该图片放入到该批次数据的对应位置。
  64. if batch_index == (batch_size - 1): # 此时一个 batch 的数据已经准备完成
  65. one_batch_data = {onnx_input_name: output_images}
  66. yield one_batch_data # 以生成器 generator 的形式输出数据
  67. output_images = np.zeros_like(output_images) # 初始化下一批次数据。
  68. def onnx_2_trt_by_polygraphy(onnx_file, conversion_target='int8', engine_suffix='engine',
  69. calibration_method='min-max', calibration_images_quantity=64,
  70. calibration_images_folder=None, onnx_input_shape=None):
  71. if conversion_target.lower() not in ['int8', 'fp16', 'fp32']:
  72. raise ValueError(f"The conversion_target must be one of ['int8', 'fp16', 'fp32'], "
  73. f"but get {conversion_target= }")
  74. if engine_suffix not in ['plan', 'engine', 'trt']:
  75. raise ValueError(f"The engine_suffix must be one of ['plan', 'engine', 'trt'], "
  76. f"but get {engine_suffix= }")
  77. onnx_file = pathlib.Path(onnx_file).expanduser().resolve()
  78. if not onnx_file.exists():
  79. raise FileNotFoundError(f'Onnx file not found: {onnx_file}')
  80. print(f"Succeeded finding ONNX file! {onnx_file= }")
  81. print(f'Polygraphy inspecting model:')
  82. os.system(f"polygraphy inspect model {onnx_file}") # 用 polygraphy 查看 ONNX 模型
  83. network = NetworkFromOnnxPath(str(onnx_file)) # 必须输入字符串给 NetworkFromOnnxPath
  84. TRT_LOGGER = trt.Logger()
  85. builder = trt.Builder(TRT_LOGGER)
  86. # 1. 准备转换 engine 文件时的配置。包括 flag 等。
  87. builder_config = builder.create_builder_config()
  88. print(f'{builder_config= }')
  89. converted_trt_name = (f"{onnx_file.stem}_optimization_{conversion_target}")
  90. if conversion_target.lower() == 'fp16':
  91. builder_config.flags |= 1 << int(trt.BuilderFlag.FP16)
  92. print(f'{builder_config.flags= }')
  93. elif conversion_target.lower() == 'int8':
  94. # 2. 准备 int8 量化所需的 5 个配置。
  95. # 2.1 设置 INT8 的 flag
  96. builder_config.set_flag(trt.BuilderFlag.INT8)
  97. print(f'{builder_config.flags= }')
  98. # 2.2 用 onnxruntime 获取模型输入的名字和形状.
  99. session = onnxruntime.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])
  100. onnx_input_name = session.get_inputs()[0].name
  101. if onnx_input_shape is None: # 查询 ONNX 中的输入张量形状。
  102. onnx_input_shape = session.get_inputs()[0].shape
  103. # 2.3 准备标定用的 cache 文件。
  104. calibration_cache_file = f"./{onnx_file.stem}_int8.cache"
  105. calibration_cache_file = pathlib.Path(calibration_cache_file).expanduser().resolve()
  106. if calibration_cache_file.exists(): # 始终使用一个新的 cache,才能每次都生成新的 TensorRT 模型。
  107. os.remove(calibration_cache_file)
  108. # 2.4 设置标定方法。
  109. if calibration_method == 'min-max':
  110. calibrator_class = trt.IInt8MinMaxCalibrator
  111. else:
  112. # 默认使用 entropy 方法,该方法通过减少量化时的信息损失 information loss,对模型进行标定。
  113. calibrator_class = trt.IInt8EntropyCalibrator2
  114. # 2.5 在 Calibrator 类中,传入标定方法,标定数据和 cache 等。
  115. builder_config.int8_calibrator = Calibrator(
  116. BaseClass=calibrator_class,
  117. data_loader=_calib_data_yolo8(onnx_input_name=onnx_input_name, onnx_input_shape=onnx_input_shape,
  118. calibration_images_quantity=calibration_images_quantity,
  119. calibration_images_folder=calibration_images_folder),
  120. cache=calibration_cache_file)
  121. int8_suffix = f'_{calibration_method}_images{calibration_images_quantity}'
  122. converted_trt_name = converted_trt_name + int8_suffix
  123. converted_trt = onnx_file.parent / (converted_trt_name + f'.{engine_suffix}')
  124. print('Building the engine ...')
  125. # 3. 按照前面的配置 config,设置 engine。注意 EngineFromNetwork 返回的是一个可调用对象 callable。
  126. build_engine = EngineFromNetwork(network, config=builder_config)
  127. # 4. 调用一次 build_engine,即可生成 engine,然后保存 TensorRT 模型即可。
  128. with build_engine() as engine, open(converted_trt, 'wb') as t:
  129. yolo8_metadata = _get_metadata() # 需要创建 YOLOv8 的原数据 metadata
  130. meta = json.dumps(yolo8_metadata) # 转换为 json 格式的字符串
  131. # 保存 TensorRT 模型时,必须先写入 metadata,然后再写入模型的数据。
  132. t.write(len(meta).to_bytes(4, byteorder='little', signed=True))
  133. t.write(meta.encode())
  134. t.write(engine.serialize())
  135. engine_saved = ''
  136. if not pathlib.Path(converted_trt).exists():
  137. engine_saved = 'not '
  138. print(f'Done! {converted_trt} is {engine_saved.upper()}saved.')
  139. return str(converted_trt)
  140. def validate_model(model_path, conf, iou, imgsz, dataset_split, agnostic_nms,
  141. batch_size=1, simplify_names=True, **kwargs):
  142. model_path = pathlib.Path(model_path).expanduser().resolve()
  143. if not model_path.exists():
  144. raise FileNotFoundError(f'Model not found: {model_path}')
  145. print(f'{model_path= }')
  146. print(f'{conf= }, {iou= }, {imgsz= }')
  147. model = YOLOv10(model_path, task='detect') # 须在创建模型时设置 task。
  148. detect_data = 'ultralytics/cfg/datasets/Fapiao.yaml'
  149. if (model_path.suffix == '.pt') and simplify_names:
  150. # model.names 只对 pt 模型有效,对 engine 模型无效。
  151. model.names[0] = 'foo' # 可以把类别的名字进行简化
  152. model.names[1] = 'bar'
  153. metrics = model.val(split=dataset_split, save=False,
  154. data=detect_data,
  155. agnostic_nms=agnostic_nms, batch=batch_size,
  156. conf=conf, iou=iou, imgsz=imgsz,
  157. **kwargs)
  158. map50 = round(metrics.box.map50, 3)
  159. print(f'{dataset_split} mAP50= {map50}')
  160. def main():
  161. onnx_file = '/home/ma-user/ocr/yolov10-main/runs/detect/train_v1015/weights/best.onnx'
  162. calibration_images = 64 # 也可以尝试 100, 32 等其它图片数量进行标定。
  163. calibration_images_folder = '/home/ma-user/ocr/yolov10-main/ultralytics/cfg/datasets/Fapiao/train/images' # 使用训练集的图片进行标定。
  164. saved_engine = onnx_2_trt_by_polygraphy(
  165. onnx_file=onnx_file, conversion_target='int8',
  166. engine_suffix='engine', calibration_images_quantity=calibration_images,
  167. calibration_images_folder=calibration_images_folder)
  168. # 3. 用验证集和测试集,检查 int8 量化后的模型指标。
  169. # 也可以输入 pt_model_path 验证 PyTorch 模型的指标。
  170. validate_model(model_path=saved_engine,dataset_split='val',imgsz=800,conf=0.5, iou=0.4,agnostic_nms=False,rect=False)
  171. if __name__ == '__main__':
  172. main()

3、量化后的指标

可以看到精度和mAP和量化前基本上没有区别

精度mAP50mAP50-90单张图片推理时间(V100)
量化前0.9910.9950.9450.1ms
int8量化后0.980.9950.9450.025ms

(量化后的推理结果,关键信息打码处理了)

OCR部分

一、克隆paddleOCR项目

GitHub - PaddlePaddle/PaddleOCR: Awesome multilingual OCR toolkits based on PaddlePaddle (practical ultra lightweight OCR system, support 80+ languages recognition, provide data annotation and synthesis tools, support training and deployment among server, mobile, embedded and IoT devices)

在github上克隆该项目到本地

二、准备数据

使用PPOCRLabel标注工具完成数据标注

https://github.com/PFCCLab/PPOCRLabel/blob/main/README_ch.md

三、微调PP-OCRv4文字检测模型

1、下载预训练模型

https://github.com/PaddlePaddle/PaddleOCR/blob/main/doc/doc_ch/models_list.md

2、修改训练配置

PaddleOCR-main/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml

将这个文件的配置进行修改

  1. Global:
  2. debug: false
  3. use_gpu: true
  4. epoch_num: &epoch_num 500
  5. log_smooth_window: 20
  6. print_batch_step: 10
  7. save_model_dir: ./output/ch_PP-OCRv4
  8. save_epoch_step: 10
  9. eval_batch_step:
  10. - 0
  11. - 10
  12. cal_metric_during_train: false
  13. checkpoints:
  14. pretrained_model: /home/ma-user/ocr/ch_PP-OCRv4_det_server_train/best_accuracy.pdparams
  15. save_inference_dir: null
  16. use_visualdl: false
  17. infer_img: doc/imgs_en/img_10.jpg
  18. save_res_path: ./checkpoints/det_db/predicts_db.txt
  19. distributed: true
  20. Architecture:
  21. model_type: det
  22. algorithm: DB
  23. Transform: null
  24. Backbone:
  25. name: PPHGNet_small
  26. det: True
  27. Neck:
  28. name: LKPAN
  29. out_channels: 256
  30. intracl: true
  31. Head:
  32. name: PFHeadLocal
  33. k: 50
  34. mode: "large"
  35. Loss:
  36. name: DBLoss
  37. balance_loss: true
  38. main_loss_type: DiceLoss
  39. alpha: 5
  40. beta: 10
  41. ohem_ratio: 3
  42. Optimizer:
  43. name: Adam
  44. beta1: 0.9
  45. beta2: 0.999
  46. lr:
  47. name: Cosine
  48. learning_rate: 0.0001 #(8*8c)
  49. warmup_epoch: 2
  50. regularizer:
  51. name: L2
  52. factor: 1e-6
  53. PostProcess:
  54. name: DBPostProcess
  55. thresh: 0.3
  56. box_thresh: 0.6
  57. max_candidates: 1000
  58. unclip_ratio: 1.5
  59. Metric:
  60. name: DetMetric
  61. main_indicator: hmean
  62. Train:
  63. dataset:
  64. name: SimpleDataSet
  65. data_dir: /home/ma-user/ocr/
  66. label_file_list:
  67. - /home/ma-user/ocr/TrainLabel.txt
  68. ratio_list: [1.0]
  69. transforms:
  70. - DecodeImage:
  71. img_mode: BGR
  72. channel_first: false
  73. - DetLabelEncode: null
  74. - CopyPaste: null
  75. - IaaAugment:
  76. augmenter_args:
  77. - type: Fliplr
  78. args:
  79. p: 0.5
  80. - type: Affine
  81. args:
  82. rotate:
  83. - -10
  84. - 10
  85. - type: Resize
  86. args:
  87. size:
  88. - 0.5
  89. - 3
  90. - EastRandomCropData:
  91. size:
  92. - 640
  93. - 640
  94. max_tries: 50
  95. keep_ratio: true
  96. - MakeBorderMap:
  97. shrink_ratio: 0.4
  98. thresh_min: 0.3
  99. thresh_max: 0.7
  100. total_epoch: *epoch_num
  101. - MakeShrinkMap:
  102. shrink_ratio: 0.4
  103. min_text_size: 8
  104. total_epoch: *epoch_num
  105. - NormalizeImage:
  106. scale: 1./255.
  107. mean:
  108. - 0.485
  109. - 0.456
  110. - 0.406
  111. std:
  112. - 0.229
  113. - 0.224
  114. - 0.225
  115. order: hwc
  116. - ToCHWImage: null
  117. - KeepKeys:
  118. keep_keys:
  119. - image
  120. - threshold_map
  121. - threshold_mask
  122. - shrink_map
  123. - shrink_mask
  124. loader:
  125. shuffle: true
  126. drop_last: false
  127. batch_size_per_card: 8
  128. num_workers: 8
  129. Eval:
  130. dataset:
  131. name: SimpleDataSet
  132. data_dir: /home/ma-user/ocr/
  133. label_file_list:
  134. - /home/ma-user/ocr/TestLabel.txt
  135. transforms:
  136. - DecodeImage:
  137. img_mode: BGR
  138. channel_first: false
  139. - DetLabelEncode: null
  140. - DetResizeForTest:
  141. - NormalizeImage:
  142. scale: 1./255.
  143. mean:
  144. - 0.485
  145. - 0.456
  146. - 0.406
  147. std:
  148. - 0.229
  149. - 0.224
  150. - 0.225
  151. order: hwc
  152. - ToCHWImage: null
  153. - KeepKeys:
  154. keep_keys:
  155. - image
  156. - shape
  157. - polys
  158. - ignore_tags
  159. loader:
  160. shuffle: false
  161. drop_last: false
  162. batch_size_per_card: 1
  163. num_workers: 2
  164. profiler_options: null

要改的是训练和验证的数据路径

data_dir: /home/ma-user/ocr/
    label_file_list:
      - /home/ma-user/ocr/Label.txt

还有就是预训练模型的路径:/home/ma-user/ocr/ch_PP-OCRv4_det_server_train/best_accuracy.pdparams

3、开始训练

python tools/train.py -c /home/ma-user/ocr/PaddleOCR-main/configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_teacher.yml

训练完后评价指标如上所示。

推理结果如上所示

四、微调PP-OCRv4文字识别模型

1、下载预训练模型

https://github.com/PaddlePaddle/PaddleOCR/blob/main/doc/doc_ch/models_list.md

2、修改PaddleOCR-main/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_hgnet.yml配置,修改方法同上

3、开始训练

python tools/train.py -c /home/ma-user/ocr/PaddleOCR-main/configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_hgnet.yml

训练完成

然后YOLO+OCR混合推理即可将发票关键信息提取出来。

作者介绍

作者本人是一名人工智能炼丹师,目前在实验室主要研究的方向为生成式模型,对其它方向也略有了解,希望能够在CSDN这个平台上与同样爱好人工智能的小伙伴交流分享,一起进步。谢谢大家鸭~~~

 如果你觉得这篇文章对您有帮助,麻烦点赞、收藏或者评论一下,这是对作者工作的肯定和鼓励。  

尾言

 如果您觉得这篇文章对您有帮忙,请点赞、收藏。您的点赞是对作者工作的肯定和鼓励,这对作者来说真的非常重要。如果您对文章内容有任何疑惑和建议,欢迎在评论区里面进行评论,我将第一时间进行回复。 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/995578
推荐阅读
相关标签
  

闽ICP备14008679号