当前位置:   article > 正文

PaddleOCR训练和测试自己的数据集_paddleocr训练自己的数据集

paddleocr训练自己的数据集

一、环境配置

1、使用GPU训练之前需要安装paddlepaddle-gpu步骤如下:
飞浆官网(https://www.paddlepaddle.org.cn/)查找安装命令:
在这里插入图片描述

二、图像标注

1、法一(针对不下载项目的,只做标注的)

# 安装标注软件
pip install PPOCRLabel  
# 安装paddlepaddle-gpu,去https://www.paddlepaddle.org.cn/install/quick?docurl=/documentation/docs/zh/install/pip/windows-pip.html中查找对应的命令
# 如:windows系统,cuda=11.6对应命令如下:
python -m pip install paddlepaddle-gpu==2.4.2.post116 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html
# 运行标注软件,进行标注
PPOCRLabel --lang ch  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

2、法二

下载paddleOCR项目:
https://github.com/PaddlePaddle/PaddleOCR

cd PPOCRLabel
python PPOCRLabel.py --lang ch
  • 1
  • 2

在这里插入图片描述

三、文本检测训练自己的数据集

1、数据集划分

import random

train_txt = open("train.txt", "w", encoding="utf-8")
val_txt = open("val.txt", "w", encoding="utf-8")
with open("Label.txt", "r", encoding="utf-8") as f:
    data = f.readlines()
    f.close()

li_all = []
for da in data:
    data1 = da.strip('\n')
    li_all.append(data1)

count = len(data)
tra = int(0.9 * count)
li = range(count)
print("训练集个数:", tra)
print("验证集个数:", count-tra)
train = random.sample(li, tra)   # 随机从li列表中选取tra个数据

for i in li:
    if i in train:
        train_txt.write(li_all[i] + "\n")
    else:
        val_txt.write(li_all[i] + "\n")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

2、修改配置文件

配置文件目录:./PaddleOCR/configs/det/det_mv3_db.yml
注: 这里的训练图像存放路径和标注label都在./data目录下。

Global:
  use_gpu: True   # 默认是True
  use_xpu: false
  use_mlu: false
  epoch_num: 500  # ======================改====================================
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/db_mv3/    # ======================改====================================
  save_epoch_step: 100          # ======================改====================================
  # evaluation is run every 2000 iterations
  eval_batch_step: [0, 2000]
  cal_metric_during_train: False
  pretrained_model: ./pretrain_models/MobileNetV3_large_x0_5_pretrained/ch_ppocr_mobile_v2.0_det_train/best_accuracy    # ======================改====================================
  checkpoints:
  save_inference_dir:
  use_visualdl: False
  infer_img: doc/imgs_en/img_10.jpg
  save_res_path: ./output/det_db/predicts_db.txt

Architecture:
  model_type: det
  algorithm: DB
  Transform:
  Backbone:
    name: MobileNetV3
    scale: 0.5
    model_name: large
  Neck:
    name: DBFPN
    out_channels: 256
  Head:
    name: DBHead
    k: 50

Loss:
  name: DBLoss
  balance_loss: true
  main_loss_type: DiceLoss
  alpha: 5
  beta: 10
  ohem_ratio: 3

Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    learning_rate: 0.001
  regularizer:
    name: 'L2'
    factor: 0

PostProcess:
  name: DBPostProcess
  thresh: 0.3
  box_thresh: 0.6
  max_candidates: 1000
  unclip_ratio: 1.5

Metric:
  name: DetMetric
  main_indicator: hmean

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./     # ======================改====================================
    label_file_list:
      - ./data/train.txt    # ======================改====================================
    ratio_list: [1.0]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - IaaAugment:
          augmenter_args:
            - { 'type': Fliplr, 'args': { 'p': 0.5 } }
            - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
            - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
      - EastRandomCropData:
          size: [640, 640]
          max_tries: 50
          keep_ratio: true
      - MakeBorderMap:
          shrink_ratio: 0.4
          thresh_min: 0.3
          thresh_max: 0.7
      - MakeShrinkMap:
          shrink_ratio: 0.4
          min_text_size: 8
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
  loader:
    shuffle: True
    drop_last: False
    batch_size_per_card: 1    # 16   =====================改====================================
    num_workers: 1    # ======================改====================================
    use_shared_memory: True

Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./     # ======================改====================================
    label_file_list:
      - ./data/val.txt    # ======================改====================================
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - DetResizeForTest:
          image_shape: [736, 1280]
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'shape', 'polys', 'ignore_tags']
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 1 # must be 1  ======================改====================================
    num_workers: 8    # ======================改====================================
    use_shared_memory: True

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134

