当前位置:   article > 正文

基于PaddleOCR的DBNet多分类文本检测网络_paddleocr dbnet

paddleocr dbnet

目录

目的

模型网络结构对比

代码实现

1、数据集格式

2、配置文件调整

3、数据预处理

4、模型代码调整 

5、添加多分类loss

6、修改db_postprocess.py

7、修改train.py、eval.py、infer_det.py和export_model.py

完毕!!!

目的

之前一直思考如果DBnet文本检测网络能够加入多分类的话,就可以实现模型很小又能够区分类别的功能,在端侧部署的话就能达到非常高的精度和效率。在参考了大佬的pytorch版的DBnet多分类功能,在此实现Paddle版的DBnet多分类文本检测网络,注意此方式不适合多个分类有重叠的情况。

模型网络结构对比

修改前 vs 修改后:从图明显发现多出来一个分支用来判定分类的

       

代码实现

经过测试以下方式在官方release/2.6分支中同样好使,本github(文章最后有源码下载地址)中代码版本较低可做参考。

1、数据集格式

新增label_list.txt文件

调整数据集中的 “transcription”对应的值,为上图中的label_name

2、配置文件调整

  1. Global:
  2. ...
  3. label_list: "../../2.4/train_data/sfz/label_list.txt" #新增一个分类文件
  4. num_classes: 9 # 新增一个分类数量
  5. ...
  6. Train:
  7. dataset:
  8. ...
  9. transforms:
  10. ...
  11. - KeepKeys:
  12. keep_keys: [ 'image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask', 'class_mask' ] # 新增一个class_mask
  13. ...
  14. ...

3、数据预处理

将不同分类按照[1,2,3,4...]的样子进行填充,有三个地方需要调整

label_ops.py

  1. class DetLabelEncode(object):
  2. # def __init__(self, **kwargs):
  3. # pass
  4. def __init__(self, label_list, num_classes=1, **kwargs):
  5. self.num_classes = num_classes
  6. self.label_list = []
  7. if label_list is not None:
  8. if isinstance(label_list, str):
  9. with open(label_list, "r+", encoding="utf-8") as f:
  10. for line in f.readlines():
  11. self.label_list.append(line.replace("\n", ""))
  12. else:
  13. self.label_list = label_list
  14. if num_classes != len(self.label_list):
  15. assert "label_list长度与num_classes长度不符合"
  16. def __call__(self, data):
  17. label = data['label']
  18. label = json.loads(label)
  19. nBox = len(label)
  20. boxes, txts, txt_tags = [], [], []
  21. classes = []
  22. for bno in range(0, nBox):
  23. box = label[bno]['points']
  24. txt = label[bno]['transcription']
  25. boxes.append(box)
  26. txts.append(txt)
  27. if txt in ['*', '###']:
  28. txt_tags.append(True)
  29. if self.num_classes > 1:
  30. classes.append(-2)
  31. else:
  32. txt_tags.append(False)
  33. if self.num_classes > 1:
  34. classes.append(int(self.label_list.index(txt)))
  35. if len(boxes) == 0:
  36. return None
  37. boxes = self.expand_points_num(boxes)
  38. boxes = np.array(boxes, dtype=np.float32)
  39. txt_tags = np.array(txt_tags, dtype=np.bool)
  40. # classes = np.array(classes, dtype=np.int)
  41. classes = classes
  42. data['polys'] = boxes
  43. data['texts'] = txts
  44. data['ignore_tags'] = txt_tags
  45. if self.num_classes > 1:
  46. data['classes'] = classes
  47. return data
make_shrink_map.py

 random_crop_data.py

4、模型代码调整 

