当前位置:   article > 正文

巧用PPOCRLabel制作DOC-VQA格式数据集

ppocrlabel

1. 项目背景

最近涉及到多模态“OCR” + “DOC-VQA”相关内容,一直使用XFUND数据集,但实际项目中需要训练真实数据才能达到更好的效果,那么如何制作DOC-VQA格式的数据集呢?

先了解一下几个简单的概念:

  • 文本检测:定位出输入图像中的文字区域。
  • 文本识别:识别出图像中的文字内容,一般输入来自于文本检测得到的文本框截取出的图像文字区域。
  • 关键信息提取(Key Information Extraction,KIE)是Document VQA中的一个重要任务,主要从图像中提取所需要的关键信息,如从身份证中提取出姓名和公民身份号码信息,这类信息的种类往往在特定任务下是固定的,但是在不同任务间是不同的。

KIE通常分为两个子任务进行研究

  • SER: 语义实体识别 (Semantic Entity Recognition), 可以完成对图像中的文本识别与分类。
SER测试效果图
对于XFUND数据集,有QUESTION, ANSWER, HEADER,OTHER 4种类别。图中在OCR检测框的左上方也标出了对应的类别和OCR识别结果。
  • RE: 关系抽取 (Relation Extraction),对每一个检测到的文本进行分类,如将其分为问题和的答案。然后对每一个问题找到对应的答案。基于 RE 任务,可以完成对图象中的文本内容的关系提取,如判断问题对(pair)。
RE预测效果图

图中红色框表示问题,蓝色框表示答案,问题和答案之间使用绿色线连接。在OCR检测框的左上方也标出了对应的类别和OCR识别结果。

2. XFUND数据集

我们来看一下XFUND数据集是怎么样?首先,要使用XFUND数据集进行训练或验证都需要先转换为“图片路径 JSON字符串”的形式,JSON字符串如:

