当前位置:   article > 正文

【Python数据增强】图像数据集扩充

【Python数据增强】图像数据集扩充

 前言:该脚本用于图像数据增强,特别是目标检测任务中的图像和标签数据增强。通过应用一系列数据增强技术(如旋转、平移、裁剪、加噪声、改变亮度、cutout、翻转等),生成多样化的图像数据集,以提高目标检测模型的鲁棒性和准确性。

效果:img存的原始图像168张图片,img2扩充的数量为5040张图片

目录

1.环境准备

2.显示图片函数

3.数据增强类

3.1类初始化

3.2数据增强方法

3.3 数据增强主方法

4.XML解析工具类 

4.1 解析XML

4.2 保存图片 

4.3 保存XML 

5. 主函数

完整程序


1.环境准备

这段代码导入了脚本所需的库,用于图像处理(cv2、numpy)、随机操作(random)、文件操作(os)、XML解析(etree)等。

  1. # -*- coding=utf-8 -*-
  2. import time
  3. import random
  4. import copy
  5. import cv2
  6. import os
  7. import math
  8. import numpy as np
  9. from skimage.util import random_noise
  10. from lxml import etree, objectify
  11. import xml.etree.ElementTree as ET
  12. import argparse

2.显示图片函数

该函数用于显示图片,并在图片上绘制边界框(bounding box)。

  1. def show_pic(img, bboxes=None):
  2. '''
  3. 输入:
  4. img: 图像array
  5. bboxes: 图像的所有bounding box list, 格式为[[x_min, y_min, x_max, y_max]....]
  6. '''
  7. for i in range(len(bboxes)):
  8. bbox = bboxes[i]
  9. x_min = bbox[0]
  10. y_min = bbox[1]
  11. x_max = bbox[2]
  12. y_max = bbox[3]
  13. cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
  14. cv2.namedWindow('pic', 0)
  15. cv2.moveWindow('pic', 0, 0)
  16. cv2.resizeWindow('pic', 1200, 800)
  17. cv2.imshow('pic', img)
  18. cv2.waitKey(0)
  19. cv2.destroyAllWindows()

3.数据增强类

3.1类初始化

该类初始化函数设置了数据增强的各种参数和是否启用某种增强方式的标志。

  1. class DataAugmentForObjectDetection():
  2. def __init__(self, rotation_rate=0.5, max_rotation_angle=5,
  3. crop_rate=0.5, shift_rate=0.5, change_light_rate=0.5,
  4. add_noise_rate=0.5, flip_rate=0.5,
  5. cutout_rate=0.5, cut_out_length=50, cut_out_holes=1, cut_out_threshold=0.5,
  6. is_addNoise=True, is_changeLight=True, is_cutout=True, is_rotate_img_bbox=True,
  7. is_crop_img_bboxes=True, is_shift_pic_bboxes=True, is_filp_pic_bboxes=True):
  8. self.rotation_rate = rotation_rate
  9. self.max_rotation_angle = max_rotation_angle
  10. self.crop_rate = crop_rate
  11. self.shift_rate = shift_rate
  12. self.change_light_rate = change_light_rate
  13. self.add_noise_rate = add_noise_rate
  14. self.flip_rate = flip_rate
  15. self.cutout_rate = cutout_rate
  16. self.cut_out_length = cut_out_length
  17. self.cut_out_holes = cut_out_holes
  18. self.cut_out_threshold = cut_out_threshold
  19. self.is_addNoise = is_addNoise
  20. self.is_changeLight = is_changeLight
  21. self.is_cutout = is_cutout
  22. self.is_rotate_img_bbox = is_rotate_img_bbox
  23. self.is_crop_img_bboxes = is_crop_img_bboxes
  24. self.is_shift_pic_bboxes = is_shift_pic_bboxes
  25. self.is_filp_pic_bboxes = is_filp_pic_bboxes

3.2数据增强方法

加噪声。为图像添加高斯噪声。

  1. def _addNoise(self, img):
  2. return random_noise(img, mode='gaussian', clip=True) * 255

改变亮度。随机改变图像亮度。

  1. def _changeLight(self, img):
  2. alpha = random.uniform(0.35, 1)
  3. blank = np.zeros(img.shape, img.dtype)
  4. return cv2.addWeighted(img, alpha, blank, 1 - alpha, 0)

