当前位置:   article > 正文

【深度学习】OCR中的Shrink操作详解

【深度学习】OCR中的Shrink操作详解

OCR中的Shrink操作详解

光学字符识别(OCR)中,shrink操作用于对文本框多边形进行缩放,以生成用于训练和检测的特征图。本文将介绍shrink操作的背景、实现方法及其应用。以下是用户提供的代码,详细展示了如何实现这一过程。

背景介绍

在OCR任务中,文本通常以多边形的形式标注于图像中。为了更好地训练检测模型,通常需要将这些多边形进行一定比例的缩放(shrink),以生成不同大小的特征图,从而提高模型的泛化能力和精度。shrink操作的目标是将文本框缩小,以减少噪声对检测结果的影响。

代码实现

以下是实现shrink操作的详细代码:

import numpy as np
import cv2
import pyclipper
from shapely.geometry import Polygon

def shrink_polygon_py(polygon, shrink_ratio):
    """
    对框进行缩放,返回去的比例为1/shrink_ratio 即可
    """
    cx = polygon[:, 0].mean()
    cy = polygon[:, 1].mean()
    polygon[:, 0] = cx + (polygon[:, 0] - cx) * shrink_ratio
    polygon[:, 1] = cy + (polygon[:, 1] - cy) * shrink_ratio
    return polygon

def shrink_polygon_pyclipper(polygon, shrink_ratio):
    polygon_shape = Polygon(polygon)
    distance = (
        polygon_shape.area * (1 - np.power(shrink_ratio, 2)) / polygon_shape.length
    )
    subject = [tuple(l) for l in polygon]
    padding = pyclipper.PyclipperOffset()
    padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
    shrinked = padding.Execute(-distance)
    if shrinked == []:
        shrinked = np.array(shrinked)
    else:
        shrinked = np.array(shrinked[0]).reshape(-1, 2)
    return shrinked

class MakeShrinkMap:
    def __init__(self, min_text_size=8, shrink_ratio=0.4, shrink_type="pyclipper"):
        shrink_func_dict = {
            "py": shrink_polygon_py,
            "pyclipper": shrink_polygon_pyclipper,
        }
        self.shrink_func = shrink_func_dict[shrink_type]
        self.min_text_size = min_text_size
        self.shrink_ratio = shrink_ratio

    def __call__(self, data: dict) -> dict:
        image = data["img"]
        text_polys = data["text_polys"]
        ignore_tags = data["ignore_tags"]

        h, w = image.shape[:2]
        text_polys, ignore_tags = self.validate_polygons(text_polys, ignore_tags, h, w)
        gt = np.zeros((h, w), dtype=np.float32)
        mask = np.ones((h, w), dtype=np.float32)
        shrinked_polygons = []
        for i in range(len(text_polys)):
            polygon = text_polys[i]
            height = max(polygon[:, 1]) - min(polygon[:, 1])
            width = max(polygon[:, 0]) - min(polygon[:, 0])
            if ignore_tags[i] or min(height, width) < self.min_text_size:
                cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
                ignore_tags[i] = True
            else:
                shrinked = self.shrink_func(polygon, self.shrink_ratio)
                shrinked_polygons.append(shrinked)
                if shrinked.size == 0:
                    cv2.fillPoly(mask, polygon.astype(np.int32)[np.newaxis, :, :], 0)
                    ignore_tags[i] = True
                    continue
                cv2.fillPoly(gt, [shrinked.astype(np.int32)], 1)

        data["shrink_map"] = gt
        data["shrink_mask"] = mask
        data["shrinked_polygons"] = shrinked_polygons
        return data

    def validate_polygons(self, polygons, ignore_tags, h, w):
        if len(polygons) == 0:
            return polygons, ignore_tags
        assert len(polygons) == len(ignore_tags)
        for polygon in polygons:
            polygon[:, 0] = np.clip(polygon[:, 0], 0, w - 1)
            polygon[:, 1] = np.clip(polygon[:, 1], 0, h - 1)

        for i in range(len(polygons)):
            area = self.polygon_area(polygons[i])
            if abs(area) < 1:
                ignore_tags[i] = True
            if area > 0:
                polygons[i] = polygons[i][::-1, :]
        return polygons, ignore_tags

    def polygon_area(self, polygon):
        return cv2.contourArea(polygon)

