当前位置:   article > 正文

深度学习——离线数据增强——图片resize

深度学习——离线数据增强——图片resize

因为有一批数据有点儿小,数据质量不佳,为了标注方便使用数据增强将数据固定在1080P,方便标注,

# -*- coding: UTF-8 -*-
"""
@Project :yolov5_relu_fire_smoke_v1.4 
@IDE     :PyCharm 
@Author  :沐枫
@Date    :2024/4/2 20:28

添加白条,做数据增强,最后所有的图片尺寸固定在1080P
"""
import os
import multiprocessing
from concurrent import futures
from copy import deepcopy

import cv2
import numpy as np
import xml.etree.ElementTree as ET
import xml.dom.minidom as minidom


def xyxy2xywh(x):
    """
    Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
    :param x:
    :return:
    """
    y = np.copy(x)
    y[..., 0] = (x[..., 0] + x[..., 2]) / 2  # x center
    y[..., 1] = (x[..., 1] + x[..., 3]) / 2  # y center
    y[..., 2] = x[..., 2] - x[..., 0]  # width
    y[..., 3] = x[..., 3] - x[..., 1]  # height
    return y


def xywh2xyxy(x: np.ndarray):
    """
    Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    :param x:
    :return:
    """
    y = np.copy(x)
    y[..., 0] = x[..., 0] - x[..., 2] / 2  # top left x
    y[..., 1] = x[..., 1] - x[..., 3] / 2  # top left y
    y[..., 2] = x[..., 0] + x[..., 2] / 2  # bottom right x
    y[..., 3] = x[..., 1] + x[..., 3] / 2  # bottom right y
    return y


def decodeVocAnnotation(voc_xml_path, class_index_dict):
    """
    voc数据集格式的文件解析,将一个文件解析成一个list,
    使用空格间隔不同对象
    注意:返回的类别不是整型,而是字符串的类别名称
    注意判断返回值是否为 空,如果是空说明没有目标,是一张背景图
    :param voc_xml_path: xml路径
    :param class_index_dict: 类别字典
    :return: [(cls_index, x1, y1, x2, y2), ...]
    """
    assert voc_xml_path.endswith(".xml"), "voc_xml_path must endswith .xml"

    with open(voc_xml_path, 'r', encoding='utf-8') as xml_file:
        # 打开xml文件,并返回根节点
        root = ET.ElementTree().parse(xml_file)

        # 定义一个列表,专门保存目标
        information = []

        # 查找root节点下所有目标信息
        for obj in root.iter('object'):
            # 目标的名称
            name = obj.find('name').text
            # 目标的bbox坐标,一般voc是保存的corner格式的bbox
            box = obj.find('bndbox')
            xmin = box.find('xmin').text
            ymin = box.find('ymin').text
            xmax = box.find('xmax').text
            ymax = box.find('ymax').text

            # 添加一个目标的信息
            # NOTE:返回值的list
            information.append((class_index_dict[name], int(xmin), int(ymin), int(xmax), int(ymax)))

    return information