cutout。随机在图像中遮挡某些部分(cutout),避免遮挡太多目标。

  1. def _cutout(self, img, bboxes, length=100, n_holes=1, threshold=0.5):
  2. def cal_iou(boxA, boxB):
  3. xA = max(boxA[0], boxB[0])
  4. yA = max(boxA[1], boxB[1])
  5. xB = min(boxA[2], boxB[2])
  6. yB = min(boxA[3], boxB[3])
  7. if xB <= xA or yB <= yA:
  8. return 0.0
  9. interArea = (xB - xA + 1) * (yB - yA + 1)
  10. boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
  11. boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
  12. iou = interArea / float(boxBArea)
  13. return iou
  14. if img.ndim == 3:
  15. h, w, c = img.shape
  16. else:
  17. _, h, w, c = img.shape
  18. mask = np.ones((h, w, c), np.float32)
  19. for n in range(n_holes):
  20. chongdie = True
  21. while chongdie:
  22. y = np.random.randint(h)
  23. x = np.random.randint(w)
  24. y1 = np.clip(y - length // 2, 0, h)
  25. y2 = np.clip(y + length // 2, 0, h)
  26. x1 = np.clip(x - length // 2, 0, w)
  27. x2 = np.clip(x + length // 2, 0, w)
  28. chongdie = False
  29. for box in bboxes:
  30. if cal_iou([x1, y1, x2, y2], box) > threshold:
  31. chongdie = True
  32. break
  33. mask[y1: y2, x1: x2, :] = 0.
  34. img = img * mask
  35. return img

旋转。旋转图像和对应的边界框。

  1. def _rotate_img_bbox(self, img, bboxes, angle=5, scale=1.):
  2. w, h = img.shape[1], img.shape[0]
  3. rangle = np.deg2rad(angle)
  4. nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) * scale
  5. nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) * scale
  6. rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
  7. rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
  8. rot_mat[0, 2] += rot_move[0]
  9. rot_mat[1, 2] += rot_move[1]
  10. rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
  11. rot_bboxes = []
  12. for bbox in bboxes:
  13. points = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]])
  14. new_points = cv2.transform(points[None, :, :], rot_mat)[0]
  15. rx, ry, rw, rh = cv2.boundingRect(new_points)
  16. corrected_bbox = [max(0, rx), max(0, ry), min(nw, rx + rw), min(nh, ry + rh)]
  17. corrected_bbox = [int(val) for val in corrected_bbox]
  18. rot_bboxes.append(corrected_bbox)
  19. return rot_img, rot_bboxes

裁剪。随机裁剪图像,同时裁剪对应的边界框。

  1. def _crop_img_bboxes(self, img, bboxes):
  2. w = img.shape[1]
  3. h = img.shape[0]
  4. x_min = w
  5. x_max = 0
  6. y_min = h
  7. y_max = 0
  8. for bbox in bboxes:
  9. x_min = min(x_min, bbox[0])
  10. y_min = min(y_min, bbox[1])
  11. x_max = max(x_max, bbox[2])
  12. y_max = max(y_max, bbox[3])
  13. d_to_left = x_min
  14. d_to_right = w - x_max
  15. d_to_top = y_min
  16. d_to_bottom = h - y_max
  17. crop_x_min = int(x_min - random.uniform(0, d_to_left))
  18. crop_y_min = int(y_min - random.uniform(0, d_to_top))
  19. crop_x_max = int(x_max + random.uniform(0, d_to_right))
  20. crop_y_max = int(y_max + random.uniform(0, d_to_bottom))
  21. crop_x_min = max(0, crop_x_min)
  22. crop_y_min = max(0, crop_y_min)
  23. crop_x_max = min(w, crop_x_max)
  24. crop_y_max = min(h, crop_y_max)
  25. crop_img = img[crop_y_min:crop_y_max, crop_x_min:crop_x_max]
  26. crop_bboxes = list()
  27. for bbox in bboxes:
  28. crop_bboxes.append([bbox[0] - crop_x_min, bbox[1] - crop_y_min, bbox[2] - crop_x_min, bbox[3] - crop_y_min])
  29. return crop_img, crop_bboxes

