当前位置:   article > 正文

OCR系列训练记(一)_ocrv4 训练

ocrv4 训练

简介

ABCNet:基于自适应贝塞尔曲线的实时端到端自然场景文字检测及识别网络

论文推荐ABCNet

论文链接 : https://arxiv.org/abs/2002.10200
官方开源代码: https://github.com/aim-uofa/AdelaiDet

ABCNet: Real-Time Scene Text Spotting With Adaptive Bezier-Curve Network

作者 | Yuliang Liu, Hao Chen, Chunhua Shen, Tong He, Lianwen Jin, Liangwei Wang

单位 | 华南理工大学;阿德莱德大学;

代码 | https://github.com/Yuliang-Liu/bezier_curve_text_spotting

备注 | CVPR 2020 Oral

解读 | https://zhuanlan.zhihu.com/p/146276834

论文是2020 CVPR 收录, 贡献1)提出采用贝塞尔曲线来拟合任意形状文本,2)提出贝塞尔对齐方式更准确地提取文本实例 .

简介部分参考知乎解读即可,具体理论知识不再赘述,下面进入正题,环境配置+测试+训练

环境

  1. # env
  2. torch 1.4.0
  3. torchvision 0.5.0
  4. py362 cuda10.1
  5. # 1.First install Detectron2
  6. git clone https://github.com/facebookresearch/detectron2.git
  7. python -m pip install -e detectron2
  8. # 2.
  9. cd AdelaiDet
  10. python setup.py build develop

 

如上,首先要安装Detectron2板块,具体 following the official guide: INSTALL.md

第二步就是正式安装AdelaiDet,按指令编译即可。

以上环境在torch 1.4.0 + torchvision0.5.0 ,python3.6.2 .cuda10.1 的虚拟环境下配置,基于anaconda.本机实际cuda安装的10.2.

特别提示,torch和torchvision的版本必须一致,否则编译中途会报各种莫名其妙的错误。已踩坑。