def create_voc_xml(image_folder, image_filename, width: int, height: int, labels,
                   save_root, class_name_dict, conf_thresh_dict=None):
    """

    :param image_folder: 图片的相对路径
    :param image_filename: 000001.jpg
    :param width: 图片宽
    :param height: 图片高
    :param labels: 目标框:[[class_index, xmin, ymin, xmax, ymax], ...]
    :param save_root: 保存xml的根目录
    :param class_name_dict: cls_index:cls_name,根据index获取正确的类别name
    :param conf_thresh_dict: cls_index:conf_thresh,根据不同类别设置的阈值获取对应的目标,如果设置为None,则表示保存的txt没有置信度
    :return:
    """
    # 创建 XML 文件的根元素
    root = ET.Element("annotation")

    # 添加图片信息
    folder = ET.SubElement(root, "folder")
    folder.text = str(image_folder)
    # 图片名字
    filename = ET.SubElement(root, "filename")
    filename.text = os.path.join(image_filename)
    # 图片大小
    size = ET.SubElement(root, "size")
    width_element = ET.SubElement(size, "width")
    width_element.text = str(width)
    height_element = ET.SubElement(size, "height")
    height_element.text = str(height)
    depth = ET.SubElement(size, "depth")  # 通道数
    depth.text = "3"

    # 添加目标框信息
    for label in labels:
        # 如果该参数设置为None,表示保存的txt没有None
        if conf_thresh_dict is None:
            # 保证这几项是整数
            class_index, x1, y1, x2, y2 = label.astype(dtype=np.int32)
        else:
            class_index, x1, y1, x2, y2, conf = label
            # 保证这几项是整数
            class_index, x1, y1, x2, y2 = np.array([class_index, x1, y1, x2, y2], dtype=np.int32)

            # 根据置信度过滤是否保存项
            if conf < conf_thresh_dict[class_index]:
                continue

        obj = ET.SubElement(root, "object")

        name = ET.SubElement(obj, "name")
        name.text = class_name_dict[int(class_index)]

        pose = ET.SubElement(obj, "pose")
        pose.text = "Unspecified"

        truncated = ET.SubElement(obj, "truncated")
        truncated.text = "0"

        difficult = ET.SubElement(obj, "difficult")
        difficult.text = "0"

        bndbox = ET.SubElement(obj, "bndbox")
        xmin = ET.SubElement(bndbox, "xmin")
        xmin.text = str(x1)

        ymin = ET.SubElement(bndbox, "ymin")
        ymin.text = str(y1)

        xmax = ET.SubElement(bndbox, "xmax")
        xmax.text = str(x2)

        ymax = ET.SubElement(bndbox, "ymax")
        ymax.text = str(y2)

    # 创建 XML 文件并保存
    xml_str = ET.tostring(root, encoding="utf-8")
    xml_str = minidom.parseString(xml_str)
    # 设置缩进为4个空格,xml可读性提高
    pretty_xml = xml_str.toprettyxml(indent=" " * 4)

    save_path = os.path.join(save_root, f"{os.path.splitext(image_filename)[0]}.xml")
    os.makedirs((os.path.dirname(save_path)), exist_ok=True)
    with open(save_path, "w") as xmlFile:
        xmlFile.write(pretty_xml)


def resize_and_pad(image: np.ndarray, labels: np.ndarray, width=1920, height=1080):
    """

    :param image:
    :param labels: (cls_id, x, y, w, h)
    :param width:
    :param height:
    :return:
    """

    def _resize(image: np.ndarray, labels: np.ndarray, width, height):
        """

        :param image:
        :param labels: (cls_id, x, y, w, h)
        :param width:
        :param height:
        :return:
            image: 最后的图片
            labels: (cls_id, x, y, w, h)
        """
        # 判断图片的尺寸,如果尺寸比目标尺寸大,就等比例缩放,斌使用纯白色填充,如果尺寸比目标尺寸小就直接填充到目标尺寸
        img_h, img_w = image.shape[:2]

        if img_w < width and img_h < height:
            # 直接填充
            # 填充的宽度和高度
            dw = (width - img_w) // 2
            dh = (height - img_h) // 2
            # 创建一个新的蒙版
            new_image = np.ones(shape=(height, width, 3), dtype=np.uint8) * 255
            # 将图片填充到里面
            new_image[dh:dh + img_h, dw:dw + img_w, :] = image[:, :, :]

            # 标签平移,(cls_id, x, y, w, h)
            labels[..., 1] += dw
            labels[..., 2] += dh

        else:
            # 等比例缩放后再填充
            # 计算宽度和高度的缩放比例
            ratio = min((width / img_w), (height / img_h))
            # 计算缩放后的宽度和高度
            new_width = int(img_w * ratio)
            new_height = int(img_h * ratio)

            # 等比例缩放图像
            resized_img = cv2.resize(image, (new_width, new_height))

            # 计算需要填充的宽度和高度
            dw = (width - new_width) // 2
            dh = (height - new_height) // 2

            # 创建一个新的蒙版
            new_image = np.ones(shape=(height, width, 3), dtype=np.uint8) * 255
            # 将图片填充到里面
            new_image[dh:dh + new_height, dw:dw + new_width, :] = resized_img[:, :, :]

            # 标签缩放,平移;(cls_id, x, y, w, h)
            labels[..., 1:] *= ratio  # 坐标和宽高都需要缩放
            # 只有中心点需要平移,不影响宽高
            labels[..., 1] += dw
            labels[..., 2] += dh

        return new_image, labels

    SCALE = 2
    # 原图的宽高
    img_h, img_w = image.shape[:2]

    # NOTE:先在外面扩大一次,写和内部函数,在判断,如果图片比目标尺寸大,就等比例缩放,如果图片比目标尺寸小,就直接填充

    # 比较小,先扩大再等比例缩放;比较大,直接等比例缩放
    if img_w < width and img_h < height:
        new_w = img_w * SCALE
        new_h = img_h * SCALE
        # 图片扩大为原来的2倍
        image = cv2.resize(image, (new_w, new_h))
        # labels也扩大,因为图片扩大2倍,所以目标的中心点和宽高都会扩大同样的倍数
        labels[..., 1:] *= SCALE

    # 缩放和填充
    new_image, new_labels = _resize(image, labels, width=width, height=height)

    return new_image, new_labels


