当前位置:   article > 正文

paddleocr文本检测模型的训练_paddle训练文字识别模型

paddle训练文字识别模型

1、环境的安装和开源项目的下载

        首先我个人建议,玩深度学习的话,不管是工作还是学习,最起码要配一个有GPU的电脑。我个人有着血淋淋的教训,我本人是电气工程的一名学生,本科期间一点深度学习和机器学习的基础都没有,读研的时候就带着自己大一的时候买的笔记本电脑(没有GPU)去了读研的学校。我的实验室是大家带上自己的电脑工作学习的。因为没有GPU很多实验跑的特别慢,学习起来也很痛苦。后面实习的时候搞得是计算机视觉相关的工作,没有GPU简直没有办法学习了,后面毅然花钱买了一个好点的电脑,现在学习起来的确快了很多很多。所以建议大家配一个最少带有GPU的电脑去玩计算机视觉。

1.1 paddlepaddle深度学习环境配置

        paddlepaddle环境的配置可以参考我的博客,利用Anaconda安装pytorch和paddle深度学习环境+pycharm安装---免额外安装CUDA和cudnn(适合小白的保姆级教学)https://blog.csdn.net/didiaopao/article/details/119787139?spm=1001.2014.3001.5501https://blog.csdn.net/didiaopao/article/details/119787139?spm=1001.2014.3001.5501

1.2、克隆paddleocr开源项目

        现在很多计算机视觉的项目都是开源的,开源的项目基本上都在github上,因此我们可以去GitHub上去学习别的大佬的项目经验。paddleocr的项目地址。打开以后就可以看到如下的界面了。我们只需要点如图的那个位置,下载我们的项目代码就行,由于项目代码是一个压缩包,就需要我们去解压。这里要注意一点是克隆develop这个分支版本的,develop:基于Paddle静态图开发的分支,推荐使用paddle1.8 或者2.0版本,该分支具备完善的模型训练、预测、推理部署、量化裁剪等功能,领先于release/1.1分支。

2、算法的介绍

        OCR文字的识别一般分为两个步骤,一般首先要检测到图片中字符的位置,如果字符的位置检测不到就更不用提文字的识别了。所以如果要训练自己的OCR文字识别模型的话。首先要训练文本检测模型,再去训练文本识别模型,最后将两个模型组合一起就可以进行OCR文字识别了。

2.1、文本检测算法

        paddleocr所用的文本检测算法如下图所示,需要了解算法原理的可以自己去看算法的相关论文。

         文本检测的模型的骨干网络有两个系列,一个是ResNet系列,一个是MobileNet系列。总体来看同一网络模型不同的骨干网络,ResNet系列的效果要由于MobileNet系列的。

 2.2、数据准备

        深度学习训练自己的模型就需要收集自己的数据集,利用paddlepaddle深度学习框架训练模型的话,就需要利用PPOCRLabel这个工具去给训练图片打标签。所以该次实验就用icdar2015数据集,icdar2015数据集可以从官网下载到,首次下载需注册。因为下载还需要注册,而且该数据集不符合我们的Padleocr的数据集格式,虽然官方提供了数据格式转换脚本,可以将官网 label 转换支持的数据格式。但是还是比较的麻烦的。

        为此我这里有一份可以直接用的icdar数据集,链接:百度云盘链接提取码:7j0u 

         这个数据集里面包含着文本检测和文本识别的数据集,该博客讲述的文本检测模型的训练,文本识别后面会讲,所以也会用到这个数据集。数据集解压后如下图所示:

         打开det文件夹后其中有四个文件,如下图所示,j将图中的文件夹解压到对应的文件夹中。(解压后可以选择将压缩包删除)

         解压后的文件夹如下图所示:

       这四个文件分别对应如下信息:

  └─ train/             icdar数据集的训练数据
  └─ test/              icdar数据集的测试数据
  └─ train_label.txt    icdar数据集的训练标注
  └─ test_label.txt     icdar数据集的测试标注

         提供的标注文件格式如下,中间用"\t"分隔:json.dumps编码前的图像标注信息是包含多个字典的list,字典中的 points 表示文本框的四个点的坐标(x, y),从左上角的点开始顺时针排列。 transcription 表示当前文本框的文字,当其内容为“###”时,表示该文本框无效,在训练时会跳过。如果您想在其他数据集上训练,可以按照上述形式构建标注文件。

