赞
踩
ABCNet:基于自适应贝塞尔曲线的实时端到端自然场景文字检测及识别网络
论文链接 : https://arxiv.org/abs/2002.10200
官方开源代码: https://github.com/aim-uofa/AdelaiDet
作者 | Yuliang Liu, Hao Chen, Chunhua Shen, Tong He, Lianwen Jin, Liangwei Wang
单位 | 华南理工大学;阿德莱德大学;
代码 | https://github.com/Yuliang-Liu/bezier_curve_text_spotting
备注 | CVPR 2020 Oral
解读 | https://zhuanlan.zhihu.com/p/146276834
论文是2020 CVPR 收录, 贡献1)提出采用贝塞尔曲线来拟合任意形状文本,2)提出贝塞尔对齐方式更准确地提取文本实例 .
简介部分参考知乎解读即可,具体理论知识不再赘述,下面进入正题,环境配置+测试+训练
- # env
- torch 1.4.0
- torchvision 0.5.0
- py362 cuda10.1
-
- # 1.First install Detectron2
- git clone https://github.com/facebookresearch/detectron2.git
- python -m pip install -e detectron2
-
- # 2.
- cd AdelaiDet
- python setup.py build develop
如上,首先要安装Detectron2板块,具体 following the official guide: INSTALL.md.
第二步就是正式安装AdelaiDet,按指令编译即可。
以上环境在torch 1.4.0 + torchvision0.5.0 ,python3.6.2 .cuda10.1 的虚拟环境下配置,基于anaconda.本机实际cuda安装的10.2.
特别提示,torch和torchvision的版本必须一致,否则编译中途会报各种莫名其妙的错误。已踩坑。
最后,综合环境预览如下:
- ---------------------- ---------------------------------------------------------------------------------------------------------
- sys.platform linux
- Python 3.6.2 |Continuum Analytics, Inc.| (default, Jul 20 2017, 13:51:32) [GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
- numpy 1.19.1
- detectron2 0.2.1 @/home/gavin/MyProj/tempwork/ocr/AdelaiDet/detectron2/detectron2
- Compiler GCC 5.4
- CUDA compiler CUDA 10.2
- detectron2 arch flags 6.1
- DETECTRON2_ENV_MODULE <not set>
- PyTorch 1.4.0 @/home/gavin/miniconda3/envs/py362/lib/python3.6/site-packages/torch
- PyTorch debug build False
- GPU available True
- GPU 0 GeForce GTX 1080 Ti (arch=6.1)
- CUDA_HOME /usr/local/cuda-10.2
- Pillow 4.2.1
- torchvision 0.5.0 @/home/gavin/miniconda3/envs/py362/lib/python3.6/site-packages/torchvision
- torchvision arch flags 3.5, 5.0, 6.0, 7.0, 7.5
- fvcore 0.1.1.post20200716
- cv2 4.4.0
- ---------------------- ---------------------------------------------------------------------------------------------------------
- PyTorch built with:
- - GCC 7.3
- - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
- - Intel(R) MKL-DNN v0.21.1 (Git Hash 7d2fd500bc78936d1d648ca713b901012f470dbc)
- - OpenMP 201511 (a.k.a. OpenMP 4.5)
- - NNPACK is enabled
- - CUDA Runtime 10.1
- - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
- - CuDNN 7.6.3
- - Magma 2.5.1
- - Build settings: BLAS=MKL, BUILD_NAMEDTENSOR=OFF, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -fopenmp -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -O2 -fPIC -Wno-narrowing -Wall -Wextra -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Wno-stringop-overflow, DISABLE_NUMA=1, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF,

假设你配好了环境,那么进行下面的代码测试即可:
- # test
- python demo/demo.py \
- --config-file configs/BAText/TotalText/attn_R_50.yaml \
- --input datasets/totaltext/test_images/ \
- --opts MODEL.WEIGHTS tt_attn_R_50.pth
-
-
- eg:
-
- python demo/demo.py --config-file configs/BAText/TotalText/attn_R_50.yaml --input /media/gavin/home/gavin/DataSet/ocr/test/bdz --output outputs/ --opts MODEL.WEIGHTS tt_attn_R_50.pth
-
-
- python3 demo/demo.py --config-file configs/BAText/TotalText/attn_R_50.yaml --input demo/demo_images/ --output outputs/ --opts MODEL.WEIGHTS tt_attn_R_50.pth
基于totalText测试情况如下:
可以看出检测出来的还是比较准确,但是漏检情况较多,可能跟数据集有关,我这里是拿的ICDAR2015里面的数据测试的。
测试的权重文件是官方训练的。
下面自行训练的CTW,检测效果上,漏检情况就好很多。结果见后面。
最关键也是最麻烦的事情来了。
方法一:直接windows_label_tool工具标注的数据,这里是基于自己的数据进行标注,那么只需要执行一个转换即可,将windows_label_tool工具标注的格式转为abcnet训练的json格式。
参考这里。
- # ABCNet 自定义数据集制作,将ICDAR15转为ABCNet标注格式 参考https://github.com/Yuliang-Liu/Curve-Text-Detector/tree/master/data
-
- # 1.将labelme标注转为windows_label_tool标注格式,如下,首行是代表标注个数,下面依次是每行的标注,包含28/2 = 14个点坐标,后面是文本内容
-
- 4
- 45,73,59,67,74,61,89,56,104,60,119,67,135,73,130,84,116,79,102,74,88,68,75,73,61,79,48,84,"DOUGLASTON"
- 50,119,58,119,66,119,74,119,82,119,90,119,98,119,98,137,90,137,82,137,74,137,66,137,58,137,51,137,"E-313"
- 41,137,48,136,56,136,64,136,71,136,79,136,87,136,89,155,81,155,73,155,65,155,57,155,49,155,41,155,"L164"
- 39,166,56,166,74,166,92,167,110,167,128,167,146,168,140,196,123,195,107,195,90,194,74,194,57,193,41,193,"F.D.N.Y."
-
-
- # 2. convert_ann_to_json:将生成windows_label_tool工具的标注格式转换为ABCNet训练的json格式标注
- python convert_ann_to_json.py \
- --ann-dir /path/to/gt \
- --image-dir /path/to/image \
- --dst-json-path train.json
-
- eg:
- python convert_ann_to_json.py --ann-dir /media/gavin/home/gavin/DataSet/ocr/ctw/ctw1500_e2e_annos/ctw1500_e2e_test \
- --image-dir /media/gavin/home/gavin/DataSet/ocr/ctw/ctw1500/test/text_image \
- --dst-json-path ./abc_json/test.json

json转换脚本
- # -*- coding: utf-8 -*-
- """
- @File : convert_ann_to_json.py
- @Time : 2020-8-17 16:13
- @Author : yizuotian
- @Description : 生成windows_label_tool工具的标注格式转换为ABCNet训练的json格式标注
- """
- import argparse
- import json
- import os
- import sys
- import cv2
- import bezier_utils
- import numpy as np
-
-
- def gen_abc_json(abc_gt_dir, abc_json_path, image_dir, classes_path):
- """
- 根据abcnet的gt标注生成coco格式的json标注
- :param abc_gt_dir: windows_label_tool标注工具生成标注文件目录
- :param abc_json_path: ABCNet训练需要json标注路径
- :param image_dir:
- :param classes_path: 类别文件路径
- :return:
- """
- # Desktop Latin_embed.
- cV2 = [' ', '!', '"', '#', '$', '%', '&', '\'', '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4',
- '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J',
- 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', '\\', ']', '^', '_',
- '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u',
- 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~']
-
- dataset = {
- 'licenses': [],
- 'info': {},
- 'categories': [],
- 'images': [],
- 'annotations': []
- }
- with open(classes_path) as f:
- classes = f.read().strip().split()
- for i, cls in enumerate(classes, 1):
- dataset['categories'].append({
- 'id': i,
- 'name': cls,
- 'supercategory': 'beverage',
- 'keypoints': ['mean',
- 'xmin',
- 'x2',
- 'x3',
- 'xmax',
- 'ymin',
- 'y2',
- 'y3',
- 'ymax',
- 'cross'] # only for BDN
- })
-
- def get_category_id(cls):
- for category in dataset['categories']:
- if category['name'] == cls:
- return category['id']
-
- # 遍历abcnet txt 标注
- indexes = sorted([f.split('.')[0]
- for f in os.listdir(abc_gt_dir)])
- print(indexes)
-
- j = 1 # 标注边框id号
- for index in indexes:
- # if int(index) >3: continue
- # print('Processing: ' + index)
- im = cv2.imread(os.path.join(image_dir, '{}.jpg'.format(index)))
- im_height, im_width = im.shape[:2]
- dataset['images'].append({
- 'coco_url': '',
- 'date_captured': '',
- 'file_name': index + '.jpg',
- 'flickr_url': '',
- 'id': int(index.split('_')[-1]), # img_1
- 'license': 0,
- 'width': im_width,
- 'height': im_height
- })
- anno_file = os.path.join(abc_gt_dir, '{}.txt'.format(index))
-
- with open(anno_file) as f:
- lines = [line for line in f.readlines() if line.strip()]
- # 没有清晰的标注,跳过
- if len(lines) <= 1:
- continue
- for i, line in enumerate(lines[1:]):
- elements = line.strip().split(',')
- polygon = np.array(elements[:28]).reshape((-1, 2)).astype(np.float32) # [14,(x,y)]
- control_points = bezier_utils.polygon_to_bezier_pts(polygon, im) # [8,(x,y)]
- ct = elements[-1].replace('"', '').strip()
-
- cls = 'text'
- # segs = [float(kkpart) for kkpart in parts[:16]]
- segs = [float(kkpart) for kkpart in control_points.flatten()]
- xt = [segs[ikpart] for ikpart in range(0, len(segs), 2)]
- yt = [segs[ikpart] for ikpart in range(1, len(segs), 2)]
-
- # 过滤越界边框
- if max(xt) > im_width or max(yt) > im_height:
- print('The annotation bounding box is outside of the image:{}'.format(index))
- print("max x:{},max y:{},w:{},h:{}".format(max(xt), max(yt), im_width, im_height))
- continue
- xmin = min([xt[0], xt[3], xt[4], xt[7]])
- ymin = min([yt[0], yt[3], yt[4], yt[7]])
- xmax = max([xt[0], xt[3], xt[4], xt[7]])
- ymax = max([yt[0], yt[3], yt[4], yt[7]])
- width = max(0, xmax - xmin + 1)
- height = max(0, ymax - ymin + 1)
- if width == 0 or height == 0:
- continue
-
- max_len = 100
- recs = [len(cV2) + 1 for ir in range(max_len)]
-
- ct = str(ct)
- # print('rec', ct)
-
- for ix, ict in enumerate(ct):
- if ix >= max_len:
- continue
- if ict in cV2:
- recs[ix] = cV2.index(ict)
- else:
- recs[ix] = len(cV2)
-
- dataset['annotations'].append({
- 'area': width * height,
- 'bbox': [xmin, ymin, width, height],
- 'category_id': get_category_id(cls),
- 'id': j,
- 'image_id': int(index.split('_')[-1]), # img_1
- 'iscrowd': 0,
- 'bezier_pts': segs,
- 'rec': recs
- })
- j += 1
-
- # 写入json文件
- folder = os.path.dirname(abc_json_path)
- if not os.path.exists(folder):
- os.makedirs(folder)
- with open(abc_json_path, 'w') as f:
- json.dump(dataset, f)
-
-
- def main(args):
- gen_abc_json(args.ann_dir, args.dst_json_path, args.image_dir, args.classes_path)
-
-
- if __name__ == '__main__':
- """
- Usage: python convert_ann_to_json.py \
- --ann-dir /path/to/gt \
- --image-dir /path/to/image \
- --dst-json-path train.json
- """
- parse = argparse.ArgumentParser()
- parse.add_argument("--ann-dir", type=str, default=None)
- parse.add_argument("--image-dir", type=str, default=None)
- parse.add_argument("--dst-json-path", type=str, default=None)
- parse.add_argument("--classes-path", type=str, default='./classes.txt')
- arguments = parse.parse_args() # sys.argv[1:]
- main(arguments)

方法二:
将labelme标注转为windows_label_tool标注格式,然后执行方法一的json转换。
labelme标注的格式转为windows_label_tool:
1. labelme 标注的json文件标注转abcnet 的gt标注,如果直接使用windowlabel工具标注则可省去此步骤
- # coding=utf-8
- # labelme 标注的json文件标注转abcnet 的标注,如果直接使用windowlabel工具标注则可省去此步骤
-
- import numpy as np
- import matplotlib.pyplot as plt
- import matplotlib.image as mpimg
- from scipy import interpolate
- from scipy.special import comb as n_over_k
- import glob, os
- import cv2
-
- from skimage import data, color
- from skimage.transform import rescale, resize, downscale_local_mean
- import json
- import matplotlib.pyplot as plt
- import math
- import numpy as np
- import random
- import torch
- from torch import nn
- from torch.nn import functional as F
-
- from sklearn.model_selection import train_test_split
- from sklearn.linear_model import LinearRegression
- from sklearn import metrics
- from sklearn.metrics import mean_squared_error, r2_score
-
- from shapely.geometry import *
- from PIL import Image
- import time
- import math
- import re
-
-
- class Bezier(nn.Module):
- def __init__(self, ps, ctps):
- """
- ps: numpy array of points
- """
- super(Bezier, self).__init__()
- self.x1 = nn.Parameter(torch.as_tensor(ctps[0], dtype=torch.float64))
- self.x2 = nn.Parameter(torch.as_tensor(ctps[2], dtype=torch.float64))
- self.y1 = nn.Parameter(torch.as_tensor(ctps[1], dtype=torch.float64))
- self.y2 = nn.Parameter(torch.as_tensor(ctps[3], dtype=torch.float64))
- self.x0 = ps[0, 0]
- self.x3 = ps[-1, 0]
- self.y0 = ps[0, 1]
- self.y3 = ps[-1, 1]
- self.inner_ps = torch.as_tensor(ps[1:-1, :], dtype=torch.float64)
- self.t = torch.as_tensor(np.linspace(0, 1, 81))
-
- def forward(self):
- x0, x1, x2, x3, y0, y1, y2, y3 = self.control_points()
- t = self.t
- bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
- bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
- bezier = torch.stack((bezier_x, bezier_y), dim=1)
- diffs = bezier.unsqueeze(0) - self.inner_ps.unsqueeze(1)
- sdiffs = diffs ** 2
- dists = sdiffs.sum(dim=2).sqrt()
- min_dists, min_inds = dists.min(dim=1)
- return min_dists.sum()
-
- def control_points(self):
- return self.x0, self.x1, self.x2, self.x3, self.y0, self.y1, self.y2, self.y3
-
- def control_points_f(self):
- return self.x0, self.x1.item(), self.x2.item(), self.x3, self.y0, self.y1.item(), self.y2.item(), self.y3
-
-
- def train(x, y, ctps, lr):
- x, y = np.array(x), np.array(y)
- ps = np.vstack((x, y)).transpose()
- bezier = Bezier(ps, ctps)
-
- return bezier.control_points_f()
-
- def draw(ps, control_points, t):
- x = ps[:, 0]
- y = ps[:, 1]
- x0, x1, x2, x3, y0, y1, y2, y3 = control_points
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.plot(x,y,color='m',linestyle='',marker='.')
- bezier_x = (1-t)*((1-t)*((1-t)*x0+t*x1)+t*((1-t)*x1+t*x2))+t*((1-t)*((1-t)*x1+t*x2)+t*((1-t)*x2+t*x3))
- bezier_y = (1-t)*((1-t)*((1-t)*y0+t*y1)+t*((1-t)*y1+t*y2))+t*((1-t)*((1-t)*y1+t*y2)+t*((1-t)*y2+t*y3))
-
- plt.plot(bezier_x,bezier_y, 'g-')
- plt.draw()
- plt.pause(1) # <-------
- raw_input("<Hit Enter To Close>")
- plt.close(fig)
-
-
- Mtk = lambda n, t, k: t**k * (1-t)**(n-k) * n_over_k(n,k)
- BezierCoeff = lambda ts: [[Mtk(3,t,k) for k in range(4)] for t in ts]
-
-
- def bezier_fit(x, y):
- dy = y[1:] - y[:-1]
- dx = x[1:] - x[:-1]
- dt = (dx ** 2 + dy ** 2)**0.5
- t = dt/dt.sum()
- t = np.hstack(([0], t))
- t = t.cumsum()
-
- data = np.column_stack((x, y))
- Pseudoinverse = np.linalg.pinv(BezierCoeff(t)) # (9,4) -> (4,9)
-
- control_points = Pseudoinverse.dot(data) # (4,9)*(9,2) -> (4,2)
- medi_ctp = control_points[1:-1,:].flatten().tolist()
- return medi_ctp
-
- def bezier_fitv2(x, y):
- xc01 = (2*x[0] + x[-1])/3.0
- yc01 = (2*y[0] + y[-1])/3.0
- xc02 = (x[0] + 2* x[-1])/3.0
- yc02 = (y[0] + 2* y[-1])/3.0
- control_points = [xc01,yc01,xc02,yc02]
- return control_points
-
- def is_close_to_line(xs, ys, thres):
- regression_model = LinearRegression()
- # Fit the data(train the model)
- regression_model.fit(xs.reshape(-1,1), ys.reshape(-1,1))
- # Predict
- y_predicted = regression_model.predict(xs.reshape(-1,1))
-
- # model evaluation
- rmse = mean_squared_error(ys.reshape(-1,1)**2, y_predicted**2)
- rmse = rmse/(ys.reshape(-1,1)**2- y_predicted**2).max()**2
-
- if rmse > thres:
- return 0.0
- else:
- return 2.0
-
- def is_close_to_linev2(xs, ys, size, thres = 0.05):
- pts = []
- nor_pixel = int(size**0.5)
- for i in range(len(xs)):
- pts.append(Point([xs[i], ys[i]]))
- import itertools
- # iterate by pairs of points
- slopes = [(second.y-first.y)/(second.x-first.x) if not (second.x-first.x) == 0.0 else math.inf*np.sign((second.y-first.y)) for first, second in zip(pts, pts[1:])]
- st_slope = (ys[-1] - ys[0])/(xs[-1] - xs[0])
- max_dis = ((ys[-1] - ys[0])**2 +(xs[-1] - xs[0])**2)**(0.5)
-
- diffs = abs(slopes - st_slope)
- score = diffs.sum() * max_dis/nor_pixel
-
- if score < thres:
- return 0.0
- else:
- return 3.0
-
-
- labels = glob.glob("dataset/json/*.json")
- labels.sort()
-
- if not os.path.isdir('abcnet_gen_labels'):
- os.mkdir('abcnet_gen_labels')
-
- for il, label in enumerate(labels):
- print('Processing: '+label)
- imgdir = label.replace('json/', 'image/').replace('.json', '.jpg')
-
- outgt = open(label.replace('dataset/json/', 'abcnet_gen_labels/').replace('.json', '.txt'), 'w')
-
- data = []
- cts = []
- with open(label,"r") as f:
- jdata = json.loads(f.read())
- boxes = jdata["shapes"]
- for il ,box in enumerate(boxes):
- line,ct = box["points"],box["label"]
- pts =[]
- [pts.extend(p) for p in line]
- if len(line) == 4:
- pts = line[0] + [(line[0][0]+line[1][0])//2, (line[0][1]+line[1][1])//2] + line[1] + line[2] +[(line[2][0]+line[3][0])/2, (line[2][1]+line[3][1])/2]+ line[3]
- if len(line) == 6:
- if abs(line[0][0] - line[1][0]) > abs(line[1][0] - line[2][0]):
- pts= line[0] + [(line[0][0]+line[1][0])//2, (line[0][1]+line[1][1])//2] + line[1] + line[2]
- pts += line[3] + [(line[3][0]+line[4][0])//2, (line[3][1]+line[4][1])//2] + line[4] + line[5]
- else:
- pts = line[0] + line[1] + [(line[1][0]+line[2][0])//2, (line[1][1]+line[2][1])//2] + line[2]
- pts += line[3] + line[4] + [(line[4][0]+line[5][0])//2, (line[4][1]+line[5][1])//2] + line[5]
- data.append(np.array([float(x) for x in pts]))
- cts.append(ct)
-
- ############## top
- img = plt.imread(imgdir)
-
- for iid, ddata in enumerate(data):
- lh = len(data[iid])
- assert(lh % 4 ==0)
- lhc2 = int(lh/2)
- lhc4 = int(lh/4)
- xcors = [data[iid][i] for i in range(0, len(data[iid]),2)]
- ycors = [data[iid][i+1] for i in range(0, len(data[iid]),2)]
-
- curve_data_top = data[iid][0:lhc2].reshape(lhc4, 2)
- curve_data_bottom = data[iid][lhc2:].reshape(lhc4, 2)
-
- left_vertex_x = [curve_data_top[0,0], curve_data_bottom[lhc4-1,0]]
- left_vertex_y = [curve_data_top[0,1], curve_data_bottom[lhc4-1,1]]
- right_vertex_x = [curve_data_top[lhc4-1,0], curve_data_bottom[0,0]]
- right_vertex_y = [curve_data_top[lhc4-1,1], curve_data_bottom[0,1]]
-
- x_data = curve_data_top[:, 0]
- y_data = curve_data_top[:, 1]
-
- init_control_points = bezier_fit(x_data, y_data)
-
- learning_rate = is_close_to_linev2(x_data, y_data, img.size)
-
- x0, x1, x2, x3, y0, y1, y2, y3 = train(x_data, y_data, init_control_points, learning_rate)
- control_points = np.array([
- [x0,y0],\
- [x1,y1],\
- [x2,y2],\
- [x3,y3]
- ])
-
- x_data_b = curve_data_bottom[:, 0]
- y_data_b = curve_data_bottom[:, 1]
-
- init_control_points_b = bezier_fit(x_data_b, y_data_b)
-
- learning_rate = is_close_to_linev2(x_data_b, y_data_b, img.size)
-
- x0_b, x1_b, x2_b, x3_b, y0_b, y1_b, y2_b, y3_b = train(x_data_b, y_data_b, init_control_points_b, learning_rate)
- control_points_b = np.array([
- [x0_b,y0_b],\
- [x1_b,y1_b],\
- [x2_b,y2_b],\
- [x3_b,y3_b]
- ])
-
- t_plot = np.linspace(0, 1, 81)
- Bezier_top = np.array(BezierCoeff(t_plot)).dot(control_points)
- Bezier_bottom = np.array(BezierCoeff(t_plot)).dot(control_points_b)
-
-
- plt.plot(Bezier_top[:,0], Bezier_top[:,1], 'g-', label='fit', linewidth=1)
- plt.plot(Bezier_bottom[:,0],Bezier_bottom[:,1],'g-', label='fit', linewidth=1)
- plt.plot(control_points[:,0],control_points[:,1], 'r.:', fillstyle='none', linewidth=1)
- plt.plot(control_points_b[:,0],control_points_b[:,1], 'r.:', fillstyle='none', linewidth=1)
-
- plt.plot(left_vertex_x, left_vertex_y, 'g-', linewidth=1)
- plt.plot(right_vertex_x, right_vertex_y, 'g-', linewidth=1)
-
- outstr = '{},{},{},{},{},{},{},{},{},{},{},{},{},{},{},{}||||{}\n'.format(round(x0,2),round(y0,2),\
- round(x1, 2), round(y1, 2),\
- round(x2, 2), round(y2, 2),\
- round(x3, 2), round(y3, 2),\
- round(x0_b, 2), round(y0_b, 2),\
- round(x1_b, 2), round(y1_b, 2),\
- round(x2_b, 2), round(y2_b, 2),\
- round(x3_b, 2), round(y3_b, 2),\
- cts[iid])
- outgt.writelines(outstr)
- outgt.close()
-
- plt.imshow(img)
- plt.axis('off')
-
- if not os.path.isdir('abcnet_vis'):
- os.mkdir('abcnet_vis')
- plt.savefig('abcnet_vis/'+os.path.basename(imgdir), bbox_inches='tight',dpi=400)
- plt.clf()

2.abcnet 的标注转abcnet的json
- _PREDEFINED_SPLITS_TEXT = {
- "totaltext_train": ("totaltext/train_images", "totaltext/train.json"),
- "totaltext_val": ("totaltext/test_images", "totaltext/test.json"),
- ...
- "abcnet_train": ("data/train", "data/annotations/train.json"),
- "abcnet_test": ("data/test", "data/annotations/test.json"),}
configs/BAText/CTW1500/Base-CTW1500.yaml
为例, - DATASETS:
- # detail cfg: AdelaiDet/adet/data/builtin.py
- TRAIN: ("abcnet_train",)
- TEST: ("abcnet_test",)
训练脚本如下:
- # train custom
-
- #1. Pretrainining with synthetic data:
-
- OMP_NUM_THREADS=1 python tools/train_net.py \
- --config-file configs/BAText/Pretrain/attn_R_50.yaml \
- --num-gpus 4 \
- OUTPUT_DIR text_pretraining/attn_R_50
-
- #2. Finetuning
-
- OMP_NUM_THREADS=1 python tools/train_net.py \
- --config-file configs/BAText/CTW1500/attn_R_50.yaml \
- --num-gpus 4 \
- MODEL.WEIGHTS text_pretraining/attn_R_50/model_final.pth
-
-
- eg:
- # 1.
- OMP_NUM_THREADS=1 python tools/train_net.py --config-file configs/BAText/CTW1500/attn_R_50.yaml --num-gpus 1
-
- # 2.Finetuning on CTW1500:
-
- OMP_NUM_THREADS=1 python tools/train_net.py \
- --config-file configs/BAText/CTW1500/attn_R_50.yaml \
- --num-gpus 1 \
- MODEL.WEIGHTS text_pretraining/attn_R_50/model_final.pth
-
- eg:
- OMP_NUM_THREADS=1 python tools/train_net.py --config-file \
- configs/BAText/CTW1500/attn_R_50.yaml --num-gpus 1 MODEL.WEIGHTS \
- output/batext/ctw1500/attn_R_50/model_final.pth
-
- # eval:
- python tools/train_net.py \
- --config-file configs/BAText/CTW1500/attn_R_50.yaml \
- --eval-only \
- MODEL.WEIGHTS ctw1500_attn_R_50.pth
-
- test:
- python demo/demo.py --config-file configs/BAText/CTW1500/attn_R_50.yaml --input demo/demo_images/ --output outputs/ --opts MODEL.WEIGHTS output/batext/ctw1500/attn_R_50/model_0119999.pth
-
- python3 demo/demo.py --config-file configs/BAText/TotalText/attn_R_50.yaml --input demo/demo_images/ --output outputs/ --opts MODEL.WEIGHTS tt_attn_R_50.pth

以上针对英文和数字进行训练,后面补充针对中文的训练和修改等操作。
基于ctw的训练后,测试结果如下:
看最后一张,漏检为0了,效果上好于TOTALTEX训练的结果。这里我只是训练了
120000 + 120000 step.
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。