def run(image_path, xml_root,
        image_root, save_image_root, save_xml_root,
        class_index_dict, class_name_dict):
    image_file = os.path.basename(image_path)
    image_name, suffix = os.path.splitext(image_file)
    xml_path = image_path.replace(image_root, xml_root).replace(suffix, ".xml")

    if not os.path.exists(xml_path):
        print(f"\n{image_path} no xml\n")
        return

    try:
        # 读图
        image = cv2.imread(image_path)
        if image is None:
            return

        # (cls_id, x, y, w, h)
        labels = decodeVocAnnotation(xml_path, class_index_dict)
        if len(labels) == 0:
            print(f"\n{image_path} no label\n")
            return
        labels = np.array(labels, dtype=np.float64)
        if labels.ndim < 2:
            labels = labels[None, ...]

        # 坐标框转成xywh
        labels[..., 1:] = xyxy2xywh(labels[..., 1:].copy())
        # resize and pad
        new_image, new_labels = resize_and_pad(image, labels.copy(), width=1920, height=1080)
        new_img_h, new_img_w = new_image.shape[:2]
        # 坐标框转成xyxy
        new_labels[..., 1:] = xywh2xyxy(new_labels[..., 1:].copy())

        # 开始保存
        # save_image_path = image_path.replace(image_root, save_image_root)
        save_image_path = os.path.join(save_image_root,
                                       os.path.basename(os.path.dirname(image_path)),
                                       f"aug_{image_file}")
        save_xml_path = save_image_path.replace(save_image_root, save_xml_root).replace(suffix, ".xml")

        os.makedirs(os.path.dirname(save_image_path), exist_ok=True)
        os.makedirs(os.path.dirname(save_xml_path), exist_ok=True)
        # 保存图片
        cv2.imwrite(save_image_path, new_image)
        # 创建xml
        create_voc_xml(image_folder=save_image_path.replace(save_image_root + os.sep, ""),
                       image_filename=os.path.basename(save_image_path),
                       width=new_img_w,
                       height=new_img_h,
                       labels=np.array(new_labels, dtype=np.int32),
                       save_root=os.path.dirname(save_xml_path),
                       class_name_dict=class_name_dict, )
        print(f"\r{image_path}", end='')

    except Exception as e:
        print(f"{image_path} {run.__name__}:{e}")


def run_process(root_file_list,
                image_root, xml_root, save_image_root, save_xml_root,
                class_index_dict, class_name_dict):
    # 使用线程池控制程序执行
    with futures.ThreadPoolExecutor(max_workers=5) as executor:
        for root, file in root_file_list:
            # 向线程池中提交任务,向线程池中提交任务的时候是一个一个提交的
            image_path = os.path.join(root, file)

            executor.submit(run,
                            *(image_path, xml_root,
                              image_root, save_image_root, save_xml_root,
                              class_index_dict, class_name_dict))