平移。随机平移图像和对应的边界框。

  1. def _shift_pic_bboxes(self, img, bboxes):
  2. h, w = img.shape[:2]
  3. x = random.uniform(-w * 0.2, w * 0.2)
  4. y = random.uniform(-h * 0.2, h * 0.2)
  5. M = np.float32([[1, 0, x], [0, 1, y]])
  6. shift_img = cv2.warpAffine(img, M, (w, h))
  7. shift_bboxes = []
  8. for bbox in bboxes:
  9. new_bbox = [bbox[0] + x, bbox[1] + y, bbox[2] + x, bbox[3] + y]
  10. corrected_bbox = [max(0, new_bbox[0]), max(0, new_bbox[1]), min(w, new_bbox[2]), min(h, new_bbox[3])]
  11. corrected_bbox = [int(val) for val in corrected_bbox]
  12. shift_bboxes.append(corrected_bbox)
  13. return shift_img, shift_bboxes

 翻转。随机翻转图像和对应的边界框。

  1. def _filp_pic_bboxes(self, img, bboxes):
  2. flipCode = random.choice([-1, 0, 1])
  3. flip_img = cv2.flip(img, flipCode)
  4. h, w, _ = img.shape
  5. flip_bboxes = []
  6. for bbox in bboxes:
  7. x_min, y_min, x_max, y_max = bbox
  8. if flipCode == 0:
  9. new_bbox = [x_min, h - y_max, x_max, h - y_min]
  10. elif flipCode == 1:
  11. new_bbox = [w - x_max, y_min, w - x_min, y_max]
  12. else:
  13. new_bbox = [w - x_max, h - y_max, w - x_min, h - y_min]
  14. flip_bboxes.append(new_bbox)
  15. return flip_img, flip_bboxes

3.3 数据增强主方法

综合应用各种数据增强方法,对输入图像和边界框进行增强。

  1. def dataAugment(self, img, bboxes):
  2. change_num = 0
  3. while change_num < 1:
  4. if self.is_rotate_img_bbox:
  5. if random.random() > self.rotation_rate:
  6. change_num += 1
  7. angle = random.uniform(-self.max_rotation_angle, self.max_rotation_angle)
  8. scale = random.uniform(0.7, 0.8)
  9. img, bboxes = self._rotate_img_bbox(img, bboxes, angle, scale)
  10. if self.is_shift_pic_bboxes:
  11. if random.random() < self.shift_rate:
  12. change_num += 1
  13. img, bboxes = self._shift_pic_bboxes(img, bboxes)
  14. if self.is_changeLight:
  15. if random.random() > self.change_light_rate:
  16. change_num += 1
  17. img = self._changeLight(img)
  18. if self.is_addNoise:
  19. if random.random() < self.add_noise_rate:
  20. change_num += 1
  21. img = self._addNoise(img)
  22. if self.is_cutout:
  23. if random.random() < self.cutout_rate:
  24. change_num += 1
  25. img = self._cutout(img, bboxes, length=self.cut_out_length, n_holes=self.cut_out_holes,
  26. threshold=self.cut_out_threshold)
  27. if self.is_filp_pic_bboxes:
  28. if random.random() < self.flip_rate:
  29. change_num += 1
  30. img, bboxes = self._filp_pic_bboxes(img, bboxes)
  31. return img, bboxes

4.XML解析工具类 

4.1 解析XML

从XML文件中提取边界框信息。

  1. class ToolHelper():
  2. def parse_xml(self, path):
  3. tree = ET.parse(path)
  4. root = tree.getroot()
  5. objs = root.findall('object')
  6. coords = list()
  7. for ix, obj in enumerate(objs):
  8. name = obj.find('name').text
  9. box = obj.find('bndbox')
  10. x_min = int(box[0].text)
  11. y_min = int(box[1].text)
  12. x_max = int(box[2].text)
  13. y_max = int(box[3].text)
  14. coords.append([x_min, y_min, x_max, y_max, name])
  15. return coords

4.2 保存图片 

保存增强后的图片。

  1. def save_img(self, file_name, save_folder, img):
  2. cv2.imwrite(os.path.join(save_folder, file_name), img)

4.3 保存XML 

保存增强后的XML文件。

  1. def save_xml(self, file_name, save_folder, img_info, height, width, channel, bboxs_info):
  2. folder_name, img_name = img_info
  3. E = objectify.ElementMaker(annotate=False)
  4. anno_tree = E.annotation(
  5. E.folder(folder_name),
  6. E.filename(img_name),
  7. E.path(os.path.join(folder_name, img_name)),
  8. E.source(
  9. E.database('Unknown'),
  10. ),
  11. E.size(
  12. E.width(width),
  13. E.height(height),
  14. E.depth(channel)
  15. ),
  16. E.segmented(0),
  17. )
  18. labels, bboxs = bboxs_info
  19. for label, box in zip(labels, bboxs):
  20. anno_tree.append(
  21. E.object(
  22. E.name(label),
  23. E.pose('Unspecified'),
  24. E.truncated('0'),
  25. E.difficult('0'),
  26. E.bndbox(
  27. E.xmin(box[0]),
  28. E.ymin(box[1]),
  29. E.xmax(box[2]),
  30. E.ymax(box[3])
  31. )
  32. ))
  33. etree.ElementTree(anno_tree).write(os.path.join(save_folder, file_name), pretty_print=True)

