当前位置:   article > 正文

【STD文本检测项目】之 DBNet++(一)用MMOCR在自己的数据集上训练进行文本检测_dbnet网络训练自己数据集

dbnet网络训练自己数据集

目录

前言

环境

准备环境

准备数据集

配置数据集

配置模型

训练

测试验证效果

使用模型

后续


前言

STD的定义,来自chatGPT

Scene Text Detection(场景文本检测):也被称为文本定位。它是指从自然场景图像中自动检测和定位出存在的文本区域。场景文本检测的目标是找到图像中包含文本的矩形边界框,以便后续的文字识别或其他文本分析任务。

 一般是作为STR(Scene Text Recogition 场景文本识别)的上游任务,STD负责把图中的文本区域圈出来,STR负责识别圈出来的文本内容。

STD有很多开源项目,可以参考Scene Text Detection | Papers With Code

 这里选用的是当前较为流行,效果比较好,支持检测不规则文本行的DBNet++,在MMOCR框架上进行训练和推理。

环境

设备 RTX 3060 6G 独显笔记本

Windows 10

Python 3.10.9

Pytorch 1.12.1

CUDA 11.6

MMEngine 0.7.0

MMCV 2.0.0rc4

MMDet 3.0.0rc6

MMOCR 1.0.0rc6

准备环境

安装anaconda,准备python环境

  1. conda create -n mmlab python=3.10
  2. conda activate mmlab

安装项目的python依赖,torch和torchvision建议用官网方式装,尽量用pip装,用conda直接装似乎有坑。(这里我安装的是cuda11.6下的torch1.12.1)

pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116

安装MMOCR环境

  1. pip install -U openmim
  2. mim install mmengine==0.7.0
  3. mim install mmcv==2.0.0rc4
  4. mim install mmdet==3.0.0rc6

克隆MMOCR项目代码并安装

  1. git clone https://github.com/open-mmlab/mmocr.git
  2. cd mmocr
  3. pip install -v -e .

更多使用说明可以查看官方文档欢迎来到 MMOCR 的中文文档! — MMOCR 1.0.0 文档

准备数据集

这里我用的是Label Studio GitHub - heartexlabs/label-studio: Label Studio is a multi-type data labeling and annotation tool with standardized output format

按官网的教程安装即可,安装完会启动一个网页服务,在网页上进行标注。

这里建议用conda另外开一个虚拟环境,label-studio支持的python版本不能超过3.9。

  1. conda create -n label python=3.9
  2. conda activate label
  3. # Requires Python >=3.7 <=3.9
  4. pip install label-studio
  5. # Start the server at http://localhost:8080
  6. label-studio

启动后随便注册一下,进入主页面创建一个项目。

进入项目,在Settings设置里修改下要标注的格式Labeling Interface,可以直接将以下的配置填写到code栏里。我的场景下文字是弯曲的,所以我用polygen(多边形)进行标注,如果文本行是直的,也可以直接用rectangle(矩形)进行标注。

配置如下,代表标注的内容是图片,标注的数据是多边形框和文本内容。更多配置可查看Label Studio官网。

  1. <View>
  2. <Image name="image"
  3. value="$image"
  4. zoom="true"
  5. smoothing="true"
  6. zoomControl="true"
  7. negativeZoom="true"
  8. crosshair="true"/>
  9. <Polygon name="poly" toName="image" strokeWidth="3" smartOnly="true"/>
  10. <TextArea name="transcription" toName="image"
  11. editable="true"
  12. perRegion="true"
  13. required="true"
  14. maxSubmissions="1"
  15. rows="5"
  16. placeholder="Recognized Text"
  17. displayMode="region-list"/>
  18. </View>

配置完后,点击Import上传图片数据。

标注后的效果如图,这里我以每个字的突出点作为一个多边形的角点。

数据都标注完后点击Export导出标注数据,导出格式选JSON-MIN。这里只会导出标注数据,不会打包导出图片。