" 图像文件名                    json.dumps编码的图像标注信息"
ch4_test_images/img_61.jpg    [{"transcription": "MASA", "points": [[310, 104], [416, 141], [418, 216], [312, 179]]}, {...}]

        我们来直观的看下真实的图片数据和label数据。

         最后将文件按如下的目录结构来,这样的目的是为了方便同一,修改训练参数的时候可以尽量少改参数。

        至此,数据准备阶段就全部结束,后面就可以开始训练模型了。

 三、文本检测模型的训练

3.1模型的快速训练

        首先下载模型backbone的pretrain model,PaddleOCR的检测模型目前支持两种backbone,分别是MobileNetV3、ResNet_vd系列, 您可以根据需求使用PaddleClas中的模型更换backbone, 对应的backbone预训练模型可以从PaddleClas repo 主页中找到下载链接。有些人问,我训练自己的数据集,我不用预训练模型行不行,答案是可以的,但是这样的模型的训练收敛起来比较慢,所以还是用预训模型来训练。如上的链接中有大量的文本检测的预训练权重,如下图所示,大家可以根据自己的需求去下载相应的权重,需要性能的就去下载小模型,需要精度就去下载大模型。

         我这里下载的是MobileNetV3_large_x0_5_pretrained.pdparams,这个模型,如下图所示:

3.2 修改yaml文件参数

        将在github上下载好的paddleocr开源代码用pycharm打开,将上面准备的数据集放到如下图的位置下,将我们下载的预训练权重放到pretrain_models(该文件夹没有,需要自己创建)文件夹下。

         找到configs这个目录打开,找到det_mv3_db.yml这个yml文件,修改其中的文件来训练我们的模型,我们选用的文字检测,MobileNetV3这个网络骨干,db文字检测算法来训练我们的网络参数。

         其中yaml文件中的参数如链接所示

         这里要特别注意如下几点:

1、填写预训练权重的那一行参数的位置,预训练权重是不需要填写后缀名的。

 2、填写训练数据集和验证数据集及其对应的标签路径的时候,首先最好填写相应路径的绝对路径,然后将绝对路径的\这个符号改成/,然后在填写数据集的路径的时候,只需要填写到数据图片目录最外面一层路径就可以了(如图到text_localization这层就可以了,这层后面加一个/就可以了)。标签的路径填写标签的绝对路径就行。

 3、我的电脑的显卡内存是6GB,显卡为3060系列的显卡,训练集和验证集的batch_size_per_card、num_workers这两个参数必须为如下两个参数值的大小,否则训练的时候会爆出显存溢出的情况,大家可以根据自己显卡的型号修改如下的参数的大小。

3.3  开始训练

        将如上的设置弄好了以后,就可以训练我们的数据了。打开pycharm的控制台。在控制台中输入如下的命令,这段命令的意思是,利用det_mv3_db.yml文件中的参数,进行文字检测的训练。

python tools/train.py -c configs/det/det_mv3_db.yml

        训练结束以后,会在output目录中生成权重文件,和一些其他的信息,如下图所示:

         如果训练程序中断,如果希望加载训练中断的模型从而恢复训练,可以通过指定Global.checkpoints指定要加载的模型路径:

python3 tools/train.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=./your/trained/model

        注意Global.checkpoints的优先级高于Global.pretrain_weights的优先级,即同时指定两个参数时,优先加载Global.checkpoints指定的模型,如果Global.checkpoints指定的模型路径有误,会加载Global.pretrain_weights指定的模型。

测试检测效果

测试单张图像的检测效果,打开pycharm的控制台,输入如下的命令。地址要填正确。

python tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy"

检测的效果如下图,可以看出,检测的效果还是很不错的。 

测试DB模型时,调整后处理阈值

python tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/img_10.jpg" Global.pretrained_model="./output/det_db/best_accuracy"  PostProcess.box_thresh=0.6 PostProcess.unclip_ratio=1.5

测试文件夹下所有图像的检测效果

python tools/infer_det.py -c configs/det/det_mv3_db.yml -o Global.infer_img="./doc/imgs_en/" Global.pretrained_model="./output/det_db/best_accuracy"

         

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号