5. 主函数

首先新建几个文件夹,修改主函数里相应的文件路径,即可。

  • img 用于存放自己手里已有的数据集图片
  • img2 用于存放增强后的数据集图片
  • xml 用于存放自己手里已有的数据集图片对应的标签(这里必须是VOC格式)
  • xml2 用于存放增强后的数据集图片对应的标签
  • txt 用于存放将xml2中的voc格式的标签转换成txt格式(yolov识别txt格式的标签)

修改每个图片的增强次数即可决定增强图片的数量。 

主函数:

  • 解析命令行参数,获取图片和XML文件路径。
  • 创建保存路径文件夹(如果不存在)。
  • 遍历源图片路径,读取图片和对应的XML文件。
  • 应用数据增强,保存增强后的图片和XML文件。
  1. if __name__ == '__main__':
  2. need_aug_num = 30 # 每张图片需要增强的次数
  3. is_endwidth_dot = True # 文件是否以.jpg或者png结尾
  4. dataAug = DataAugmentForObjectDetection() # 数据增强工具类
  5. toolhelper = ToolHelper() # 工具
  6. # 获取相关参数
  7. parser = argparse.ArgumentParser()
  8. parser.add_argument('--source_img_path', type=str, default='D:/lenovo/Archie/shujukuochongv1.0/img')
  9. parser.add_argument('--source_xml_path', type=str, default='D:/lenovo/Archie/shujukuochongv1.0/xml')
  10. parser.add_argument('--save_img_path', type=str, default='D:/lenovo/Archie/shujukuochongv1.0/img2')
  11. parser.add_argument('--save_xml_path', type=str, default='D:/lenovo/Archie/shujukuochongv1.0/xml2')
  12. args = parser.parse_args()
  13. source_img_path = args.source_img_path # 图片原始位置
  14. source_xml_path = args.source_xml_path # xml的原始位置
  15. save_img_path = args.save_img_path # 图片增强结果保存文件
  16. save_xml_path = args.save_xml_path # xml增强结果保存文件
  17. if not os.path.exists(save_img_path):
  18. os.mkdir(save_img_path)
  19. if not os.path.exists(save_xml_path):
  20. os.mkdir(save_xml_path)
  21. for parent, _, files in os.walk(source_img_path):
  22. files.sort()
  23. for file in files:
  24. cnt = 0
  25. pic_path = os.path.join(parent, file)
  26. xml_path = os.path.join(source_xml_path, file[:-4] + '.xml')
  27. values = toolhelper.parse_xml(xml_path)
  28. coords = [v[:4] for v in values]
  29. labels = [v[-1] for v in values]
  30. if is_endwidth_dot:
  31. dot_index = file.rfind('.')
  32. _file_prefix = file[:dot_index]
  33. _file_suffix = file[dot_index:]
  34. img = cv2.imread(pic_path)
  35. while cnt < need_aug_num:
  36. auged_img, auged_bboxes = dataAug.dataAugment(img, coords)
  37. auged_bboxes_int = np.array(auged_bboxes).astype(np.int32)
  38. height, width, channel = auged_img.shape
  39. img_name = '{}_{}{}'.format(_file_prefix, cnt + 1, _file_suffix)
  40. tool

完整程序