Label Studio标注的JSON-MIN数据格式和MMOCR要求的数据格式不同,需要手动转一下,并分出训练集和测试集。转换代码如下

  1. import json
  2. import numpy as np
  3. import os, random, shutil
  4. output_path = "输出的目标文件夹路径"
  5. # 图片数据格式
  6. image_type = "jpg"
  7. input_img_path = f"输入图片文件夹路径"
  8. input_anno_file = f"label studio导出的标注json文件路径"
  9. output_img_path = f"{output_path}/textdet_imgs"
  10. train_img_path = f"{output_img_path}/train"
  11. test_img_path = f"{output_img_path}/test"
  12. output_train_anno = f"{output_path}/textdet_train.json"
  13. output_test_anno = f"{output_path}/textdet_test.json"
  14. os.makedirs(train_img_path, exist_ok=True)
  15. os.makedirs(test_img_path, exist_ok=True)
  16. def convert(data_list, type):
  17. res_json = {
  18. "metainfo": {
  19. "dataset_type": "TextDetDataset",
  20. "task_name": "textdet",
  21. "category": [{"id": 0, "name": "text"}],
  22. },
  23. "data_list": [],
  24. }
  25. # 每张图片
  26. for res in data_list:
  27. polys = res["poly"]
  28. img_w = res["poly"][0]["original_width"]
  29. img_h = res["poly"][0]["original_height"]
  30. scale_w = img_w / 100
  31. scale_h = img_h / 100
  32. img_uri = res["image"]
  33. img_path = img_uri.split("/")[-1].split("-", maxsplit=1)[1]
  34. lines = []
  35. # 每个poly
  36. for index, poly in enumerate(polys):
  37. points = poly["points"]
  38. o_points = (
  39. np.float32(points) * np.float32((scale_w, 1)) * np.float32((1, scale_h))
  40. )
  41. points = np.int0(o_points)
  42. line = points.flatten().tolist()
  43. x1 = int(np.min(points[:, 0]))
  44. y1 = int(np.min(points[:, 1]))
  45. x2 = int(np.max(points[:, 0]))
  46. y2 = int(np.max(points[:, 1]))
  47. lines.append(
  48. {
  49. "polygon": line,
  50. "bbox": [x1, y1, x2, y2],
  51. "bbox_label": 0,
  52. "ignore": False,
  53. }
  54. )
  55. res_json["data_list"].append(
  56. {
  57. "instances": lines,
  58. "img_path": f"textdet_imgs/{type}/{img_path}",
  59. "height": img_h,
  60. "width": img_w,
  61. }
  62. )
  63. shutil.copy(
  64. f"{input_img_path}/{img_path.replace('_', ' ')}",
  65. f"{output_img_path}/{type}/{img_path}",
  66. )
  67. if type == "train":
  68. with open(output_train_anno, "w") as anno:
  69. json.dump(res_json, anno)
  70. else:
  71. with open(output_test_anno, "w") as anno:
  72. json.dump(res_json, anno)
  73. with open(input_anno_file) as f:
  74. data = json.load(f)
  75. random.shuffle(data) # 随机打乱顺序
  76. split_index = int(0.8 * len(data)) # 计算分割点
  77. train_list = data[:split_index] # 取前80%作为训练集
  78. test_list = data[split_index:] # 取后20%作为测试集
  79. convert(train_list, "train")
  80. convert(test_list, "test")

转完之后文件夹结构入下

配置数据集

创建一个数据集配置文件,如命名为dataset.py

  1. data_root = "数据集文件夹路径"
  2. data_textdet_train = dict(
  3. type="OCRDataset",
  4. data_root=data_root,
  5. ann_file="textdet_train.json",
  6. filter_cfg=dict(filter_empty_gt=True, min_size=32),
  7. pipeline=None,
  8. )
  9. data_textdet_test = dict(
  10. type="OCRDataset",
  11. data_root=data_root,
  12. ann_file="textdet_test.json",
  13. test_mode=True,
  14. pipeline=None,
  15. )

用MMOCR项目下的tools/analysis_tools/browse_dataset.py数据集预览工具进行验证,看数据转换及配置是否正确。

python tools/analysis_tools/browse_dataset.py 数据集配置.py

如果能正常预览数据标注情况,及为配置完成。

配置模型