3、训练自己的数据集

python tools/train.py -c configs/det/det_mv3_db.yml
  • 1

4、断点续训

python tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./output/db_mv3_0606/latest_accuracy
  • 1

四、文本识别训练自己的数据集

1、制作数据集

对标注好的图像进行处理,如下:

import os
from PIL import Image

def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

img_path = "E:/PycharmProjects/meter_detection/PaddleOCR/train_data/data/"  # 图像目录
img_txt_path = "train_data/rec_ch/"    # 标注好的图像的txt文件目录
img_save_path = "E:/PycharmProjects/meter_detection/PaddleOCR/train_data/rec_ch/"   # 处理后的图像和txt存储目录,即训练集目录
mkdir(img_save_path)
li = ["train", "test"]   # 待处理的txt标注文件
for txt in li:
    ocr_li = []
    img_save = img_save_path + txt + "/"   # 图像保存路径
    mkdir(img_save)
    with open(f"E:/PycharmProjects/meter_detection/PaddleOCR/train_data/{txt}.txt", "r", encoding="utf-8") as f:
        data = f.readlines()
        f.close()
    new_txt = open(f"{img_save_path}rec_gt_{txt}.txt", "w", encoding="utf-8")   # 新的txt标注文件存放处
    for da in data:
        da_new = da.strip("\n")
        img_name, img_info = da_new.split("    ")
        img_name = img_name.split("/")[-1]
        img = Image.open(img_path + img_name)
        img_info = eval(img_info)  # 将字符串转换为列表
        i = 1
        for di in img_info:
            new_name = img_name[:-4] + "_" + str(i) + ".jpg"
            img_new_path = img_txt_path + txt + "/" + new_name  # txt文件中的图像路径+名字
            label = di["transcription"]
            points = di["points"]
            # 获取四个点的 x 和 y 坐标
            x_coordinates = [point[0] for point in points]
            y_coordinates = [point[1] for point in points]
            # 计算剪切区域的坐标
            left = min(x_coordinates)
            upper = min(y_coordinates)
            right = max(x_coordinates)
            lower = max(y_coordinates)
            if label not in ocr_li:
                ocr_li.append(label)
                new_txt.write(img_new_path + " " + label + "\n")
                new_img = img.crop((left, upper, right, lower))   # 左上角和右下角的坐标
                new_img.save(img_save + new_name)
                i += 1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

训练图像和txt存储路径:
在这里插入图片描述
txt文件格式例子:
在这里插入图片描述

2、修改配置文件

配置文件路径:configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml

Global:
  debug: false
  use_gpu: true
  epoch_num: 500    # 800 ======================修改=====================
  log_smooth_window: 20
  print_batch_step: 10
  save_model_dir: ./output/rec_ppocr_v3_distillation
  save_epoch_step: 100   # 3 ======================修改=====================
  eval_batch_step: [0, 2000]
  cal_metric_during_train: true
  pretrained_model: pretrain_models/rec_train/ch_PP-OCRv2_rec_slim/ch_PP-OCRv3_rec_train/best_accuracy   # ======================修改=====================
  checkpoints:
  save_inference_dir:
  use_visualdl: false
  infer_img: doc/imgs_words/ch/word_1.jpg
  character_dict_path: ppocr/utils/ppocr_keys_v1.txt
  max_text_length: &max_text_length 25
  infer_mode: false
  use_space_char: true
  distributed: true
  save_res_path: ./output/rec/predicts_ppocrv3_distillation.txt


Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Piecewise
    decay_epochs : [700]
    values : [0.0005, 0.00005]
    warmup_epoch: 5
  regularizer:
    name: L2
    factor: 3.0e-05


Architecture:
  model_type: &model_type "rec"
  name: DistillationModel
  algorithm: Distillation
  Models:
    Teacher:
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
      algorithm: SVTR
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
        last_conv_stride: [1, 2]
        last_pool_type: avg
      Head:
        name: MultiHead
        head_list:
          - CTCHead:
              Neck:
                name: svtr
                dims: 64
                depth: 2
                hidden_dims: 120
                use_guide: True
              Head:
                fc_decay: 0.00001
          - SARHead:
              enc_dim: 512
              max_text_length: *max_text_length
    Student:
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
      algorithm: SVTR
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
        last_conv_stride: [1, 2]
        last_pool_type: avg
      Head:
        name: MultiHead
        head_list:
          - CTCHead:
              Neck:
                name: svtr
                dims: 64
                depth: 2
                hidden_dims: 120
                use_guide: True
              Head:
                fc_decay: 0.00001
          - SARHead:
              enc_dim: 512
              max_text_length: *max_text_length
