当前位置:   article > 正文

PaddleOCR手写体训练摸索

paddleocr手写

手写OCR识别

蓦然回首,那人却在灯火阑珊处。
此文从20220421开始着手,
至20220425为初稿,
从基础理论看起,
以官方文档为根本,
经过自己的一番摸索,
也倒是大致了解了整个过程,
最后却发现最有用的却是已经被官方总结好了
放在这里让后来人少走弯路。
虽然自己摸爬滚打许多天不如一篇官方文章,
但是,
真的没有收获吗?
我觉得未必。

编辑记录
————————————————
20220421 起稿
20220425 21:49 一稿发布
20220426 编辑
20220427 已成功能够开始训练,涉及到手写体数据库格式转换和配置过程后面补齐
20220428 摸索服务器训练
20220502 尝试使用epoch31的训练数据识别(epoch: [31/500], iter: 10105, lr: 0.000993, loss: 7.327022, acc: 0.558589,)
20220503 修补训练数据集的生成与配置使用;测试训练模型的识别能力
202205__ 有使用epoch117数据识别,应该为第10节的测试
202205__ 未完待续…
20220512 搞清train与test数据的相关性——不同作者书写,不相关

一:官方支持的数据格式?

1.官方文档

1.1 PaddleOCR 支持两种数据格式:

  • lmdb 用于训练以lmdb格式存储的数据集(LMDBDataSet);
  • 通用数据 用于训练以文本文件存储的数据集(SimpleDataSet);

1.2 训练数据的默认存储路径

PaddleOCR/train_data/

1.3 自定义数据集的准备

1.3.1 通用数据集

见官网文档
在这里插入图片描述

1.3.2 lmdb数据集