该脚本用于对图像数据进行各种数据增强操作,并保存增强后的图像和标签数据。通过这些增强操作,可以生成大量多样化的训练数据,提升目标检测模型的鲁棒性和准确性。

  1. # -*- coding=utf-8 -*-
  2. import time
  3. import random
  4. import copy
  5. import cv2
  6. import os
  7. import math
  8. import numpy as np
  9. from skimage.util import random_noise
  10. from lxml import etree, objectify
  11. import xml.etree.ElementTree as ET
  12. import argparse
  13. # 显示图片
  14. def show_pic(img, bboxes=None):
  15. '''
  16. 输入:
  17. img:图像array
  18. bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....]
  19. names:每个box对应的名称
  20. '''
  21. for i in range(len(bboxes)):
  22. bbox = bboxes[i]
  23. x_min = bbox[0]
  24. y_min = bbox[1]
  25. x_max = bbox[2]
  26. y_max = bbox[3]
  27. cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
  28. cv2.namedWindow('pic', 0) # 1表示原图
  29. cv2.moveWindow('pic', 0, 0)
  30. cv2.resizeWindow('pic', 1200, 800) # 可视化的图片大小
  31. cv2.imshow('pic', img)
  32. cv2.waitKey(0)
  33. cv2.destroyAllWindows()
  34. # 图像均为cv2读取
  35. class DataAugmentForObjectDetection():
  36. def __init__(self, rotation_rate=0.5, max_rotation_angle=5,
  37. crop_rate=0.5, shift_rate=0.5, change_light_rate=0.5,
  38. add_noise_rate=0.5, flip_rate=0.5,
  39. cutout_rate=0.5, cut_out_length=50, cut_out_holes=1, cut_out_threshold=0.5,
  40. is_addNoise=True, is_changeLight=True, is_cutout=True, is_rotate_img_bbox=True,
  41. is_crop_img_bboxes=True, is_shift_pic_bboxes=True, is_filp_pic_bboxes=True):
  42. # 配置各个操作的属性
  43. self.rotation_rate = rotation_rate
  44. self.max_rotation_angle = max_rotation_angle
  45. self.crop_rate = crop_rate
  46. self.shift_rate = shift_rate
  47. self.change_light_rate = change_light_rate
  48. self.add_noise_rate = add_noise_rate
  49. self.flip_rate = flip_rate
  50. self.cutout_rate = cutout_rate
  51. self.cut_out_length = cut_out_length
  52. self.cut_out_holes = cut_out_holes
  53. self.cut_out_threshold = cut_out_threshold
  54. # 是否使用某种增强方式
  55. self.is_addNoise = is_addNoise
  56. self.is_changeLight = is_changeLight
  57. self.is_cutout = is_cutout
  58. self.is_rotate_img_bbox = is_rotate_img_bbox
  59. self.is_crop_img_bboxes = is_crop_img_bboxes
  60. self.is_shift_pic_bboxes = is_shift_pic_bboxes
  61. self.is_filp_pic_bboxes = is_filp_pic_bboxes
  62. # ----1.加噪声---- #
  63. def _addNoise(self, img):
  64. '''
  65. 输入:
  66. img:图像array
  67. 输出:
  68. 加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255
  69. '''
  70. # return cv2.GaussianBlur(img, (11, 11), 0)
  71. return random_noise(img, mode='gaussian', clip=True) * 255
  72. # ---2.调整亮度--- #
  73. def _changeLight(self, img):
  74. alpha = random.uniform(0.35, 1)
  75. blank = np.zeros(img.shape, img.dtype)
  76. return cv2.addWeighted(img, alpha, blank, 1 - alpha, 0)
  77. # ---3.cutout--- #
  78. def _cutout(self, img, bboxes, length=100, n_holes=1, threshold=0.5):
  79. '''
  80. 原版本:https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
  81. Randomly mask out one or more patches from an image.
  82. Args:
  83. img : a 3D numpy array,(h,w,c)
  84. bboxes : 框的坐标
  85. n_holes (int): Number of patches to cut out of each image.
  86. length (int): The length (in pixels) of each square patch.
  87. '''
  88. def cal_iou(boxA, boxB):
  89. '''
  90. boxA, boxB为两个框,返回iou
  91. boxB为bouding box
  92. '''
  93. # determine the (x, y)-coordinates of the intersection rectangle
  94. xA = max(boxA[0], boxB[0])
  95. yA = max(boxA[1], boxB[1])
  96. xB = min(boxA[2], boxB[2])
  97. yB = min(boxA[3], boxB[3])
  98. if xB <= xA or yB <= yA:
  99. return 0.0
  100. # compute the area of intersection rectangle
  101. interArea = (xB - xA + 1) * (yB - yA + 1)
  102. # compute the area of both the prediction and ground-truth
  103. # rectangles
  104. boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
  105. boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
  106. iou = interArea / float(boxBArea)
  107. return iou
  108. # 得到h和w
  109. if img.ndim == 3:
  110. h, w, c = img.shape
  111. else:
  112. _, h, w, c = img.shape
  113. mask = np.ones((h, w, c), np.float32)
  114. for n in range(n_holes):
  115. chongdie = True # 看切割的区域是否与box重叠太多
  116. while chongdie:
  117. y = np.random.randint(h)
  118. x = np.random.randint(w)
  119. y1 = np.clip(y - length // 2, 0,
  120. h) # numpy.clip(a, a_min, a_max, out=None), clip这个函数将将数组中的元素限制在a_min, a_max之间,大于a_max的就使得它等于 a_max,小于a_min,的就使得它等于a_min
  121. y2 = np.clip(y + length // 2, 0, h)
  122. x1 = np.clip(x - length // 2, 0, w)
  123. x2 = np.clip(x + length // 2, 0, w)
  124. chongdie = False
  125. for box in bboxes:
  126. if cal_iou([x1, y1, x2, y2], box) > threshold:
  127. chongdie = True
  128. break
  129. mask[y1: y2, x1: x2, :] = 0.
  130. img = img * mask
  131. return img
  132. # ---4.旋转--- #
  133. def _rotate_img_bbox(self, img, bboxes, angle=5, scale=1.):
  134. w, h = img.shape[1], img.shape[0]
  135. rangle = np.deg2rad(angle) # angle in radians
  136. nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) * scale
  137. nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) * scale
  138. rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
  139. rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
  140. rot_mat[0, 2] += rot_move[0]
  141. rot_mat[1, 2] += rot_move[1]
  142. rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)
  143. rot_bboxes = []
  144. for bbox in bboxes:
  145. points = np.array([[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[3]], [bbox[0], bbox[3]]])
  146. new_points = cv2.transform(points[None, :, :], rot_mat)[0]
  147. rx, ry, rw, rh = cv2.boundingRect(new_points)
  148. corrected_bbox = [max(0, rx), max(0, ry), min(nw, rx + rw), min(nh, ry + rh)]
  149. corrected_bbox = [int(val) for val in corrected_bbox] # Convert to int and correct order if necessary
  150. rot_bboxes.append(corrected_bbox)
  151. return rot_img, rot_bboxes
  152. # ---5.裁剪--- #
  153. def _crop_img_bboxes(self, img, bboxes):
  154. '''
  155. 裁剪后的图片要包含所有的框
  156. 输入:
  157. img:图像array
  158. bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
  159. 输出:
  160. crop_img:裁剪后的图像array
  161. crop_bboxes:裁剪后的bounding box的坐标list
  162. '''
  163. # 裁剪图像
  164. w = img.shape[1]
  165. h = img.shape[0]
  166. x_min = w # 裁剪后的包含所有目标框的最小的框
  167. x_max = 0
  168. y_min = h
  169. y_max = 0
  170. for bbox in bboxes:
  171. x_min = min(x_min, bbox[0])
  172. y_min = min(y_min, bbox[1])
  173. x_max = max(x_max, bbox[2])
  174. y_max = max(y_max, bbox[3])
  175. d_to_left = x_min # 包含所有目标框的最小框到左边的距离
  176. d_to_right = w - x_max # 包含所有目标框的最小框到右边的距离
  177. d_to_top = y_min # 包含所有目标框的最小框到顶端的距离
  178. d_to_bottom = h - y_max # 包含所有目标框的最小框到底部的距离
  179. # 随机扩展这个最小框
  180. crop_x_min = int(x_min - random.uniform(0, d_to_left))
  181. crop_y_min = int(y_min - random.uniform(0, d_to_top))
  182. crop_x_max = int(x_max + random.uniform(0, d_to_right))
  183. crop_y_max = int(y_max + random.uniform(0, d_to_bottom))
  184. # 随机扩展这个最小框 , 防止别裁的太小
  185. # crop_x_min = int(x_min - random.uniform(d_to_left//2, d_to_left))
  186. # crop_y_min = int(y_min - random.uniform(d_to_top//2, d_to_top))
  187. # crop_x_max = int(x_max + random.uniform(d_to_right//2, d_to_right))
  188. # crop_y_max = int(y_max + random.uniform(d_to_bottom//2, d_to_bottom))
  189. # 确保不要越界
  190. crop_x_min = max(0, crop_x_min)
  191. crop_y_min = max(0, crop_y_min)
  192. crop_x_max = min(w, crop_x_max)
  193. crop_y_max = min(h, crop_y_max)
  194. crop_img = img[crop_y_min:crop_y_max, crop_x_min:crop_x_max]
  195. # 裁剪boundingbox
  196. # 裁剪后的boundingbox坐标计算
  197. crop_bboxes = list()
  198. for bbox in bboxes:
  199. crop_bboxes.append([bbox[0] - crop_x_min, bbox[1] - crop_y_min, bbox[2] - crop_x_min, bbox[3] - crop_y_min])
  200. return crop_img, crop_bboxes
  201. # ---6.平移--- #
  202. def _shift_pic_bboxes(self, img, bboxes):
  203. h, w = img.shape[:2]
  204. x = random.uniform(-w * 0.2, w * 0.2)
  205. y = random.uniform(-h * 0.2, h * 0.2)
  206. M = np.float32([[1, 0, x], [0, 1, y]])
  207. shift_img = cv2.warpAffine(img, M, (w, h))
  208. shift_bboxes = []
  209. for bbox in bboxes:
  210. new_bbox = [bbox[0] + x, bbox[1] + y, bbox[2] + x, bbox[3] + y]
  211. corrected_bbox = [max(0, new_bbox[0]), max(0, new_bbox[1]), min(w, new_bbox[2]), min(h, new_bbox[3])]
  212. corrected_bbox = [int(val) for val in corrected_bbox] # Convert to int and correct order if necessary
  213. shift_bboxes.append(corrected_bbox)
  214. return shift_img, shift_bboxes
  215. # ---7.镜像--- #
  216. def _filp_pic_bboxes(self, img, bboxes):
  217. # Randomly decide the flip method
  218. flipCode = random.choice([-1, 0, 1]) # -1: both; 0: vertical; 1: horizontal
  219. flip_img = cv2.flip(img, flipCode) # Apply the flip
  220. h, w, _ = img.shape
  221. flip_bboxes = []
  222. for bbox in bboxes:
  223. x_min, y_min, x_max, y_max = bbox
  224. if flipCode == 0: # Vertical flip
  225. new_bbox = [x_min, h - y_max, x_max, h - y_min]
  226. elif flipCode == 1: # Horizontal flip
  227. new_bbox = [w - x_max, y_min, w - x_min, y_max]
  228. else: # Both flips
  229. new_bbox = [w - x_max, h - y_max, w - x_min, h - y_min]
  230. flip_bboxes.append(new_bbox)
  231. return flip_img, flip_bboxes
  232. # 图像增强方法
  233. def dataAugment(self, img, bboxes):
  234. '''
  235. 图像增强
  236. 输入:
  237. img:图像array
  238. bboxes:该图像的所有框坐标
  239. 输出:
  240. img:增强后的图像
  241. bboxes:增强后图片对应的box
  242. '''
  243. change_num = 0 # 改变的次数
  244. # print('------')
  245. while change_num < 1: # 默认至少有一种数据增强生效
  246. if self.is_rotate_img_bbox:
  247. if random.random() > self.rotation_rate: # 旋转
  248. change_num += 1
  249. angle = random.uniform(-self.max_rotation_angle, self.max_rotation_angle)
  250. scale = random.uniform(0.7, 0.8)
  251. img, bboxes = self._rotate_img_bbox(img, bboxes, angle, scale)
  252. if self.is_shift_pic_bboxes:
  253. if random.random() < self.shift_rate: # 平移
  254. change_num += 1
  255. img, bboxes = self._shift_pic_bboxes(img, bboxes)
  256. if self.is_changeLight:
  257. if random.random() > self.change_light_rate: # 改变亮度
  258. change_num += 1
  259. img = self._changeLight(img)
  260. if self.is_addNoise:
  261. if random.random() < self.add_noise_rate: # 加噪声
  262. change_num += 1
  263. img = self._addNoise(img)
  264. if self.is_cutout:
  265. if random.random() < self.cutout_rate: # cutout
  266. change_num += 1
  267. img = self._cutout(img, bboxes, length=self.cut_out_length, n_holes=self.cut_out_holes,
  268. threshold=self.cut_out_threshold)
  269. if self.is_filp_pic_bboxes:
  270. if random.random() < self.flip_rate: # 翻转
  271. change_num += 1
  272. img, bboxes = self._filp_pic_bboxes(img, bboxes)
  273. return img, bboxes
  274. # xml解析工具
  275. class ToolHelper():
  276. # 从xml文件中提取bounding box信息, 格式为[[x_min, y_min, x_max, y_max, name]]
  277. def parse_xml(self, path):
  278. '''
  279. 输入:
  280. xml_path: xml的文件路径
  281. 输出:
  282. 从xml文件中提取bounding box信息, 格式为[[x_min, y_min, x_max, y_max, name]]
  283. '''
  284. tree = ET.parse(path)
  285. root = tree.getroot()
  286. objs = root.findall('object')
  287. coords = list()
  288. for ix, obj in enumerate(objs):
  289. name = obj.find('name').text
  290. box = obj.find('bndbox')
  291. x_min = int(box[0].text)
  292. y_min = int(box[1].text)
  293. x_max = int(box[2].text)
  294. y_max = int(box[3].text)
  295. coords.append([x_min, y_min, x_max, y_max, name])
  296. return coords
  297. # 保存图片结果
  298. def save_img(self, file_name, save_folder, img):
  299. cv2.imwrite(os.path.join(save_folder, file_name), img)
  300. # 保持xml结果
  301. def save_xml(self, file_name, save_folder, img_info, height, width, channel, bboxs_info):
  302. '''
  303. :param file_name:文件名
  304. :param save_folder:#保存的xml文件的结果
  305. :param height:图片的信息
  306. :param width:图片的宽度
  307. :param channel:通道
  308. :return:
  309. '''
  310. folder_name, img_name = img_info # 得到图片的信息
  311. E = objectify.ElementMaker(annotate=False)
  312. anno_tree = E.annotation(
  313. E.folder(folder_name),
  314. E.filename(img_name),
  315. E.path(os.path.join(folder_name, img_name)),
  316. E.source(
  317. E.database('Unknown'),
  318. ),
  319. E.size(
  320. E.width(width),
  321. E.height(height),
  322. E.depth(channel)
  323. ),
  324. E.segmented(0),
  325. )
  326. labels, bboxs = bboxs_info # 得到边框和标签信息
  327. for label, box in zip(labels, bboxs):
  328. anno_tree.append(
  329. E.object(
  330. E.name(label),
  331. E.pose('Unspecified'),
  332. E.truncated('0'),
  333. E.difficult('0'),
  334. E.bndbox(
  335. E.xmin(box[0]),
  336. E.ymin(box[1]),
  337. E.xmax(box[2]),
  338. E.ymax(box[3])
  339. )
  340. ))
  341. etree.ElementTree(anno_tree).write(os.path.join(save_folder, file_name), pretty_print=True)
  342. if __name__ == '__main__':
  343. need_aug_num = 30 # 每张图片需要增强的次数
  344. is_endwidth_dot = True # 文件是否以.jpg或者png结尾
  345. dataAug = DataAugmentForObjectDetection() # 数据增强工具类
  346. toolhelper = ToolHelper() # 工具
  347. # 获取相关参数
  348. parser = argparse.ArgumentParser()
  349. parser.add_argument('--source_img_path', type=str, default='D:/lenovo/Archie/shujukuochongv1.0/img')
  350. parser.add_argument('--source_xml_path', type=str, default='D:/lenovo/Archie/shujukuochongv1.0/xml')
  351. parser.add_argument('--save_img_path', type=str, default='D:/lenovo/Archie/shujukuochongv1.0/img2')
  352. parser.add_argument('--save_xml_path', type=str, default='D:/lenovo/Archie/shujukuochongv1.0/xml2')
  353. args = parser.parse_args()
  354. source_img_path = args.source_img_path # 图片原始位置
  355. source_xml_path = args.source_xml_path # xml的原始位置
  356. save_img_path = args.save_img_path # 图片增强结果保存文件
  357. save_xml_path = args.save_xml_path # xml增强结果保存文件
  358. # 如果保存文件夹不存在就创建
  359. if not os.path.exists(save_img_path):
  360. os.mkdir(save_img_path)
  361. if not os.path.exists(save_xml_path):
  362. os.mkdir(save_xml_path)
  363. for parent, _, files in os.walk(source_img_path):
  364. files.sort()
  365. for file in files:
  366. cnt = 0
  367. pic_path = os.path.join(parent, file)
  368. xml_path = os.path.join(source_xml_path, file[:-4] + '.xml')
  369. values = toolhelper.parse_xml(xml_path) # 解析得到box信息,格式为[[x_min,y_min,x_max,y_max,name]]
  370. coords = [v[:4] for v in values] # 得到框
  371. labels = [v[-1] for v in values] # 对象的标签
  372. # 如果图片是有后缀的
  373. if is_endwidth_dot:
  374. # 找到文件的最后名字
  375. dot_index = file.rfind('.')
  376. _file_prefix = file[:dot_index] # 文件名的前缀
  377. _file_suffix = file[dot_index:] # 文件名的后缀
  378. img = cv2.imread(pic_path)
  379. # show_pic(img, coords) # 显示原图
  380. while cnt < need_aug_num: # 继续增强
  381. auged_img, auged_bboxes = dataAug.dataAugment(img, coords)
  382. auged_bboxes_int = np.array(auged_bboxes).astype(np.int32)
  383. height, width, channel = auged_img.shape # 得到图片的属性
  384. img_name = '{}_{}{}'.format(_file_prefix, cnt + 1, _file_suffix) # 图片保存的信息
  385. toolhelper.save_img(img_name, save_img_path,
  386. auged_img) # 保存增强图片
  387. toolhelper.save_xml('{}_{}.xml'.format(_file_prefix, cnt + 1),
  388. save_xml_path, (save_img_path, img_name), height, width, channel,
  389. (labels, auged_bboxes_int)) # 保存xml文件
  390. # show_pic(auged_img, auged_bboxes) # 强化后的图
  391. print(img_name)
  392. cnt += 1 # 继续增强下一张
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/代码探险家/article/detail/884714
推荐阅读
相关标签
  

闽ICP备14008679号