最后,综合环境预览如下:

  1. ---------------------- ---------------------------------------------------------------------------------------------------------
  2. sys.platform linux
  3. Python 3.6.2 |Continuum Analytics, Inc.| (default, Jul 20 2017, 13:51:32) [GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
  4. numpy 1.19.1
  5. detectron2 0.2.1 @/home/gavin/MyProj/tempwork/ocr/AdelaiDet/detectron2/detectron2
  6. Compiler GCC 5.4
  7. CUDA compiler CUDA 10.2
  8. detectron2 arch flags 6.1
  9. DETECTRON2_ENV_MODULE <not set>
  10. PyTorch 1.4.0 @/home/gavin/miniconda3/envs/py362/lib/python3.6/site-packages/torch
  11. PyTorch debug build False
  12. GPU available True
  13. GPU 0 GeForce GTX 1080 Ti (arch=6.1)
  14. CUDA_HOME /usr/local/cuda-10.2
  15. Pillow 4.2.1
  16. torchvision 0.5.0 @/home/gavin/miniconda3/envs/py362/lib/python3.6/site-packages/torchvision
  17. torchvision arch flags 3.5, 5.0, 6.0, 7.0, 7.5
  18. fvcore 0.1.1.post20200716
  19. cv2 4.4.0
  20. ---------------------- ---------------------------------------------------------------------------------------------------------
  21. PyTorch built with:
  22. - GCC 7.3
  23. - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  24. - Intel(R) MKL-DNN v0.21.1 (Git Hash 7d2fd500bc78936d1d648ca713b901012f470dbc)
  25. - OpenMP 201511 (a.k.a. OpenMP 4.5)
  26. - NNPACK is enabled
  27. - CUDA Runtime 10.1
  28. - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  29. - CuDNN 7.6.3
  30. - Magma 2.5.1
  31. - Build settings: BLAS=MKL, BUILD_NAMEDTENSOR=OFF, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -O2 -fPIC -Wno-narrowing -Wall -Wextra -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Wno-stringop-overflow, DISABLE_NUMA=1, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF,

 

测试

假设你配好了环境,那么进行下面的代码测试即可:

  1. # test
  2. python demo/demo.py \
  3. --config-file configs/BAText/TotalText/attn_R_50.yaml \
  4. --input datasets/totaltext/test_images/ \
  5. --opts MODEL.WEIGHTS tt_attn_R_50.pth
  6. eg:
  7. python demo/demo.py --config-file configs/BAText/TotalText/attn_R_50.yaml --input /media/gavin/home/gavin/DataSet/ocr/test/bdz --output outputs/ --opts MODEL.WEIGHTS tt_attn_R_50.pth
  8. python3 demo/demo.py --config-file configs/BAText/TotalText/attn_R_50.yaml --input demo/demo_images/ --output outputs/ --opts MODEL.WEIGHTS tt_attn_R_50.pth

基于totalText测试情况如下:

 

可以看出检测出来的还是比较准确,但是漏检情况较多,可能跟数据集有关,我这里是拿的ICDAR2015里面的数据测试的。

测试的权重文件是官方训练的。

下面自行训练的CTW,检测效果上,漏检情况就好很多。结果见后面。

 

训练

最关键也是最麻烦的事情来了。

数据集准备

方法一:直接windows_label_tool工具标注的数据,这里是基于自己的数据进行标注,那么只需要执行一个转换即可,将windows_label_tool工具标注的格式转为abcnet训练的json格式。

参考这里。

 

  1. # ABCNet 自定义数据集制作,将ICDAR15转为ABCNet标注格式 参考https://github.com/Yuliang-Liu/Curve-Text-Detector/tree/master/data
  2. # 1.将labelme标注转为windows_label_tool标注格式,如下,首行是代表标注个数,下面依次是每行的标注,包含28/2 = 14个点坐标,后面是文本内容
  3. 4
  4. 45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84,"DOUGLASTON"
  5. 50,119,58,119,66,119,74,119,82,119,90,119,98,119,98,137,90,137,82,137,74,137,66,137,58,137,51,137,"E-313"
  6. 41,137,48,136,56,136,64,136,71,136,79,136,87,136,89,155,81,155,73,155,65,155,57,155,49,155,41,155,"L164"
  7. 39,166,56,166,74,166,92,167,110,167,128,167,146,168,140,196,123,195,107,195,90,194,74,194,57,193,41,193,"F.D.N.Y."
  8. # 2. convert_ann_to_json:将生成windows_label_tool工具的标注格式转换为ABCNet训练的json格式标注
  9. python convert_ann_to_json.py \
  10. --ann-dir /path/to/gt \
  11. --image-dir /path/to/image \
  12. --dst-json-path train.json
  13. eg:
  14. python convert_ann_to_json.py --ann-dir /media/gavin/home/gavin/DataSet/ocr/ctw/ctw1500_e2e_annos/ctw1500_e2e_test \
  15. --image-dir /media/gavin/home/gavin/DataSet/ocr/ctw/ctw1500/test/text_image \
  16. --dst-json-path ./abc_json/test.json

json转换脚本

  1. # -*- coding: utf-8 -*-
  2. """
  3. @File : convert_ann_to_json.py
  4. @Time : 2020-8-17 16:13
  5. @Author : yizuotian
  6. @Description : 生成windows_label_tool工具的标注格式转换为ABCNet训练的json格式标注
  7. """
  8. import argparse
  9. import json
  10. import os
  11. import sys
  12. import cv2
  13. import bezier_utils
  14. import numpy as np
  15. def gen_abc_json(abc_gt_dir, abc_json_path, image_dir, classes_path):
  16. """
  17. 根据abcnet的gt标注生成coco格式的json标注
  18. :param abc_gt_dir: windows_label_tool标注工具生成标注文件目录
  19. :param abc_json_path: ABCNet训练需要json标注路径
  20. :param image_dir:
  21. :param classes_path: 类别文件路径
  22. :return:
  23. """
  24. # Desktop Latin_embed.
  25. cV2 = [' ', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4',
  26. '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
  27. 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_',
  28. '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u',
  29. 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~']
  30. dataset = {
  31. 'licenses': [],
  32. 'info': {},
  33. 'categories': [],
  34. 'images': [],
  35. 'annotations': []
  36. }
  37. with open(classes_path) as f:
  38. classes = f.read().strip().split()
  39. for i, cls in enumerate(classes, 1):
  40. dataset['categories'].append({
  41. 'id': i,
  42. 'name': cls,
  43. 'supercategory': 'beverage',
  44. 'keypoints': ['mean',
  45. 'xmin',
  46. 'x2',
  47. 'x3',
  48. 'xmax',
  49. 'ymin',
  50. 'y2',
  51. 'y3',
  52. 'ymax',
  53. 'cross'] # only for BDN
  54. })
  55. def get_category_id(cls):
  56. for category in dataset['categories']:
  57. if category['name'] == cls:
  58. return category['id']
  59. # 遍历abcnet txt 标注
  60. indexes = sorted([f.split('.')[0]
  61. for f in os.listdir(abc_gt_dir)])
  62. print(indexes)
  63. j = 1 # 标注边框id号
  64. for index in indexes:
  65. # if int(index) >3: continue
  66. # print('Processing: ' + index)
  67. im = cv2.imread(os.path.join(image_dir, '{}.jpg'.format(index)))
  68. im_height, im_width = im.shape[:2]
  69. dataset['images'].append({
  70. 'coco_url': '',
  71. 'date_captured': '',
  72. 'file_name': index + '.jpg',
  73. 'flickr_url': '',
  74. 'id': int(index.split('_')[-1]), # img_1
  75. 'license': 0,
  76. 'width': im_width,
  77. 'height': im_height
  78. })
  79. anno_file = os.path.join(abc_gt_dir, '{}.txt'.format(index))
  80. with open(anno_file) as f:
  81. lines = [line for line in f.readlines() if line.strip()]
  82. # 没有清晰的标注,跳过
  83. if len(lines) <= 1:
  84. continue
  85. for i, line in enumerate(lines[1:]):
  86. elements = line.strip().split(',')
  87. polygon = np.array(elements[:28]).reshape((-1, 2)).astype(np.float32) # [14,(x,y)]
  88. control_points = bezier_utils.polygon_to_bezier_pts(polygon, im) # [8,(x,y)]
  89. ct = elements[-1].replace('"', '').strip()
  90. cls = 'text'
  91. # segs = [float(kkpart) for kkpart in parts[:16]]
  92. segs = [float(kkpart) for kkpart in control_points.flatten()]
  93. xt = [segs[ikpart] for ikpart in range(0, len(segs), 2)]
  94. yt = [segs[ikpart] for ikpart in range(1, len(segs), 2)]
  95. # 过滤越界边框
  96. if max(xt) > im_width or max(yt) > im_height:
  97. print('The annotation bounding box is outside of the image:{}'.format(index))
  98. print("max x:{},max y:{},w:{},h:{}".format(max(xt), max(yt), im_width, im_height))
  99. continue
  100. xmin = min([xt[0], xt[3], xt[4], xt[7]])
  101. ymin = min([yt[0], yt[3], yt[4], yt[7]])
  102. xmax = max([xt[0], xt[3], xt[4], xt[7]])
  103. ymax = max([yt[0], yt[3], yt[4], yt[7]])
  104. width = max(0, xmax - xmin + 1)
  105. height = max(0, ymax - ymin + 1)
  106. if width == 0 or height == 0:
  107. continue
  108. max_len = 100
  109. recs = [len(cV2) + 1 for ir in range(max_len)]
  110. ct = str(ct)
  111. # print('rec', ct)
  112. for ix, ict in enumerate(ct):
  113. if ix >= max_len:
  114. continue
  115. if ict in cV2:
  116. recs[ix] = cV2.index(ict)
  117. else:
  118. recs[ix] = len(cV2)
  119. dataset['annotations'].append({
  120. 'area': width * height,
  121. 'bbox': [xmin, ymin, width, height],
  122. 'category_id': get_category_id(cls),
  123. 'id': j,
  124. 'image_id': int(index.split('_')[-1]), # img_1
  125. 'iscrowd': 0,
  126. 'bezier_pts': segs,
  127. 'rec': recs
  128. })
  129. j += 1
  130. # 写入json文件
  131. folder = os.path.dirname(abc_json_path)
  132. if not os.path.exists(folder):
  133. os.makedirs(folder)
  134. with open(abc_json_path, 'w') as f:
  135. json.dump(dataset, f)
  136. def main(args):
  137. gen_abc_json(args.ann_dir, args.dst_json_path, args.image_dir, args.classes_path)
  138. if __name__ == '__main__':
  139. """
  140. Usage: python convert_ann_to_json.py \
  141. --ann-dir /path/to/gt \
  142. --image-dir /path/to/image \
  143. --dst-json-path train.json
  144. """
  145. parse = argparse.ArgumentParser()
  146. parse.add_argument("--ann-dir", type=str, default=None)
  147. parse.add_argument("--image-dir", type=str, default=None)
  148. parse.add_argument("--dst-json-path", type=str, default=None)
  149. parse.add_argument("--classes-path", type=str, default='./classes.txt')
  150. arguments = parse.parse_args() # sys.argv[1:]
  151. main(arguments)

 

方法二:

将labelme标注转为windows_label_tool标注格式,然后执行方法一的json转换。

labelme标注的格式转为windows_label_tool:

1. labelme 标注的json文件标注转abcnet 的gt标注,如果直接使用windowlabel工具标注则可省去此步骤
  1. # coding=utf-8
  2. # labelme 标注的json文件标注转abcnet 的标注,如果直接使用windowlabel工具标注则可省去此步骤
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import matplotlib.image as mpimg
  6. from scipy import interpolate
  7. from scipy.special import comb as n_over_k
  8. import glob, os
  9. import cv2
  10. from skimage import data, color
  11. from skimage.transform import rescale, resize, downscale_local_mean
  12. import json
  13. import matplotlib.pyplot as plt
  14. import math
  15. import numpy as np
  16. import random
  17. import torch
  18. from torch import nn
  19. from torch.nn import functional as F
  20. from sklearn.model_selection import train_test_split
  21. from sklearn.linear_model import LinearRegression
  22. from sklearn import metrics
  23. from sklearn.metrics import mean_squared_error, r2_score
  24. from shapely.geometry import *
  25. from PIL import Image
  26. import time
  27. import math
  28. import re
  29. class Bezier(nn.Module):
  30. def __init__(self, ps, ctps):
  31. """
  32. ps: numpy array of points
  33. """
  34. super(Bezier, self).__init__()
  35. self.x1 = nn.Parameter(torch.as_tensor(ctps[0], dtype=torch.float64))
  36. self.x2 = nn.Parameter(torch.as_tensor(ctps[2], dtype=torch.float64))
  37. self.y1 = nn.Parameter(torch.as_tensor(ctps[1], dtype=torch.float64))
  38. self.y2 = nn.Parameter(torch.as_tensor(ctps[3], dtype=torch.float64))
  39. self.x0 = ps[0, 0]
  40. self.x3 = ps[-1, 0]
  41. self.y0 = ps[0, 1]
  42. self.y3 = ps[-1, 1]
  43. self.inner_ps = torch.as_tensor(ps[1:-1, :], dtype=torch.float64)
  44. self.t = torch.as_tensor(np.linspace(0, 1, 81))
  45. def forward(self):
  46. x0, x1, x2, x3, y0, y1, y2, y3 = self.control_points()
  47. t = self.t
  48. bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
  49. bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
  50. bezier = torch.stack((bezier_x, bezier_y), dim=1)
  51. diffs = bezier.unsqueeze(0) - self.inner_ps.unsqueeze(1)
  52. sdiffs = diffs ** 2
  53. dists = sdiffs.sum(dim=2).sqrt()
  54. min_dists, min_inds = dists.min(dim=1)
  55. return min_dists.sum()
  56. def control_points(self):
  57. return self.x0, self.x1, self.x2, self.x3, self.y0, self.y1, self.y2, self.y3
  58. def control_points_f(self):
  59. return self.x0, self.x1.item(), self.x2.item(), self.x3, self.y0, self.y1.item(), self.y2.item(), self.y3
  60. def train(x, y, ctps, lr):
  61. x, y = np.array(x), np.array(y)
  62. ps = np.vstack((x, y)).transpose()
  63. bezier = Bezier(ps, ctps)
  64. return bezier.control_points_f()
  65. def draw(ps, control_points, t):
  66. x = ps[:, 0]
  67. y = ps[:, 1]
  68. x0, x1, x2, x3, y0, y1, y2, y3 = control_points
  69. fig = plt.figure()
  70. ax = fig.add_subplot(111)
  71. ax.plot(x,y,color='m',linestyle='',marker='.')
  72. bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
  73. bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
  74. plt.plot(bezier_x,bezier_y, 'g-')
  75. plt.draw()
  76. plt.pause(1) # <-------
  77. raw_input("<Hit Enter To Close>")
  78. plt.close(fig)
  79. Mtk = lambda n, t, k: t**k * (1-t)**(n-k) * n_over_k(n,k)
  80. BezierCoeff = lambda ts: [[Mtk(3,t,k) for k in range(4)] for t in ts]
  81. def bezier_fit(x, y):
  82. dy = y[1:] - y[:-1]
  83. dx = x[1:] - x[:-1]
  84. dt = (dx ** 2 + dy ** 2)**0.5
  85. t = dt/dt.sum()
  86. t = np.hstack(([0], t))
  87. t = t.cumsum()
  88. data = np.column_stack((x, y))
  89. Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9)
  90. control_points = Pseudoinverse.dot(data) # (4,9)*(9,2) -> (4,2)
  91. medi_ctp = control_points[1:-1,:].flatten().tolist()
  92. return medi_ctp
  93. def bezier_fitv2(x, y):
  94. xc01 = (2*x[0] + x[-1])/3.0
  95. yc01 = (2*y[0] + y[-1])/3.0
  96. xc02 = (x[0] + 2* x[-1])/3.0
  97. yc02 = (y[0] + 2* y[-1])/3.0
  98. control_points = [xc01,yc01,xc02,yc02]
  99. return control_points
  100. def is_close_to_line(xs, ys, thres):
  101. regression_model = LinearRegression()
  102. # Fit the data(train the model)
  103. regression_model.fit(xs.reshape(-1,1), ys.reshape(-1,1))
  104. # Predict
  105. y_predicted = regression_model.predict(xs.reshape(-1,1))
  106. # model evaluation
  107. rmse = mean_squared_error(ys.reshape(-1,1)**2, y_predicted**2)
  108. rmse = rmse/(ys.reshape(-1,1)**2- y_predicted**2).max()**2
  109. if rmse > thres:
  110. return 0.0
  111. else:
  112. return 2.0
  113. def is_close_to_linev2(xs, ys, size, thres = 0.05):
  114. pts = []
  115. nor_pixel = int(size**0.5)
  116. for i in range(len(xs)):
  117. pts.append(Point([xs[i], ys[i]]))
  118. import itertools
  119. # iterate by pairs of points
  120. slopes = [(second.y-first.y)/(second.x-first.x) if not (second.x-first.x) == 0.0 else math.inf*np.sign((second.y-first.y)) for first, second in zip(pts, pts[1:])]
  121. st_slope = (ys[-1] - ys[0])/(xs[-1] - xs[0])
  122. max_dis = ((ys[-1] - ys[0])**2 +(xs[-1] - xs[0])**2)**(0.5)
  123. diffs = abs(slopes - st_slope)
  124. score = diffs.sum() * max_dis/nor_pixel
  125. if score < thres:
  126. return 0.0
  127. else:
  128. return 3.0
  129. labels = glob.glob("dataset/json/*.json")
  130. labels.sort()
  131. if not os.path.isdir('abcnet_gen_labels'):
  132. os.mkdir('abcnet_gen_labels')
  133. for il, label in enumerate(labels):
  134. print('Processing: '+label)
  135. imgdir = label.replace('json/', 'image/').replace('.json', '.jpg')
  136. outgt = open(label.replace('dataset/json/', 'abcnet_gen_labels/').replace('.json', '.txt'), 'w')
  137. data = []
  138. cts = []
  139. with open(label,"r") as f:
  140. jdata = json.loads(f.read())
  141. boxes = jdata["shapes"]
  142. for il ,box in enumerate(boxes):
  143. line,ct = box["points"],box["label"]
  144. pts =[]
  145. [pts.extend(p) for p in line]
  146. if len(line) == 4:
  147. pts = line[0] + [(line[0][0]+line[1][0])//2, (line[0][1]+line[1][1])//2] + line[1] + line[2] +[(line[2][0]+line[3][0])/2, (line[2][1]+line[3][1])/2]+ line[3]
  148. if len(line) == 6:
  149. if abs(line[0][0] - line[1][0]) > abs(line[1][0] - line[2][0]):
  150. pts= line[0] + [(line[0][0]+line[1][0])//2, (line[0][1]+line[1][1])//2] + line[1] + line[2]
  151. pts += line[3] + [(line[3][0]+line[4][0])//2, (line[3][1]+line[4][1])//2] + line[4] + line[5]
  152. else:
  153. pts = line[0] + line[1] + [(line[1][0]+line[2][0])//2, (line[1][1]+line[2][1])//2] + line[2]
  154. pts += line[3] + line[4] + [(line[4][0]+line[5][0])//2, (line[4][1]+line[5][1])//2] + line[5]
  155. data.append(np.array([float(x) for x in pts]))
  156. cts.append(ct)
  157. ############## top
  158. img = plt.imread(imgdir)
  159. for iid, ddata in enumerate(data):
  160. lh = len(data[iid])
  161. assert(lh % 4 ==0)
  162. lhc2 = int(lh/2)
  163. lhc4 = int(lh/4)
  164. xcors = [data[iid][i] for i in range(0, len(data[iid]),2)]
  165. ycors = [data[iid][i+1] for i in range(0, len(data[iid]),2)]
  166. curve_data_top = data[iid][0:lhc2].reshape(lhc4, 2)
  167. curve_data_bottom = data[iid][lhc2:].reshape(lhc4, 2)
  168. left_vertex_x = [curve_data_top[0,0], curve_data_bottom[lhc4-1,0]]
  169. left_vertex_y = [curve_data_top[0,1], curve_data_bottom[lhc4-1,1]]
  170. right_vertex_x = [curve_data_top[lhc4-1,0], curve_data_bottom[0,0]]
  171. right_vertex_y = [curve_data_top[lhc4-1,1], curve_data_bottom[0,1]]
  172. x_data = curve_data_top[:, 0]
  173. y_data = curve_data_top[:, 1]
  174. init_control_points = bezier_fit(x_data, y_data)
  175. learning_rate = is_close_to_linev2(x_data, y_data, img.size)
  176. x0, x1, x2, x3, y0, y1, y2, y3 = train(x_data, y_data, init_control_points, learning_rate)
  177. control_points = np.array([
  178. [x0,y0],\
  179. [x1,y1],\
  180. [x2,y2],\
  181. [x3,y3]
  182. ])
  183. x_data_b = curve_data_bottom[:, 0]
  184. y_data_b = curve_data_bottom[:, 1]
  185. init_control_points_b = bezier_fit(x_data_b, y_data_b)
  186. learning_rate = is_close_to_linev2(x_data_b, y_data_b, img.size)
  187. x0_b, x1_b, x2_b, x3_b, y0_b, y1_b, y2_b, y3_b = train(x_data_b, y_data_b, init_control_points_b, learning_rate)
  188. control_points_b = np.array([
  189. [x0_b,y0_b],\
  190. [x1_b,y1_b],\
  191. [x2_b,y2_b],\
  192. [x3_b,y3_b]
  193. ])
  194. t_plot = np.linspace(0, 1, 81)
  195. Bezier_top = np.array(BezierCoeff(t_plot)).dot(control_points)
  196. Bezier_bottom = np.array(BezierCoeff(t_plot)).dot(control_points_b)
  197. plt.plot(Bezier_top[:,0], Bezier_top[:,1], 'g-', label='fit', linewidth=1)
  198. plt.plot(Bezier_bottom[:,0],Bezier_bottom[:,1],'g-', label='fit', linewidth=1)
  199. plt.plot(control_points[:,0],control_points[:,1], 'r.:', fillstyle='none', linewidth=1)
  200. plt.plot(control_points_b[:,0],control_points_b[:,1], 'r.:', fillstyle='none', linewidth=1)
  201. plt.plot(left_vertex_x, left_vertex_y, 'g-', linewidth=1)
  202. plt.plot(right_vertex_x, right_vertex_y, 'g-', linewidth=1)
  203. outstr = '{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}||||{}\n'.format(round(x0,2),round(y0,2),\
  204. round(x1, 2), round(y1, 2),\
  205. round(x2, 2), round(y2, 2),\
  206. round(x3, 2), round(y3, 2),\
  207. round(x0_b, 2), round(y0_b, 2),\
  208. round(x1_b, 2), round(y1_b, 2),\
  209. round(x2_b, 2), round(y2_b, 2),\
  210. round(x3_b, 2), round(y3_b, 2),\
  211. cts[iid])
  212. outgt.writelines(outstr)
  213. outgt.close()
  214. plt.imshow(img)
  215. plt.axis('off')
  216. if not os.path.isdir('abcnet_vis'):
  217. os.mkdir('abcnet_vis')
  218. plt.savefig('abcnet_vis/'+os.path.basename(imgdir), bbox_inches='tight',dpi=400)
  219. plt.clf()
 2.abcnet 的标注转abcnet的json

 

修改配置

  • 修改相关配置文件进行训练
    • 将制作好的data数据目录放在"AdelaiDet/datasets"目录
    • 修改"adet/data/builtin.py"中的_PREDEFINED_SPLITS_TEXT值来指定训练测试数据,注意这里默认是在datasets下的,所以它们的相对路径都是从下层目录开始的.
    1. _PREDEFINED_SPLITS_TEXT = {
    2. "totaltext_train": ("totaltext/train_images", "totaltext/train.json"),
    3. "totaltext_val": ("totaltext/test_images", "totaltext/test.json"),
    4. ...
    5. "abcnet_train": ("data/train", "data/annotations/train.json"),
    6. "abcnet_test": ("data/test", "data/annotations/test.json"),}
  • 在需要训练的配置文件中指定数据集即可.以configs/BAText/CTW1500/Base-CTW1500.yaml为例,
    1. DATASETS:
    2. # detail cfg: AdelaiDet/adet/data/builtin.py
    3. TRAIN: ("abcnet_train",)
    4. TEST: ("abcnet_test",)

 

训练脚本如下:

  1. # train custom
  2. #1. Pretrainining with synthetic data:
  3. OMP_NUM_THREADS=1 python tools/train_net.py \
  4. --config-file configs/BAText/Pretrain/attn_R_50.yaml \
  5. --num-gpus 4 \
  6. OUTPUT_DIR text_pretraining/attn_R_50
  7. #2. Finetuning
  8. OMP_NUM_THREADS=1 python tools/train_net.py \
  9. --config-file configs/BAText/CTW1500/attn_R_50.yaml \
  10. --num-gpus 4 \
  11. MODEL.WEIGHTS text_pretraining/attn_R_50/model_final.pth
  12. eg:
  13. # 1.
  14. OMP_NUM_THREADS=1 python tools/train_net.py --config-file configs/BAText/CTW1500/attn_R_50.yaml --num-gpus 1
  15. # 2.Finetuning on CTW1500:
  16. OMP_NUM_THREADS=1 python tools/train_net.py \
  17. --config-file configs/BAText/CTW1500/attn_R_50.yaml \
  18. --num-gpus 1 \
  19. MODEL.WEIGHTS text_pretraining/attn_R_50/model_final.pth
  20. eg:
  21. OMP_NUM_THREADS=1 python tools/train_net.py --config-file \
  22. configs/BAText/CTW1500/attn_R_50.yaml --num-gpus 1 MODEL.WEIGHTS \
  23. output/batext/ctw1500/attn_R_50/model_final.pth
  24. # eval:
  25. python tools/train_net.py \
  26. --config-file configs/BAText/CTW1500/attn_R_50.yaml \
  27. --eval-only \
  28. MODEL.WEIGHTS ctw1500_attn_R_50.pth
  29. test:
  30. python demo/demo.py --config-file configs/BAText/CTW1500/attn_R_50.yaml --input demo/demo_images/ --output outputs/ --opts MODEL.WEIGHTS output/batext/ctw1500/attn_R_50/model_0119999.pth
  31. python3 demo/demo.py --config-file configs/BAText/TotalText/attn_R_50.yaml --input demo/demo_images/ --output outputs/ --opts MODEL.WEIGHTS tt_attn_R_50.pth

 

以上针对英文和数字进行训练,后面补充针对中文的训练和修改等操作。

 

基于ctw的训练后,测试结果如下:

 

 

 

看最后一张,漏检为0了,效果上好于TOTALTEX训练的结果。这里我只是训练了

120000 + 120000 step.

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/969942
推荐阅读
相关标签
  

闽ICP备14008679号