Loss:
  name: CombinedLoss
  loss_config_list:
  - DistillationDMLLoss:
      weight: 1.0
      act: "softmax"
      use_log: true
      model_name_pairs:
      - ["Student", "Teacher"]
      key: head_out
      multi_head: True
      dis_head: ctc
      name: dml_ctc
  - DistillationDMLLoss:
      weight: 0.5
      act: "softmax"
      use_log: true
      model_name_pairs:
      - ["Student", "Teacher"]
      key: head_out
      multi_head: True
      dis_head: sar
      name: dml_sar
  - DistillationDistanceLoss:
      weight: 1.0
      mode: "l2"
      model_name_pairs:
      - ["Student", "Teacher"]
      key: backbone_out
  - DistillationCTCLoss:
      weight: 1.0
      model_name_list: ["Student", "Teacher"]
      key: head_out
      multi_head: True
  - DistillationSARLoss:
      weight: 1.0
      model_name_list: ["Student", "Teacher"]
      key: head_out
      multi_head: True

PostProcess:
  name: DistillationCTCLabelDecode
  model_name: ["Student", "Teacher"]
  key: head_out
  multi_head: True

Metric:
  name: DistillationMetric
  base_metric_name: RecMetric
  main_indicator: acc
  key: "Student"
  ignore_space: False

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./       # ======================修改=====================
    ext_op_transform_idx: 1
    label_file_list:
    - ./train_data/rec_ch/rec_gt_train.txt   # ======================修改=====================
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - RecConAug:
        prob: 0.5
        ext_data_num: 2
        image_shape: [48, 320, 3]
        max_text_length: *max_text_length
    - RecAug:
    - MultiLabelEncode:
    - RecResizeImg:
        image_shape: [3, 48, 320]
    - KeepKeys:
        keep_keys:
        - image
        - label_ctc
        - label_sar
        - length
        - valid_ratio
  loader:
    shuffle: true
    batch_size_per_card: 8   # 128======================修改=====================
    drop_last: true
    num_workers: 4    # ======================修改=====================
Eval:
  dataset:
    name: SimpleDataSet
    data_dir: ./   # ======================修改=====================
    label_file_list:
    - ./train_data/rec_ch/rec_gt_test.txt   # ======================修改=====================
    transforms:
    - DecodeImage:
        img_mode: BGR
        channel_first: false
    - MultiLabelEncode:
    - RecResizeImg:
        image_shape: [3, 48, 320]
    - KeepKeys:
        keep_keys:
        - image
        - label_ctc
        - label_sar
        - length
        - valid_ratio
  loader:
    shuffle: false
    drop_last: false
    batch_size_per_card: 8   # 128======================修改=====================
    num_workers: 4   # ======================修改=====================

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207

3、训练自己的数据集

python tools/train.py -c configs/rec/PP-OCRv3/ch_PP-OCRv3_rec_distillation.yml
  • 1

五、测试

1、将训练好的权重转换为infer文件

命令如下:

python tools/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./output/db_mv3_0606/best_accuracy Global.save_inference_dir=./output/db_mv3_infer_0606/
  • 1

2、测试

import os
import time

from paddleocr import PaddleOCR
import pandas as pd
import numpy as np
import cv2


def ocr_predict(img, img_name):
    '''
    det_model_dir:文本检测
    rec_model_dir:文本识别
    '''
    ocr = PaddleOCR(det_model_dir="./output/det_test/ch_PP-OCRv3_det/ch_PP-OCRv3_det_infer/",
                    rec_model_dir="./output/rec_test/ch_PP-OCRv3_rec_infer/", lang='ch', use_angle_cls=True,
                    use_gpu=False)
    result = ocr.ocr(
        img)  
    print(result)


if __name__ == '__main__':
    img_path = "E:/PycharmProjects/meter_detection/data/digital_meter/test_data/"
    files = os.listdir(img_path)
    for file in files:
        img = img_path + file
        ocr_predict(img, file)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

六、训练过程中遇到的问题

报错内容:
Could not locate zlibwapi.dll. Please make sure it is in your library path
解决方法:
缺少zlibwapi.dll文件,下载缺少的文件并存放到以下目录:
lib文件放到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\lib
dll文件放到C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\bin
链接:https://pan.baidu.com/s/1Q9VNmU3UN_yaP-hWAJJNgA?pwd=0921
提取码:0921
在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/364179
推荐阅读
  

闽ICP备14008679号