if __name__ == "__main__":
    # 示例图像
    image = np.ones((200, 200, 3), dtype=np.uint8) * 255

    # 示例文本框多边形
    text_polys = [
        np.array([[50, 50], [150, 50], [150, 100], [50, 100]]),
        np.array([[60, 120], [140, 120], [140, 160], [60, 160]])
    ]

    # 示例忽略标志
    ignore_tags = [False, False]

    # 构建输入数据字典
    data = {
        "img": image,
        "text_polys": text_polys,
        "ignore_tags": ignore_tags
    }

    # 初始化 MakeShrinkMap 类
    make_shrink_map = MakeShrinkMap(min_text_size=8, shrink_ratio=0.4, shrink_type="pyclipper")

    # 调用类处理数据
    result = make_shrink_map(data)

    # 获取生成的shrink_map和shrink_mask
    shrink_map = result["shrink_map"]
    shrink_mask = result["shrink_mask"]
    shrinked_polygons = result["shrinked_polygons"]

    # 在原图上绘制shrink前的多边形
    original_image = image.copy()
    for polygon in text_polys:
        cv2.polylines(original_image, [polygon.astype(np.int32)], True, (0, 0, 255), 2)

    # 在原图上绘制shrink后的多边形
    shrinked_image = image.copy()
    for polygon in shrinked_polygons:
        cv2.polylines(shrinked_image, [polygon.astype(np.int32)], True, (0, 255, 0), 2)

    # 保存结果图像
    cv2.imwrite("original_image.png", original_image)
    cv2.imwrite("shrinked_image.png", shrinked_image)
    cv2.imwrite("shrink_map.png", shrink_map * 255)  # 将shrink_map转换为图像
    cv2.imwrite("shrink_mask.png", shrink_mask * 255)  # 将shrink_mask转换为图像

    # 显示结果
    # cv2.imshow("Original Image", original_image)
    # cv2.imshow("Shrinked Image", shrinked_image)
    # cv2.imshow("Shrink Map", shrink_map)
    # cv2.imshow("Shrink Mask", shrink_mask)
    # cv2.waitKey(0)
    # cv2.destroyAllWindows()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144

代码详解

  1. Shrink算法实现

    代码中实现了两种不同的shrink算法:shrink_polygon_pyshrink_polygon_pyclipper

    • shrink_polygon_py:通过计算多边形的中心点,然后将多边形的每个点按照缩放比例向中心点收缩。
    • shrink_polygon_pyclipper:使用pyclipper库进行多边形缩放,计算更为精确,适用于复杂多边形。
  2. MakeShrinkMap类

    MakeShrinkMap类用于将图像中的文本多边形进行shrink操作。类的构造函数接受最小文本尺寸、缩放比例和缩放类型作为参数。__call__方法处理输入数据字典,并生成缩放后的特征图和掩码。

  3. 代码示例

    __main__部分,创建了一个示例图像和文本多边形,并使用MakeShrinkMap类进行shrink操作。结果图像包括原始多边形和缩放后的多边形,并将生成的特征图和掩码保存为图像文件。

应用

Shrink操作在OCR中有广泛的应用,如:

  • 文本检测:通过缩放文本框生成特征图,可以提高文本检测模型的准确性和鲁棒性。
  • 噪声过滤:缩小多边形可以减少背景噪声对检测结果的干扰。
  • 数据增强:生成不同缩放比例的特征图,有助于提升模型的泛化能力。

总结

本文介绍了OCR中shrink操作的实现方法和应用,通过详细的代码示例展示了如何对文本多边形进行缩放。shrink操作在提高OCR模型性能方面具有重要作用,是文本检测和识别过程中不可或缺的一

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号