添加新分支,只需要调整head模块就可以了,det_db_head.py代码如下

  1. class Head(nn.Layer):
  2. def __init__(self, in_channels, name_list, num_classes=1):
  3. super(Head, self).__init__()
  4. self.num_classes = num_classes
  5. ...
  6. self.conv3 = nn.Conv2DTranspose(
  7. in_channels=in_channels // 4,
  8. out_channels=num_classes,
  9. kernel_size=2,
  10. stride=2,
  11. weight_attr=ParamAttr(
  12. initializer=paddle.nn.initializer.KaimingUniform()),
  13. bias_attr=get_bias_attr(in_channels // 4), )
  14. def forward(self, x):
  15. x = self.conv1(x)
  16. x = self.conv_bn1(x)
  17. x = self.conv2(x)
  18. x = self.conv_bn2(x)
  19. x = self.conv3(x)
  20. if self.num_classes == 1:
  21. x = F.sigmoid(x)
  22. return x
  23. class DBHead(nn.Layer):
  24. def __init__(self, in_channels, num_classes=1, k=50, **kwargs):
  25. super(DBHead, self).__init__()
  26. self.k = k
  27. self.num_classes = num_classes
  28. ...
  29. if num_classes != 1:
  30. self.classes = Head(in_channels, binarize_name_list, num_classes=num_classes)
  31. else:
  32. self.classes = None
  33. def step_function(self, x, y):
  34. return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
  35. def forward(self, x, targets=None):
  36. shrink_maps = self.binarize(x)
  37. if not self.training:
  38. if self.num_classes == 1:
  39. return {'maps': shrink_maps}
  40. else:
  41. classes = paddle.argmax(self.classes(x), axis=1, keepdim=True, dtype='int32')
  42. return {'maps': shrink_maps, "classes": classes}
  43. threshold_maps = self.thresh(x)
  44. binary_maps = self.step_function(shrink_maps, threshold_maps)
  45. y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
  46. if self.num_classes == 1:
  47. return {'maps': y}
  48. else:
  49. return {'maps': y, "classes": self.classes(x)}

5、添加多分类loss

参考PaddleSeg代码,新增了一个CrossEntropyLoss方法

  1. class CrossEntropyLoss(nn.Layer):
  2. def __init__(self,
  3. weight=None,
  4. ignore_index=255,
  5. top_k_percent_pixels=1.0,
  6. data_format='NCHW'):
  7. super(CrossEntropyLoss, self).__init__()
  8. self.ignore_index = ignore_index
  9. self.top_k_percent_pixels = top_k_percent_pixels
  10. self.EPS = 1e-8
  11. self.data_format = data_format
  12. if weight is not None:
  13. self.weight = paddle.to_tensor(weight, dtype='float32')
  14. else:
  15. self.weight = None
  16. def forward(self, logit, label, semantic_weights=None):
  17. channel_axis = 1 if self.data_format == 'NCHW' else -1
  18. if self.weight is not None and logit.shape[channel_axis] != len(
  19. self.weight):
  20. raise ValueError(
  21. 'The number of weights = {} must be the same as the number of classes = {}.'
  22. .format(len(self.weight), logit.shape[channel_axis]))
  23. if channel_axis == 1:
  24. logit = paddle.transpose(logit, [0, 2, 3, 1])
  25. label = label.astype('int64')
  26. # In F.cross_entropy, the ignore_index is invalid, which needs to be fixed.
  27. # When there is 255 in the label and paddle version <= 2.1.3, the cross_entropy OP will report an error, which is fixed in paddle develop version.
  28. loss = F.cross_entropy(
  29. logit,
  30. label,
  31. ignore_index=self.ignore_index,
  32. reduction='none',
  33. weight=self.weight)
  34. return self._post_process_loss(logit, label, semantic_weights, loss)
  35. def _post_process_loss(self, logit, label, semantic_weights, loss):
  36. mask = label != self.ignore_index
  37. mask = paddle.cast(mask, 'float32')
  38. label.stop_gradient = True
  39. mask.stop_gradient = True
  40. if loss.ndim > mask.ndim:
  41. loss = paddle.squeeze(loss, axis=-1)
  42. loss = loss * mask
  43. if semantic_weights is not None:
  44. loss = loss * semantic_weights
  45. if self.weight is not None:
  46. _one_hot = F.one_hot(label, logit.shape[-1])
  47. coef = paddle.sum(_one_hot * self.weight, axis=-1)
  48. else:
  49. coef = paddle.ones_like(label)
  50. if self.top_k_percent_pixels == 1.0:
  51. avg_loss = paddle.mean(loss) / (paddle.mean(mask * coef) + self.EPS)
  52. else:
  53. loss = loss.reshape((-1,))
  54. top_k_pixels = int(self.top_k_percent_pixels * loss.numel())
  55. loss, indices = paddle.topk(loss, top_k_pixels)
  56. coef = coef.reshape((-1,))
  57. coef = paddle.gather(coef, indices)
  58. coef.stop_gradient = True
  59. coef = coef.astype('float32')
  60. avg_loss = loss.mean() / (paddle.mean(coef) + self.EPS)
  61. return avg_loss
  62. class DBLoss(nn.Layer):
  63. """
  64. Differentiable Binarization (DB) Loss Function
  65. args:
  66. param (dict): the super paramter for DB Loss
  67. """
  68. def __init__(self,
  69. balance_loss=True,
  70. main_loss_type='DiceLoss',
  71. alpha=5,
  72. beta=10,
  73. ohem_ratio=3,
  74. eps=1e-6,
  75. num_classes=1,
  76. **kwargs):
  77. super(DBLoss, self).__init__()
  78. self.alpha = alpha
  79. self.beta = beta
  80. self.num_classes = num_classes
  81. self.dice_loss = DiceLoss(eps=eps)
  82. self.l1_loss = MaskL1Loss(eps=eps)
  83. self.bce_loss = BalanceLoss(
  84. balance_loss=balance_loss,
  85. main_loss_type=main_loss_type,
  86. negative_ratio=ohem_ratio)
  87. self.loss_func = CrossEntropyLoss()
  88. def forward(self, predicts, labels):
  89. predict_maps = predicts['maps']
  90. if self.num_classes > 1:
  91. predict_classes = predicts['classes']
  92. label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask, class_mask = labels[1:]
  93. else:
  94. label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[1:]
  95. shrink_maps = predict_maps[:, 0, :, :]
  96. threshold_maps = predict_maps[:, 1, :, :]
  97. binary_maps = predict_maps[:, 2, :, :]
  98. loss_shrink_maps = self.bce_loss(shrink_maps, label_shrink_map,
  99. label_shrink_mask)
  100. loss_threshold_maps = self.l1_loss(threshold_maps, label_threshold_map,
  101. label_threshold_mask)
  102. loss_binary_maps = self.dice_loss(binary_maps, label_shrink_map,
  103. label_shrink_mask)
  104. loss_shrink_maps = self.alpha * loss_shrink_maps
  105. loss_threshold_maps = self.beta * loss_threshold_maps
  106. # 处理
  107. if self.num_classes > 1:
  108. loss_classes = self.loss_func(predict_classes, class_mask)
  109. loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps + loss_classes
  110. losses = {'loss': loss_all,
  111. "loss_shrink_maps": loss_shrink_maps,
  112. "loss_threshold_maps": loss_threshold_maps,
  113. "loss_binary_maps": loss_binary_maps,
  114. "loss_classes": loss_classes}
  115. else:
  116. loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps
  117. losses = {'loss': loss_all,
  118. "loss_shrink_maps": loss_shrink_maps,
  119. "loss_threshold_maps": loss_threshold_maps,
  120. "loss_binary_maps": loss_binary_maps}
  121. return losses

6、修改db_postprocess.py

  1. class DBPostProcess(object):
  2. """
  3. The post process for Differentiable Binarization (DB).
  4. """
  5. def __init__(self,
  6. thresh=0.3,
  7. box_thresh=0.7,
  8. max_candidates=1000,
  9. unclip_ratio=2.0,
  10. use_dilation=False,
  11. score_mode="fast",
  12. **kwargs):
  13. self.thresh = thresh
  14. self.box_thresh = box_thresh
  15. self.max_candidates = max_candidates
  16. self.unclip_ratio = unclip_ratio
  17. self.min_size = 3
  18. self.score_mode = score_mode
  19. assert score_mode in [
  20. "slow", "fast"
  21. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  22. self.dilation_kernel = None if not use_dilation else np.array(
  23. [[1, 1], [1, 1]])
  24. def boxes_from_bitmap(self, pred, _bitmap, classes, dest_width, dest_height):
  25. '''
  26. _bitmap: single map with shape (1, H, W),
  27. whose values are binarized as {0, 1}
  28. '''
  29. bitmap = _bitmap
  30. height, width = bitmap.shape
  31. outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
  32. cv2.CHAIN_APPROX_SIMPLE)
  33. if len(outs) == 3:
  34. img, contours, _ = outs[0], outs[1], outs[2]
  35. elif len(outs) == 2:
  36. contours, _ = outs[0], outs[1]
  37. num_contours = min(len(contours), self.max_candidates)
  38. boxes = []
  39. scores = []
  40. class_indexes = []
  41. class_scores = []
  42. for index in range(num_contours):
  43. contour = contours[index]
  44. points, sside = self.get_mini_boxes(contour)
  45. if sside < self.min_size:
  46. continue
  47. points = np.array(points)
  48. if self.score_mode == "fast":
  49. score, class_index, class_score = self.box_score_fast(pred, points.reshape(-1, 2), classes)
  50. else:
  51. score, class_index, class_score = self.box_score_slow(pred, contour, classes)
  52. if self.box_thresh > score:
  53. continue
  54. box = self.unclip(points).reshape(-1, 1, 2)
  55. box, sside = self.get_mini_boxes(box)
  56. if sside < self.min_size + 2:
  57. continue
  58. box = np.array(box)
  59. box[:, 0] = np.clip(
  60. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  61. box[:, 1] = np.clip(
  62. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  63. boxes.append(box.astype(np.int16))
  64. scores.append(score)
  65. class_indexes.append(class_index)
  66. class_scores.append(class_score)
  67. if classes is None:
  68. return np.array(boxes, dtype=np.int16), scores
  69. else:
  70. return np.array(boxes, dtype=np.int16), scores, class_indexes, class_scores
  71. def unclip(self, box):
  72. unclip_ratio = self.unclip_ratio
  73. poly = Polygon(box)
  74. distance = poly.area * unclip_ratio / poly.length
  75. offset = pyclipper.PyclipperOffset()
  76. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  77. expanded = np.array(offset.Execute(distance))
  78. return expanded
  79. def get_mini_boxes(self, contour):
  80. bounding_box = cv2.minAreaRect(contour)
  81. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  82. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  83. if points[1][1] > points[0][1]:
  84. index_1 = 0
  85. index_4 = 1
  86. else:
  87. index_1 = 1
  88. index_4 = 0
  89. if points[3][1] > points[2][1]:
  90. index_2 = 2
  91. index_3 = 3
  92. else:
  93. index_2 = 3
  94. index_3 = 2
  95. box = [
  96. points[index_1], points[index_2], points[index_3], points[index_4]
  97. ]
  98. return box, min(bounding_box[1])
  99. def box_score_fast(self, bitmap, _box, classes):
  100. '''
  101. box_score_fast: use bbox mean score as the mean score
  102. '''
  103. h, w = bitmap.shape[:2]
  104. box = _box.copy()
  105. xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)
  106. xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)
  107. ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)
  108. ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)
  109. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  110. box[:, 0] = box[:, 0] - xmin
  111. box[:, 1] = box[:, 1] - ymin
  112. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
  113. if classes is None:
  114. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], None, None
  115. else:
  116. k = 999
  117. class_mask = np.full((ymax - ymin + 1, xmax - xmin + 1), k, dtype=np.int32)
  118. cv2.fillPoly(class_mask, box.reshape(1, -1, 2).astype(np.int32), 0)
  119. classes = classes[ymin:ymax + 1, xmin:xmax + 1]
  120. new_classes = classes + class_mask
  121. # 拉平
  122. a = new_classes.reshape(-1)
  123. b = np.where(a >= k)
  124. classes = np.delete(a, b[0].tolist())
  125. class_index = np.argmax(np.bincount(classes))
  126. class_score = np.sum(classes == class_index) / len(classes)
  127. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], class_index, class_score
  128. def box_score_slow(self, bitmap, contour, classes):
  129. '''
  130. box_score_slow: use polyon mean score as the mean score
  131. '''
  132. h, w = bitmap.shape[:2]
  133. contour = contour.copy()
  134. contour = np.reshape(contour, (-1, 2))
  135. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  136. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  137. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  138. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  139. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  140. contour[:, 0] = contour[:, 0] - xmin
  141. contour[:, 1] = contour[:, 1] - ymin
  142. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype(np.int32), 1)
  143. if classes is None:
  144. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], None, None
  145. else:
  146. k = 999
  147. class_mask = np.full((ymax - ymin + 1, xmax - xmin + 1), k, dtype=np.int32)
  148. cv2.fillPoly(class_mask, contour.reshape(1, -1, 2).astype(np.int32), 0)
  149. classes = classes[ymin:ymax + 1, xmin:xmax + 1]
  150. new_classes = classes + class_mask
  151. # 拉平
  152. a = new_classes.reshape(-1)
  153. b = np.where(a >= k)
  154. classes = np.delete(a, b[0].tolist())
  155. class_index = np.argmax(np.bincount(classes))
  156. class_score = np.sum(classes == class_index) / len(classes)
  157. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], class_index, class_score
  158. def __call__(self, outs_dict, shape_list):
  159. pred = outs_dict['maps']
  160. if isinstance(pred, paddle.Tensor):
  161. pred = pred.numpy()
  162. pred = pred[:, 0, :, :]
  163. segmentation = pred > self.thresh
  164. if "classes" in outs_dict:
  165. classes = outs_dict['classes']
  166. if isinstance(classes, paddle.Tensor):
  167. classes = classes.numpy()
  168. else:
  169. classes = None
  170. boxes_batch = []
  171. for batch_index in range(pred.shape[0]):
  172. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  173. if self.dilation_kernel is not None:
  174. mask = cv2.dilate(
  175. np.array(segmentation[batch_index]).astype(np.uint8),
  176. self.dilation_kernel)
  177. else:
  178. mask = segmentation[batch_index]
  179. if classes is None:
  180. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, None,
  181. src_w, src_h)
  182. boxes_batch.append({'points': boxes})
  183. else:
  184. boxes, scores, class_indexes, class_scores = self.boxes_from_bitmap(pred[batch_index], mask,
  185. classes[batch_index],
  186. src_w, src_h)
  187. boxes_batch.append({'points': boxes, "classes": class_indexes, "class_scores": class_scores})
  188. return boxes_batch

7、修改train.py、eval.py、infer_det.py和export_model.py

添加这两行代码

  1. if "num_classes" in global_config:
  2. config['Architecture']["Head"]['num_classes'] = global_config["num_classes"]
  3. config['Loss']['num_classes'] = global_config["num_classes"]

完毕!!!

到此,整个网络结构及核心代码就完成了!本文只讲解了在python端相关实现和部署代码,如果需要c++、端侧等部署代码可以加QQ【2952855968】获取,接下来我们看看实际效果如何。

后面将写几篇文章来讲解DBNet多分类的应用,敬请关注!

github代码地址:GitHub - yangy996/PaddleOCR

工程应用:

1、基于DBnet多分类身份证识别证识别

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/692800
推荐阅读
相关标签
  

闽ICP备14008679号