{
    "height": 3508, # 图像高度
    "width": 2480,  # 图像宽度
    "ocr_info": [
        {
            "text": "邮政地址:",  # 单个文本内容
            "label": "question", # 文本所属类别
            "bbox": [261, 802, 483, 859], # 单个文本框
            "id": 54,  # 文本索引
            "linking": [[54, 60]], # 当前文本和其他文本的关系 [question, answer]
            "words": []
        },
        {
            "text": "湖南省怀化市市辖区",
            "label": "answer",
            "bbox": [487, 810, 862, 859],
            "id": 60,
            "linking": [[54, 60]],
            "words": []
        }
    ]
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

其中的label一般分为:headerquestionanswerother,其中questionanswer如果是对应关系则都包含一个相同的linking。一个question可以包含多个answer

2.1 下载数据集

%cd /home/aistudio/
! mkdir /home/aistudio/XFUND && mkdir /home/aistudio/XFUND/zh_train && mkdir /home/aistudio/XFUND/zh_val
! unzip -q -o /home/aistudio/data/data140302/XFUND_ori.zip  -d /home/aistudio/data/data140302/
! mv /home/aistudio/data/data140302/XFUND_ori/zh.train /home/aistudio/XFUND/zh_train/image
! mv /home/aistudio/data/data140302/XFUND_ori/zh.val /home/aistudio/XFUND/zh_val/image
! rm -rf /data/data140302/XFUND_ori

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
/home/aistudio
mkdir: 无法创建目录"/home/aistudio/XFUND": 文件已存在
  • 1
  • 2

2.2 转为可训练的格式

! unzip -q -o PaddleOCR.zip 
# 如仍需安装or安装更新,可以执行以下步骤
#! git clone https://gitee.com/PaddlePaddle/PaddleOCR
  • 1
  • 2
  • 3
# 安装依赖包
! pip install -r /home/aistudio/PaddleOCR/requirements.txt > install.log
#! pip install paddleocr >> install.log
# 安装nlp及其他包
# ! pip install yacs gnureadline paddlenlp==2.2.1 >> install.log
# ! pip install xlsxwriter >> install.log
! pip install regex
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
[33mWARNING: You are using pip version 22.0.4; however, version 22.1.2 is available.
You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mLooking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: regex in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2022.6.2)
[33mWARNING: You are using pip version 22.0.4; however, version 22.1.2 is available.
You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.[0m[33m
[0m
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
# 还可通过以下命令,生成文本检测和文本识别的训练集和验证集
%cd /home/aistudio/
! python trans_xfund_data.py
  • 1
  • 2
  • 3
/home/aistudio
Corrupt JPEG data: 18 extraneous bytes before marker 0xc4
Corrupt JPEG data: bad Huffman code
Corrupt JPEG data: premature end of data segment
  • 1
  • 2
  • 3
  • 4

3 自制数据集

3.1 解压图片数据

本文只用了以下数据集中的33张图片进行了标注。

%cd /home/aistudio/XTOWER
# 解压图片,本文只用了33张图片
! unzip -q -o image.zip 
# 解压 ”文本标注结果“的图片
! unzip -q -o crop_img.zip 

#更多图片可以用以下命令获取
# ! unzip -q -o /home/aistudio/data/data142101/Scan_0012_0004.zip -d /home/aistudio/XTOWER/image 
# ! unzip -q -o /home/aistudio/data/data142101/1234.zip -d /home/aistudio/XTOWER/image 

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
/home/aistudio/XTOWER
  • 1

3.2 利用PPOCRLabel来辅助标注

选择PPOCRLabel的原因:

  • 支持自动标注,可节省不少时间。
  • 支持kie模式,可以为文本指定一个分类。

启动 PPOCRLabel(kie模式)

PPOCRLabel --lang ch --kie True
  • 1

需要本地安装PaddleOCR,需要将图片从XTOWER下image目录打包下载到本地,使用PPOCRLabel自动标注全部图片,然后检查标注并可做一些调整。然后在PPOCRLabel上 “导出标记结果”和“导出识别结果”,“导出标记结果”会生成一个Label.txt文件(文字检测用),“导出识别结果”新建一个“crop_img”文件夹用于保存切割的图片以及rec_gt.txt文件(文字识别用)。

把以下文件和文件夹都上传到/home/aistudio/XTOWER目录:

  • crop_img
  • rec_gt.txt
  • Label.txt

3.3 自动关联QA及分割文本

使用“自动标注”后并检查微调识别的框位置,确定后保存的Label.txt,这个文件包含了图片上文字的位置信息,但还没有建立QA关系,可以用PPOCRLabel编辑“更改box关键字类别”建立RE关系,如: question_1,question_2,answer_1,answer_2,…还要标注出header。

可不可以根据question 自动匹配 answer呢?设置question都不用手动标注,答案是可以的。

我们可以先定义所要标注的表格的结构,如按行来区分,有哪些header,有哪些question? ,哪些是other?甚至我们还能定义部分question的几倍行高(相当于单行)。

注意:每行有多个question放在同一个list里,显示申明的other不会被当成answer去匹配question。

# 定义一个简单文档结构,不同的文档都可以这里定义
documents=[
    {
        "headers":["《塔类业务交付验收单》"],
        "questions":[ 
            ["客 户:"],
            ["需求名称","铁塔名称"],
            ["运营商区县","铁塔区县"],
            ["需求单号","站址编码"],
            ["所属批次","产品单元"],
            ["站点经度","站点纬度"],
            ["验收日期:"],
            ["塔型"],
            ["机房类型"],
            ["挂高"],
            ["场景"],
            ["运营商共享"],
            ["建设内容"],
            ["市电引入费用原值(元)"],
            ["存在问题及解决办法"],
            ["铁塔公司验收负责人:","运营商验收负责人:"],
            ["其他参加验收人员签字:"]
        ],
        "others":[],
        "style":{
            "铁塔名称":{"max_height":2},
            "建设内容":{"max_height":3}
        }
    },
    {
        "headers":["泰安电信“一站一案”需求线下确认单"],
        "questions":[
            ["需求名称","铁塔站点名称"],
            ["建设单位","设计单位"],
            ["站址编码","所属批次"],
            ["站点位置情况"],
            ["产品类型"],
            ["塔形、天线挂高"],
            ["供电类型"],
            ["建设类型"],
            ["场租类型"],
            ["电力引入费用预估"],#如果无answer,标记为other
            ["交付时间要求:","起租计费时间"],
            ["起租方式说明:"],
            ["新建/改造立项编码"],
            ["塔桅"],
            ["配套"],
            ["市电"]
        ],
        "others":["精确到小数点后6位","参考标准工期"],
    }
]
    
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
import json
from PIL import Image
import regex
import math
import copy
  • 1
  • 2
  • 3
  • 4
  • 5
def get_word_list(s1):
    # 把句子按字分开,中文按字分,英文按单词 
    regEx = regex.compile('\W]+') # 我们可以使用正则表达式来切分句子,切分的规则是除单词,数字外的任意字符串
    res = regex.compile(r"([\u4e00-\u9fa5\pZ\(\)\:\。\,\?\《\》\“\”])")    #  [\u4e00-\u9fa5]中文范围

    p1 = regEx.split(s1)
    str1_list = []
    for str in p1:
        if res.split(str) == None:
            str1_list.append(str)
        else:
            ret = res.split(str) 
            for ch in ret:
                str1_list.append(ch)
    # list_word1 = [w for w in str1_list if w in [" "," "] or len(w.strip()) > 0]  # 去掉为空的字符
    list_word1 = [w for w in str1_list if len(w) > 0]  

    return  list_word1
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
def isChinese(word):
    if '\u4e00' <= word <= '\u9fff':
        return True
    if word in "()《》“”‘’ 。,:【】「」?": #中文标点符号
        return True
    elif len(word) ==1 and 32 <= ord(word) <= 255:
        return False
    elif len(regex.findall('\p{Z}', word)) != 0: #中文符号
        return True
    return False
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
"""
 分割文字标注框,横向水平均分
 box = [[x1,y2],[x2,y2]]
"""
def splitBox(box,words):
    left_padding=2
    top_padding=0
    # words的box的四个点的顺序分便是:上左,上右,下右,下左,只需均分上和下两条直线即可,
    # 考虑到有倾斜的情况,垂直方向也需要做均分
    split_width=[]
    word_count=0
    for word in words:
        if isChinese(word):
            split_width.append(1) 
            word_count = word_count +1
        else:
            #大写,小写,数字 占多宽
            split_width.append(0.5*len(word))  #非中文一个字符占0.5个中文宽度
            word_count = word_count+ (0.5*len(word))
    
    dx = (int(box[1][0]) - int(box[0][0]) - left_padding)  / word_count #单个字的宽度
    dy = (int(box[2][1]) - int(box[3][1])) / word_count #单个字的高度   
        
    wordboxes = []
    i=0
    bx=left_padding
    by=0
    for word in words:
        x = dx * split_width[i] 
        if isChinese(word):
            px= x * 0.8 # 更容易取中文特征
        else:
            px= x
        y = dy * split_width[i]
        make_ocrinfo = {}
        make_ocrinfo['transcription']=word
        make_ocrinfo['key_cls']='word'
        make_ocrinfo['difficult']=False
        make_ocrinfo['points']=[
            [int(box[0][0]+bx),int(box[0][1]+by+top_padding)],
            [int(box[0][0]+bx+px),int(box[0][1]+by+top_padding)],
            [int(box[0][0]+bx+px),int(box[3][1]+by+y-top_padding)],
            [int(box[0][0]+bx),int(box[3][1]+by+y-top_padding)]
            ]
        # print(box)
        
        wordboxes.append(make_ocrinfo)
        # print(make_ocrinfo)
        bx=bx+x
        by=by+y
        i = i+1
    return wordboxes
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
# 根据PPOCRLabel自动标注转格式,无需更改标签类型
def make_qa_linking(labelfile,newlabelfile,documents=None,split_text=False,ignore_word=False):
    file = open(labelfile)
    newinfo = {}
    i = 0
    lines = ""
    json_lines = ""
    while True:
        line = file.readline()
        if not line:
            break
        image_path,ocr_info = line.split("\t");
        ocr_infos = json.loads(ocr_info)

        # 1.找出表格的header,判断是哪张表格
        curr_doument =None;
        for ocrinfo in ocr_infos:
            for document in documents:
                if(ocrinfo['transcription'] in document['headers']):
                    curr_doument = document;
        if curr_doument== None:
            print(image_path+'表格未匹配')
            continue


        # 排除word类型
        all_word_info=[]
        new_ocr_infos=[]
        for ocr_info in ocr_infos:
            if ocr_info['key_cls'] == "word" :
                if  ignore_word == False:
                    all_word_info.append(copy.deepcopy(ocr_info))
            else:
                new_ocr_infos.append(copy.deepcopy(ocr_info))
        
        for i,ocrinfo in enumerate(new_ocr_infos):   
            #如果忽略word,则删除word信息   
            # if ignore_word == True:
            #     continue
            words=[]
            # 分析word的从属关系
            for word_info in all_word_info:
                word_center_x = word_info['points'][0][0] + int((word_info['points'][1][0] - word_info['points'][0][0]) / 2)
                word_center_y = word_info['points'][0][1] + int((word_info['points'][3][1] - word_info['points'][0][1]) / 2)
                #如果word的中心点落在 question或answer标签内
                if ocrinfo['points'][0][0] < word_center_x < ocrinfo['points'][1][0] and  ocrinfo['points'][0][1] < word_center_y < ocrinfo['points'][3][1] :
                    words.append(word_info)
            ocrinfo['words'] = words
        
        
        # 2. 定位question,并修正首位符号为全角。因PPOCRLabel将首位中文符号“(、)、《、》、:“等识别为了半角的宽度
        questions= {}
        q=1
        ocr_infos=copy.deepcopy(new_ocr_infos)
        
        for ocrinfo in ocr_infos:  

            # **********
            # 本段是根据ppocrlabel自动识别的情况,在文字前后出现特殊中文符号时增加标注的长度,方便切割文字。
            # 当QA关系通过key_cls建立后,就不再调整特殊符号的标注框位置
            #
            if 'key_cls' not in ocrinfo or ocrinfo['key_cls'] == 'None':
                if ocrinfo['transcription'] == "客户:":
                    ocrinfo['transcription'] = "客 户:"
                if ocrinfo['transcription'][0] in ["《","("]:
                    ocrinfo['points'][0][0] = ocrinfo['points'][0][0] -20
                    ocrinfo['points'][3][0] = ocrinfo['points'][3][0] -20
                if ocrinfo['transcription'][-1] in ['》',')',':','。',',']:
                    ocrinfo['points'][1][0] = ocrinfo['points'][1][0] +20
                    ocrinfo['points'][2][0] = ocrinfo['points'][2][0] +20
            # ********* end
        
            row=1
            for row_questions in curr_doument['questions']:
                # print(row_questions)
                col=1
                col_count = len(row_questions)
                for one_quesion in row_questions:
                    if ocrinfo['transcription'] == one_quesion:
                        ocrinfo['key_cls'] = 'question_'+str(q)
                        ocrinfo['construct'] = [row,col,col_count]
                        questions['q_'+str(row)+'_'+str(col)]=ocrinfo
                        q = q+1
                    col = col +1
                row = row+1
            
        
        # 2.根据question 找出对应的answer(多行的话,暂时用表格多个答案,再分割word后再合并),other,header 保留
        for ocrinfo in ocr_infos:
            for bianhao in questions:
                question = questions[bianhao]
                question_points = question['points']
                [q_row,q_col,q_col_count] = question['construct']
                
                one_question_answers=[]
                label_points = ocrinfo['points']
                
                # 排除其他的label。满足:左边框在问题右边框的右边;下边框在问题上边框的下边;上边框在问题下边框的上边
                if ocrinfo['transcription']  in curr_doument['headers'] :
                    ocrinfo['key_cls'] = "header"
                elif ocrinfo['key_cls'] == "word":
                    continue
                elif "words" in ocrinfo and len(ocrinfo['words']) > 0:
                    continue
                else:
                    max_height = 1
                    style = curr_doument['questions']
                    if question['transcription'] in style:
                        if "max_height" in style[question['transcription']]:
                            max_height = style[question['transcription']]["max_height"]
                    #额外增加answer的判定范围
                    q_lineheight = abs(question_points[3][1]-question_points[0][1]) * (max_height-1)/2

                    if label_points[0][0] > question_points[1][0] and ( label_points[3][1] > question_points[1][1]-q_lineheight and label_points[0][1] < question_points[2][1] + q_lineheight ):
                        if q_col < q_col_count:
                            next_col=q_col+1
                            bh='q_'+str(q_row)+'_'+str(next_col)
                            if bh in questions:
                                next_question= questions[bh]
                                if label_points[1][0] < next_question['points'][0][0]:
                                    answer_name= copy.deepcopy(question['key_cls']).replace('question_','answer_')
                                    ocrinfo['key_cls'] = answer_name  #答案辅助标签,与question后的数字对应
                                    one_question_answers.append(copy.deepcopy(ocrinfo))
                        elif q_col == q_col_count:
                            answer_name= copy.deepcopy(question['key_cls']).replace('question_','answer_')
                            ocrinfo['key_cls'] = answer_name  #答案辅助标签,与question后的数字对应
                            one_question_answers.append(copy.deepcopy(ocrinfo))
                
                        question['answers'] = one_question_answers
        
        # 3.切割文字生成新的标注框question
        if split_text == True:
            for ocrinfo in ocr_infos:
                # 已分割的不再继续分割,方便手动微调。
                if ocrinfo['key_cls'] == "word":
                    continue
                if "words" in ocrinfo and len(ocrinfo['words']) > 0:
                    continue
                qwords =  get_word_list(ocrinfo['transcription'])
                qwordboxes = splitBox(ocrinfo['points'],qwords)
                ocrinfo['words']=qwordboxes
                
                if "answers" in ocrinfo:
                    for answer in ocrinfo['answers']:
                        if "words" in answer or answer['key_cls'][0:8] == "question" or answer['key_cls'][0:6] == "answer":
                            continue
                        words =  get_word_list(answer['transcription'])
                        wordboxes = splitBox(answer['points'],words)
                        answer['words']=wordboxes
            

        # 组合成ppocrlabel格式
        txt_document = []
        
        for ocrinfo in ocr_infos:
            if 'words' in ocrinfo:
                for word in ocrinfo['words']:
                    txt_document.append(word)
                del ocrinfo['words']
            
            if ocrinfo['key_cls'] in ["other","None"] :
                # None、other
                ocrinfo["key_cls"] = "other"
                 
            if "construct" in  ocrinfo:
                del ocrinfo["construct"]
            if "answers" in ocrinfo:
                del ocrinfo['answers']
            if "words" in ocrinfo:
                del ocrinfo['words']

            txt_document.append(ocrinfo)
        lines +=image_path+"\t"+json.dumps(txt_document,ensure_ascii=False)+"\n"

    with open(newlabelfile,'w+',encoding='utf-8') as f2:
        f2.writelines(lines)


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
# 切割文字和kie
make_qa_linking('/home/aistudio/XTOWER/Label.txt','/home/aistudio/XTOWER/Label_kie_word.txt',documents,split_text=True)
# 可忽略切割的文字,可用来导出文本识别结果。
make_qa_linking('/home/aistudio/XTOWER/Label.txt','/home/aistudio/XTOWER/Label_kie_no_word.txt',documents,split_text=False,ignore_word=True)

# 生成的新文件都是可以直接在PPOCRLabel中打开的
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

Label_kie_word.txt 也是文本检测格式,也可以直接在PPOCRLabel中打开,覆盖Label.txt即可。


切割文本后在PPOCRLabel中效果

标注技巧

  • 中英文和特殊符号的长文本可以分段标注,自动切割误差小。
  • 切割文本后的文件可导入到PPOCRLabel中进行微调(可直接覆盖Label.txt).
  • 如果header识别错误,需要先纠错,防止不能匹配到定义的文档结构上

Label_kie_not_word.txt 不包含切割的key_cls="word"的标注框,可用来PPOCRLabel“导出识别结果”功能制作文本识别数据集。
在这里插入图片描述

3.4 划分数据集

将QA重命名和切割文字的Label_kie_word.txt文件再划分数据集,这里就简单的划分为7:3

# 划分文本检测的训练集和验证集
import random
import os
import shutil

image_path = "/home/aistudio/XTOWER/"
det_gt_kie_file = "/home/aistudio/XTOWER/Label_kie_word.txt"
det_gt_train = "/home/aistudio/XTOWER/train_data/det_gt_train.txt" #此时还不是最终的目标检测格式
det_gt_val = "/home/aistudio/XTOWER/val_data/det_gt_val.txt"

train_dir = os.path.dirname(det_gt_train)
val_dir = os.path.dirname(det_gt_val)
if os.path.isdir(train_dir):
    shutil.rmtree(train_dir)
os.mkdir(train_dir)
os.mkdir(train_dir+'/image')
if os.path.isdir(val_dir):
    shutil.rmtree(val_dir)
os.mkdir(val_dir)
os.mkdir(val_dir+'/image')

newinfo = {}
i = 0
json_lines = ""
with open(det_gt_kie_file) as f:
    lines = f.readlines();
    random.shuffle (lines)

list_len = len(lines)
train_len= int(0.7 * list_len)
train_data = lines[:train_len]
val_data = lines[train_len:]

# print(list_len)
# print(len(train_data))
# print(len(val_data))
with open(det_gt_train,'w+',encoding='utf-8') as f1:
    f1.writelines(train_data)
    for line in train_data:
        image_file,_ = line.split("\t");
        shutil.move(image_path+image_file,train_dir+'/image/')
with open(det_gt_val,'w+',encoding='utf-8') as f2:
    f2.writelines(val_data)
    for line in val_data:
        image_file,_ = line.split("\t");
        shutil.move(image_path+image_file,val_dir+'/image/')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
# 划分文本识别的训练集和验证集
import random
import os
import shutil

image_path = "/home/aistudio/XTOWER/"
rec_file = "/home/aistudio/XTOWER/rec_gt.txt"
rec_train = "/home/aistudio/XTOWER/rec_train.txt"
rec_val = "/home/aistudio/XTOWER/rec_val.txt"

train_dir = os.path.dirname(rec_train)
val_dir = os.path.dirname(rec_val)

newinfo = {}
i = 0
json_lines = ""
with open(rec_file) as f:
    lines = f.readlines();
    random.shuffle (lines)

list_len = len(lines)
train_len= int(0.8 * list_len)
train_data = lines[:train_len]
val_data = lines[train_len:]

# print(list_len)
# print(len(train_data))
# print(len(val_data))
with open(rec_train,'w+',encoding='utf-8') as f1:
    f1.writelines(train_data)
with open(rec_val,'w+',encoding='utf-8') as f2:
    f2.writelines(val_data)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

3.5 文本检测格式转DOC-VQA格式

import json
from PIL import Image
import re
import copy

# 文本检测格式转DOC-VQA格式
def det_gt_kie_vqa(det_gt_file = 'Label.txt',normalize_file='normalize.json',dataset_path='../train_data'):
    file = open(det_gt_file)
    
    newinfo = {}
    i = 0
    lines = ""
    
    while True:
        line = file.readline()
        if not line:
            break
        image_path,ocr_info = line.split("\t");
        # image_id = image_path[-8:][:4] #这里取文件明的后4位(不含扩展名)
        img = Image.open(dataset_path+"/"+image_path)
        newinfo['width'] = img.width
        newinfo['height'] = img.height
        ocr_infos = json.loads(ocr_info)
        
        # 排除word类型
        all_word_info=[]
        qa=[]
        for ocr_info in ocr_infos:

            if ocr_info['key_cls'] == "word":
                words_dict={
                    'box':[ocr_info['points'][0][0],
                        ocr_info['points'][0][1],
                        ocr_info['points'][2][0],
                        ocr_info['points'][2][1]],
                    'text':ocr_info['transcription']
                }
                all_word_info.append(copy.deepcopy(words_dict))
            else:
                qa.append(copy.deepcopy(ocr_info))
        
        links={}
        ocr_id=1
        # 分别提取qa
        q=[]
        a=[]
        for ocr_info in qa:
            ocr_info['id'] = ocr_id
            if ocr_info['key_cls'][0:8] == "question":
                question_id = ocr_info['key_cls'].replace("question_","")
                q.append([question_id,ocr_id])
            elif ocr_info['key_cls'][0:6] == "answer":
                 question_answer_id = ocr_info['key_cls'].replace("answer_","")
                 a.append([question_answer_id,ocr_id])

            ocr_id=ocr_id+1
        # qa关系
        for cls_id,ocrid in q:
            link=[]
            for cls_id2,ocrid2 in a:
                if cls_id == cls_id2:
                    link.append([ocrid,ocrid2])
            links[cls_id]=link
        
        newocrinfos = []
        ocr_id = 1
        for ocr_info in qa:
            question_id = 0
            newocrinfo={}
            newocrinfo['text'] = ocr_info['transcription']
            newocrinfo['bbox'] = [
                ocr_info['points'][0][0],
                ocr_info['points'][0][1],
                ocr_info['points'][2][0],
                ocr_info['points'][2][1],
            ]
            
            newocrinfo['id'] = ocr_id
            if ocr_info['key_cls'][0:8] == "question":
                question_id = ocr_info['key_cls'].replace("question_","")
                newocrinfo['label'] = 'question'
            elif ocr_info['key_cls'][0:6] == "answer":
                question_id = ocr_info['key_cls'].replace("answer_","")
                newocrinfo['label'] = "answer"
            elif ocr_info['key_cls'] == 'header':
                newocrinfo['label'] ="header"
            else:
                # newocrinfo['label'] = ocr_info['key_cls']
                newocrinfo['label'] = "other"
        
            if question_id !='0' and question_id in links :
                newocrinfo['linking'] = links[question_id]
            else:
                newocrinfo['linking'] = []
            
                
            # 分析word的从属关系
            words=[]
            for word_info in all_word_info:
                word_center_x = word_info['box'][0] + int((word_info['box'][2] - word_info['box'][0]) / 2)
                word_center_y = word_info['box'][1] + int((word_info['box'][3] - word_info['box'][1]) / 2)
                #如果word的中心点落在 question或answer标签内
                if newocrinfo['bbox'][0] < word_center_x < newocrinfo['bbox'][2] and  newocrinfo['bbox'][1] < word_center_y < newocrinfo['bbox'][3] :
                    words.append(word_info)
            newocrinfo['words'] = words
            newocrinfos.append(newocrinfo)
            
            # if newocrinfo['label'] =='answer':
            #     print('ocr_id=',ocr_id)
            #     print(newinfo)
            ocr_id = ocr_id+1
        newinfo['ocr_info'] = newocrinfos
        # print(newocrinfos)
        # break
        # print(all_word_info)
        
        # break
        lines += image_path +"\t" + json.dumps(newinfo,ensure_ascii=False) + "\n"
        i=i+1

    with open(normalize_file,'w+',encoding='utf-8') as f2:
    # file2.seek(0)  # 移动指针到开头
        f2.writelines(lines)


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
# 生成ser和re可训练的格式
det_gt_kie_vqa(det_gt_file='/home/aistudio/XTOWER/train_data/det_gt_train.txt',normalize_file='/home/aistudio/XTOWER/train_data/normalize_train.json',dataset_path='/home/aistudio/XTOWER/train_data')
det_gt_kie_vqa(det_gt_file='/home/aistudio/XTOWER/val_data/det_gt_val.txt',normalize_file='/home/aistudio/XTOWER/val_data/normalize_val.json',dataset_path='/home/aistudio/XTOWER/val_data')
  • 1
  • 2
  • 3
import shutil

di_xtower = set()

def to_det_gt(filename,di=set()):
    """
    将kie标注格式转为的文本检测格式,并返回全部的文本
    """
    new_docs = ""
    with open(filename, "r", encoding='utf-8') as f:
        docs = f.readlines()    
        for doc in docs:
            image_file,ocr_info = doc.split("\t");
            ocr_infos = json.loads(ocr_info)
            txt_document=[]
            for ocr_info in ocr_infos:
                if "key_cls" not in ocr_info or ocr_info['key_cls'] != 'word':
                    txt_document.append({'transcription':ocr_info['transcription'],'points':ocr_info['points']})
                    # 字典
                    di = di | set(ocr_info["transcription"])
            new_docs += image_file+"\t"+json.dumps(txt_document,ensure_ascii=False)+"\n"

    with open(filename, "w", encoding='utf-8') as f:
        f.writelines(new_docs)
    return di

di_xtower=to_det_gt("/home/aistudio/XTOWER/train_data/det_gt_train.txt",di_xtower)
di_xtower=to_det_gt("/home/aistudio/XTOWER/val_data/det_gt_val.txt",di_xtower)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
# 字典处理
baseline_label = '/home/aistudio/PaddleOCR/ppocr/utils/ppocr_keys_v1.txt'
shutil.copyfile(baseline_label, '/home/aistudio/XTOWER/word_dict.txt')
with open(baseline_label, 'r', encoding='utf-8') as f:
    all_chars = f.read()

with open('/home/aistudio/XTOWER/xtower_dict.txt', 'w', encoding='utf-8') as f:
    for char in di_xtower:
        f.write(char+'\n')

with open('/home/aistudio/XTOWER/word_dict.txt', 'a', encoding='utf-8') as f:
    f.write('\n')
    for char in di_xtower:
        if char not in all_chars:
            f.write(char+'\n')
            print(char)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
️
☑
  • 1
  • 2

4. 查看效果

这里只显示预测效果,如何训练ser和re,可参考我写的《多模态技术在工业场景中的应用实践:表单识别重命名

在这里插入图片描述
)]

预测效果

5. 总结

5.1 优点

  • 大大减少标注时间,全中文文本,基本不用调整。
  • 可以在切割后的文件里继续标注,不会被重复切割。因已经设置特定的key_cls,不再切割标注的文本。
  • 切割后的文字会自动归属到所属的文本里。
  • 支持answer为多行文本时自动question关联。

5.2 缺点

  • 中英文及符号较多的句子效果不佳,还需要手动调整切割后word的位置。
  • 对倾斜度较大的文本自动切割效果不佳。
  • question和answer自动关联机制比较简单,通用性有局限。

5.3 改进

  • 通过训练识别单个字(汉字、中文符号,英文单词)的方式,定位单个字的坐标位置、角度等来进行自动标注。
  • 期待PPOCRLabel支持自动切割文本和RE关联。

作者介绍:tianxingxia

原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/4197468

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/132597
推荐阅读
相关标签
  

闽ICP备14008679号