在MMOCR项目的configs\textdet\dbnetpp文件夹下创建一个模型配置文件,如config.py

  1. _base_ = [
  2. # 引用的dbnet++的模型配置
  3. "_base_dbnetpp_resnet50-dcnv2_fpnc.py",
  4. # 配置运行时的环境,打印方案,验证方案,可视化方案等。
  5. "../_base_/default_runtime.py",
  6. # 引用数据集配置
  7. "数据集配置文件.py",
  8. # 配置优化器方案
  9. "../_base_/schedules/schedule_sgd_1200e.py",
  10. ]
  11. # 加载预训练权重
  12. load_from = "https://download.openmmlab.com/mmocr/textdet/dbnetpp/tmp_1.0_pretrain/dbnetpp_r50dcnv2_fpnc_100k_iter_synthtext-20220502-352fec8a.pth"
  13. _base_.model.det_head = dict(
  14. type="DBHead",
  15. in_channels=256,
  16. module_loss=dict(type="DBModuleLoss"),
  17. # 配置后处理输出的结果
  18. postprocessor=dict(
  19. type="DBPostprocessor",
  20. # poly为多边形,quad为预测区域的最小外接矩形
  21. text_repr_type="poly",
  22. # 拟合出来的多边形的平滑程度,越小越平滑
  23. epsilon_ratio=0.002,
  24. # 预测的结果区域往外膨胀的大小
  25. unclip_ratio=4,
  26. ),
  27. )
  28. # dataset settings
  29. data_textdet_train = _base_.data_textdet_train
  30. data_textdet_test = _base_.data_textdet_test
  31. test_pipeline = [
  32. dict(
  33. type="LoadImageFromFile",
  34. color_type="color_ignore_orientation",
  35. ),
  36. dict(type="Resize", scale=(1280, 1280), keep_ratio=True),
  37. dict(type="LoadOCRAnnotations", with_polygon=True, with_bbox=True, with_label=True),
  38. dict(
  39. type="PackTextDetInputs",
  40. meta_keys=("img_path", "ori_shape", "img_shape", "scale_factor", "instances"),
  41. ),
  42. ]
  43. # pipeline settings
  44. data_textdet_train.pipeline = _base_.train_pipeline
  45. data_textdet_test.pipeline = test_pipeline
  46. train_dataloader = dict(
  47. batch_size=8,
  48. num_workers=1,
  49. persistent_workers=False,
  50. sampler=dict(type="DefaultSampler", shuffle=True),
  51. dataset=data_textdet_train,
  52. )
  53. val_dataloader = dict(
  54. batch_size=8,
  55. num_workers=1,
  56. persistent_workers=False,
  57. sampler=dict(type="DefaultSampler", shuffle=False),
  58. dataset=data_textdet_test,
  59. )
  60. test_dataloader = val_dataloader
  61. # 学习率
  62. _base_.optim_wrapper.optimizer.lr = 0.002
  63. # 训练多少轮在测试集上验证一次
  64. _base_.train_cfg.val_interval = 1
  65. # 训练多少轮保存一次权重
  66. _base_.default_hooks.checkpoint.interval = 2
  67. auto_scale_lr = dict(base_batch_size=8)
  68. param_scheduler = [
  69. dict(type="LinearLR", end=200, start_factor=0.001),
  70. dict(type="PolyLR", power=0.9, eta_min=1e-7, begin=200, end=1200),
  71. ]

其他配置详将官方文档。

训练

执行MMOCR项目下的tools/train.py脚本

  1. # amp 混合精度训练,减少显存暂用,提升速度。需要显卡支持
  2. python tools/train.py 模型配置.py --amp

训练到打印信息显示在测试集已经达到hmean达到1.0000或者接近即可。

默认模型输出路径在work_dirs/模型配置文件名/训练时间 下。

测试验证效果

选择一个打印信息里,测试集效果最好的模型,用项目下的python tools\test.py脚本验证。对比标注的和预测的结果是否一致。

python tools\test.py 模型配置文件.py 模型权重.pth --show

效果如下

使用模型

MMOCR已经封装的非常方便,只需要几行代码就可以使用训练好的模型,代码如下

  1. from mmocr.apis import TextDetInferencer
  2. infer = TextDetInferencer(
  3. model="模型配置文件.py",
  4. weights="模型权重文件.pth",
  5. device="cuda:0", # 显卡或CPU运行
  6. )
  7. det_res = infer(f"图片.jpg", show=True)
  8. print(det_res)

后续

下篇将讲解如何用MMDeploy将模型转换成ONNX和TensorRT,并在visual studio工程里用QT和C++调用,部署成性能最佳可供生产环境使用的版本。

再后续将讲解如何结合前几篇

【STR文字识别项目】之 最新SOTA项目PARSeq(一)训练自己的数据集,并转成onnx用C++调用
​​​​​【STR文字识别项目】之 最新SOTA项目PARSeq(二)转TensorRT并用C++调用

STR文本识别结合起来形成完整的文字识别流程。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/489749
推荐阅读
相关标签
  

闽ICP备14008679号