if __name__ == '__main__':
    class_index_dict = {
        "fire": 0,
        "smoke": 1,
    }

    class_name_dict = {
        0: "fire",
        1: "smoke",
    }

    data_root = r"Z:\Datasets\FireSmoke_v4"
    data_root = os.path.abspath(data_root)
    # 数据的原图根目录
    image_root = os.path.join(data_root, "images")
    # xml标注文件根目录
    xml_root = os.path.join(data_root, "annotations")
    # 保存根目录
    save_image_root = os.path.join(image_root, "aug-pad")
    save_xml_root = os.path.join(xml_root, "aug-pad")

    # 过滤点不想用的目录
    exclude_dirs = [
        os.sep + r"background",
        os.sep + r"candle_fire",
        os.sep + r"AUG",
        os.sep + r"smoke",
        os.sep + r"val",
        os.sep + r"aug-merge",
        os.sep + r"candle_fire",
        os.sep + r"cut_aug",
        os.sep + r"miniFire",
        os.sep + r"net",
        os.sep + r"new_data",
        os.sep + r"realScenario",
        os.sep + r"smoke",
        os.sep + r"TSMCandle",
    ]

    max_workers = 10  # 线程/进程 数
    print(f"max_workers:{max_workers}")
    # 一个进程处理多少图片
    max_file_num = 3000
    # 保存root和file的list
    root_file_list = list()
    # 创建进程池,根据自己的设备自行调整,别太多,否则会变慢
    pool = multiprocessing.Pool(processes=max_workers)

    for root, _, files in os.walk(image_root):
        # 只要其中有一个,就跳过
        if any(map(lambda x: x in root, exclude_dirs)):
            continue

        for file in files:
            file_name, suffix = os.path.splitext(file)

            if suffix.lower() not in (".jpg", ".jpeg", ".bmp", ".png"):
                continue

            root_file_list.append((root, file))
            if len(root_file_list) > max_file_num:
                pool.apply_async(run_process,
                                 (deepcopy(root_file_list),
                                  image_root, xml_root, save_image_root, save_xml_root,
                                  class_index_dict, class_name_dict))
                # 清除列表中的存储
                root_file_list.clear()
    else:
        pool.apply_async(run_process,
                         (deepcopy(root_file_list),
                          image_root, xml_root, save_image_root, save_xml_root,
                          class_index_dict, class_name_dict))
        # 清除列表中的存储
        root_file_list.clear()

    # 关闭进程池
    pool.close()
    # 等待所有子进程执行结束
    pool.join()

    print("\nFinish ...")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314
  • 315
  • 316
  • 317
  • 318
  • 319
  • 320
  • 321
  • 322
  • 323
  • 324
  • 325
  • 326
  • 327
  • 328
  • 329
  • 330
  • 331
  • 332
  • 333
  • 334
  • 335
  • 336
  • 337
  • 338
  • 339
  • 340
  • 341
  • 342
  • 343
  • 344
  • 345
  • 346
  • 347
  • 348
  • 349
  • 350
  • 351
  • 352
  • 353
  • 354
  • 355
  • 356
  • 357
  • 358
  • 359
  • 360
  • 361
  • 362
  • 363
  • 364
  • 365
  • 366
  • 367
  • 368
  • 369
  • 370
  • 371
  • 372
  • 373
  • 374
  • 375
  • 376
  • 377
  • 378
  • 379
  • 380
  • 381
  • 382
  • 383
  • 384
  • 385
  • 386
  • 387
  • 388
  • 389
  • 390
  • 391
  • 392
  • 393
  • 394
  • 395
  • 396
  • 397
  • 398
  • 399
  • 400
  • 401
  • 402
  • 403
  • 404
  • 405
  • 406
  • 407
  • 408
  • 409
  • 410
  • 411
  • 412
  • 413
  • 414
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/438566
推荐阅读
  

闽ICP备14008679号