(引用——by程序员阿德

**什么是lmdb数据集?**
1.英文全名:Lightning Memory-Mapped Database (LMDB);
2.对应中文名:轻量级内存映射数据库。
3.因为最开始 Caffe 就是使用的这个数据库,所以网上的大多数关于 LMDB 的教程都通过 Caffe 实现的
4.LMDB属于key-value数据库,而不是关系型数据库( 比如 MySQL ),LMDB提供 key-value 存储,其中每个键值对都是我们数据集中的一个样本。
  LMDB的主要作用是提供数据管理,可以将各种各样的原始数据转换为统一的key-value存储。
5.LMDB的文件结构很简单,一个文件夹,里面是一个数据文件和一个锁文件。
  数据随意复制,随意传输。
  它的访问简单,不需要单独的数据管理进程。
  只要在访问代码里引用LMDB库,访问时给文件路径即可。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
1.3.2.1 lmdb基本函数:
  1. env = lmdb.open():创建 lmdb 环境
  2. txn = env.begin():建立事务
  3. txn.put(key, value):进行插入和修改
  4. txn.delete(key):进行删除
  5. txn.get(key):进行查询
  6. txn.cursor():进行遍历
  7. txn.commit():提交更改
1.3.2.2 创建一个 lmdb 环境:
# 安装:pip install lmdb
import lmdb

env = lmdb.open(lmdb_path, map_size=1099511627776)
# lmdb_path:指定存放生成的lmdb数据库的文件夹路径,如果没有该文件夹则自动创建。
# map_size: 指定创建的新数据库所需磁盘空间的最小值,1099511627776B=1T。

# 会在指定路径下创建 data.mdb 和 lock.mdb 两个文件,一是个数据文件,一个是锁文件。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
1.3.2.3 修改数据库内容:
# 创建一个事务(transaction) 对象 txn,所有的操作都必须经过这个事务对象。
# 因为我们要对数据库进行写入操作,所以将 write 参数置为 True,默认其为 False。
txn = env.begin(write=True)

# 使用 .put(key, value) 对数据库进行插入和修改操作,传入的参数为键值对。
# 需要在键值字符串后加 .encode() 改变其编码格式,
# 将 str 转换为 bytes 格式,否则会报该错误:# TypeError: Won't implicitly convert Unicode to bytes; use .encode()。
# 在后面使用 .decode() 对其进行解码得到原数据。
# insert/modify
txn.put(str(1).encode(), "Alice".encode())
txn.put(str(2).encode(), "Bob".encode())

# 使用 .delete(key) 删除指定键值对。
# delete
txn.delete(str(1).encode())

# 对LMDB的读写操作在事务中执行,需要使用 commit 方法提交待处理的事务。
txn.commit()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
1.3.2.4 查询数据库内容:
# 每次 commit() 之后都要用 env.begin() 更新 txn(得到最新的lmdb数据库)。
txn = env.begin()

# 使用 .get(key) 查询数据库中的单条记录。
print(txn.get(str(2).encode()))

# 使用 .cursor() 遍历数据库中的所有记录,
# 其返回一个可迭代对象,相当于关系数据库中的游标,每读取一次,游标下移一位。
for key, value in txn.cursor():
    print(key, value)

env.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
1.3.2.5 完整的demo如下:
import lmdb
import os, sys

def initialize():
    env = lmdb.open("lmdb_dir")
    return env

def insert(env, sid, name):
    txn = env.begin(write=True)
    txn.put(str(sid).encode(), name.encode())
    txn.commit()

def delete(env, sid):
    txn = env.begin(write=True)
    txn.delete(str(sid).encode())
    txn.commit()

def update(env, sid, name):
    txn = env.begin(write=True)
    txn.put(str(sid).encode(), name.encode())
    txn.commit()

def search(env, sid):
    txn = env.begin()
    name = txn.get(str(sid).encode())
    return name

def display(env):
    txn = env.begin()
    cur = txn.cursor()
    for key, value in cur:
        print(key, value)


env = initialize()

print("Insert 3 records.")
insert(env, 1, "Alice")
insert(env, 2, "Bob")
insert(env, 3, "Peter")
display(env)

print("Delete the record where sid = 1.")
delete(env, 1)
display(env)

print("Update the record where sid = 3.")
update(env, 3, "Mark")
display(env)

print("Get the name of student whose sid = 3.")
name = search(env, 3)
print(name)

# 最后需要关闭关闭lmdb数据库
env.close()

# 执行系统命令
os.system("rm -r lmdb_dir")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
1.3.2.6 将图片和对应的文本标签存放到lmdb数据库:
在这里插入代码片
  • 1
1.3.2.7 从lmdb数据库中读取图片数据:
在这里插入代码片
  • 1

做OCR,在搜索中常常碰到一个优质博主:冠军的试炼(【OCR技术系列之八】端到端不定长文本识别CRNN代码实现)
这篇文章对理解CRNN训练的数据输入有帮助

1.3.3文字标签数字化(涉及到解码部分,训练输入的数据不用编码,按一般格式输入即可):

在数据准备部分还有一个操作需要强调的,那就是文字标签数字化,即我们用数字来表示每一个文字(汉字,英文字母,标点符号)。

比如“我”字对应的id是1,
“l”对应的id是1000,
“?”对应的id是90,
如此类推,这种编解码工作使用字典数据结构存储即可,训练时先把标签编码(encode),预测时就将网络输出结果解码(decode)成文字输出
参考代码:

# 定义str to label 类
class strLabelConverter(object):
    """Convert between str and label.
    	转换str和label
    	
    NOTE:
        Insert `blank` to the alphabet for CTC.
       
    Args:
        alphabet (str): set of the possible characters.
        ignore_case (bool, default=True): whether or not to ignore all of the case.
    """

    def __init__(self, alphabet, ignore_case=False):
        self._ignore_case = ignore_case
        if self._ignore_case:
            alphabet = alphabet.lower()
        self.alphabet = alphabet + '-'  # for `-1` index

        self.dict = {}
        for i, char in enumerate(alphabet):
            # NOTE: 0 is reserved for 'blank' required by wrap_ctc
            self.dict[char] = i + 1

    def encode(self, text):
        """Support batch or single str.

        Args:
            text (str or list of str): texts to convert.

        Returns:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.
        """

        length = []
        result = []
        for item in text:
            item = item.decode('utf-8', 'strict')

            length.append(len(item))
            for char in item:

                index = self.dict[char]
                result.append(index)

        text = result
        # print(text,length)
        return (torch.IntTensor(text), torch.IntTensor(length))

    def decode(self, t, length, raw=False):
        """Decode encoded texts back into strs.

        Args:
            torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
            torch.IntTensor [n]: length of each text.

        Raises:
            AssertionError: when the texts and its length does not match.

        Returns:
            text (str or list of str): texts to convert.
        """
        if length.numel() == 1:
            length = length[0]
            assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
                                                                                                         length)
            if raw:
                return ''.join([self.alphabet[i - 1] for i in t])
            else:
                char_list = []
                for i in range(length):
                    if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
                        char_list.append(self.alphabet[t[i] - 1])
                return ''.join(char_list)
        else:
            # batch mode
            assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
                t.numel(), length.sum())
            texts = []
            index = 0
            for i in range(length.numel()):
                l = length[i]
                texts.append(
                    self.decode(
                        t[index:index + l], torch.IntTensor([l]), raw=raw))
                index += l
            return texts
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 所以要看PaddleOCR这一部分设置的对应关系

2.FAQ中的相关问题

2.1 Q:对于图片中的密集文字,有什么好的处理方法?

A:可以先试用预训练模型测试一下,例如DB+CRNN,判断下密集文字图片中是检测还是识别的问题,然后针对性的改善。还有一种是如果图象中密集文字较小,可以尝试增大图像分辨率,对图像进行一定范围内的拉伸,将文字稀疏化,提高识别效果。

2.2 Q:文本行较紧密的情况下如何准确检测?

A:使用基于分割的方法,如DB,检测密集文本行时,最好收集一批数据进行训练,并且在训练时,并将生成二值图像的shrink_ratio参数调小一些。

2.3 Q:文档场景中,使用DB模型会出现整行漏检的情况应该怎么解决?

A:可以在预测时调小 det_db_box_thresh 阈值,默认为0.5, 可调小至0.3观察效果。

2.4 Q: 弯曲文本(如略微形变的文档图像)漏检问题

A: db后处理中计算文本框平均得分时,是求rectangle区域的平均分数,容易造成弯曲文本漏检,已新增求polygon区域的平均分数,会更准确,但速度有所降低,可按需选择,在相关pr中可查看可视化对比效果。该功能通过参数 det_db_score_mode进行选择,参数值可选[fast(默认)、slow],fast对应原始的rectangle方式,slow对应polygon方式。感谢用户buptlihang提pr帮助解决该问题

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/678664
推荐阅读
相关标签