赞
踩
因为有一批数据有点儿小,数据质量不佳,为了标注方便使用数据增强将数据固定在1080P,方便标注,
# -*- coding: UTF-8 -*- """ @Project :yolov5_relu_fire_smoke_v1.4 @IDE :PyCharm @Author :沐枫 @Date :2024/4/2 20:28 添加白条,做数据增强,最后所有的图片尺寸固定在1080P """ import os import multiprocessing from concurrent import futures from copy import deepcopy import cv2 import numpy as np import xml.etree.ElementTree as ET import xml.dom.minidom as minidom def xyxy2xywh(x): """ Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right :param x: :return: """ y = np.copy(x) y[..., 0] = (x[..., 0] + x[..., 2]) / 2 # x center y[..., 1] = (x[..., 1] + x[..., 3]) / 2 # y center y[..., 2] = x[..., 2] - x[..., 0] # width y[..., 3] = x[..., 3] - x[..., 1] # height return y def xywh2xyxy(x: np.ndarray): """ Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right :param x: :return: """ y = np.copy(x) y[..., 0] = x[..., 0] - x[..., 2] / 2 # top left x y[..., 1] = x[..., 1] - x[..., 3] / 2 # top left y y[..., 2] = x[..., 0] + x[..., 2] / 2 # bottom right x y[..., 3] = x[..., 1] + x[..., 3] / 2 # bottom right y return y def decodeVocAnnotation(voc_xml_path, class_index_dict): """ voc数据集格式的文件解析,将一个文件解析成一个list, 使用空格间隔不同对象 注意:返回的类别不是整型,而是字符串的类别名称 注意判断返回值是否为 空,如果是空说明没有目标,是一张背景图 :param voc_xml_path: xml路径 :param class_index_dict: 类别字典 :return: [(cls_index, x1, y1, x2, y2), ...] """ assert voc_xml_path.endswith(".xml"), "voc_xml_path must endswith .xml" with open(voc_xml_path, 'r', encoding='utf-8') as xml_file: # 打开xml文件,并返回根节点 root = ET.ElementTree().parse(xml_file) # 定义一个列表,专门保存目标 information = [] # 查找root节点下所有目标信息 for obj in root.iter('object'): # 目标的名称 name = obj.find('name').text # 目标的bbox坐标,一般voc是保存的corner格式的bbox box = obj.find('bndbox') xmin = box.find('xmin').text ymin = box.find('ymin').text xmax = box.find('xmax').text ymax = box.find('ymax').text # 添加一个目标的信息 # NOTE:返回值的list information.append((class_index_dict[name], int(xmin), int(ymin), int(xmax), int(ymax))) return information def create_voc_xml(image_folder, image_filename, width: int, height: int, labels, save_root, class_name_dict, conf_thresh_dict=None): """ :param image_folder: 图片的相对路径 :param image_filename: 000001.jpg :param width: 图片宽 :param height: 图片高 :param labels: 目标框:[[class_index, xmin, ymin, xmax, ymax], ...] :param save_root: 保存xml的根目录 :param class_name_dict: cls_index:cls_name,根据index获取正确的类别name :param conf_thresh_dict: cls_index:conf_thresh,根据不同类别设置的阈值获取对应的目标,如果设置为None,则表示保存的txt没有置信度 :return: """ # 创建 XML 文件的根元素 root = ET.Element("annotation") # 添加图片信息 folder = ET.SubElement(root, "folder") folder.text = str(image_folder) # 图片名字 filename = ET.SubElement(root, "filename") filename.text = os.path.join(image_filename) # 图片大小 size = ET.SubElement(root, "size") width_element = ET.SubElement(size, "width") width_element.text = str(width) height_element = ET.SubElement(size, "height") height_element.text = str(height) depth = ET.SubElement(size, "depth") # 通道数 depth.text = "3" # 添加目标框信息 for label in labels: # 如果该参数设置为None,表示保存的txt没有None if conf_thresh_dict is None: # 保证这几项是整数 class_index, x1, y1, x2, y2 = label.astype(dtype=np.int32) else: class_index, x1, y1, x2, y2, conf = label # 保证这几项是整数 class_index, x1, y1, x2, y2 = np.array([class_index, x1, y1, x2, y2], dtype=np.int32) # 根据置信度过滤是否保存项 if conf < conf_thresh_dict[class_index]: continue obj = ET.SubElement(root, "object") name = ET.SubElement(obj, "name") name.text = class_name_dict[int(class_index)] pose = ET.SubElement(obj, "pose") pose.text = "Unspecified" truncated = ET.SubElement(obj, "truncated") truncated.text = "0" difficult = ET.SubElement(obj, "difficult") difficult.text = "0" bndbox = ET.SubElement(obj, "bndbox") xmin = ET.SubElement(bndbox, "xmin") xmin.text = str(x1) ymin = ET.SubElement(bndbox, "ymin") ymin.text = str(y1) xmax = ET.SubElement(bndbox, "xmax") xmax.text = str(x2) ymax = ET.SubElement(bndbox, "ymax") ymax.text = str(y2) # 创建 XML 文件并保存 xml_str = ET.tostring(root, encoding="utf-8") xml_str = minidom.parseString(xml_str) # 设置缩进为4个空格,xml可读性提高 pretty_xml = xml_str.toprettyxml(indent=" " * 4) save_path = os.path.join(save_root, f"{os.path.splitext(image_filename)[0]}.xml") os.makedirs((os.path.dirname(save_path)), exist_ok=True) with open(save_path, "w") as xmlFile: xmlFile.write(pretty_xml) def resize_and_pad(image: np.ndarray, labels: np.ndarray, width=1920, height=1080): """ :param image: :param labels: (cls_id, x, y, w, h) :param width: :param height: :return: """ def _resize(image: np.ndarray, labels: np.ndarray, width, height): """ :param image: :param labels: (cls_id, x, y, w, h) :param width: :param height: :return: image: 最后的图片 labels: (cls_id, x, y, w, h) """ # 判断图片的尺寸,如果尺寸比目标尺寸大,就等比例缩放,斌使用纯白色填充,如果尺寸比目标尺寸小就直接填充到目标尺寸 img_h, img_w = image.shape[:2] if img_w < width and img_h < height: # 直接填充 # 填充的宽度和高度 dw = (width - img_w) // 2 dh = (height - img_h) // 2 # 创建一个新的蒙版 new_image = np.ones(shape=(height, width, 3), dtype=np.uint8) * 255 # 将图片填充到里面 new_image[dh:dh + img_h, dw:dw + img_w, :] = image[:, :, :] # 标签平移,(cls_id, x, y, w, h) labels[..., 1] += dw labels[..., 2] += dh else: # 等比例缩放后再填充 # 计算宽度和高度的缩放比例 ratio = min((width / img_w), (height / img_h)) # 计算缩放后的宽度和高度 new_width = int(img_w * ratio) new_height = int(img_h * ratio) # 等比例缩放图像 resized_img = cv2.resize(image, (new_width, new_height)) # 计算需要填充的宽度和高度 dw = (width - new_width) // 2 dh = (height - new_height) // 2 # 创建一个新的蒙版 new_image = np.ones(shape=(height, width, 3), dtype=np.uint8) * 255 # 将图片填充到里面 new_image[dh:dh + new_height, dw:dw + new_width, :] = resized_img[:, :, :] # 标签缩放,平移;(cls_id, x, y, w, h) labels[..., 1:] *= ratio # 坐标和宽高都需要缩放 # 只有中心点需要平移,不影响宽高 labels[..., 1] += dw labels[..., 2] += dh return new_image, labels SCALE = 2 # 原图的宽高 img_h, img_w = image.shape[:2] # NOTE:先在外面扩大一次,写和内部函数,在判断,如果图片比目标尺寸大,就等比例缩放,如果图片比目标尺寸小,就直接填充 # 比较小,先扩大再等比例缩放;比较大,直接等比例缩放 if img_w < width and img_h < height: new_w = img_w * SCALE new_h = img_h * SCALE # 图片扩大为原来的2倍 image = cv2.resize(image, (new_w, new_h)) # labels也扩大,因为图片扩大2倍,所以目标的中心点和宽高都会扩大同样的倍数 labels[..., 1:] *= SCALE # 缩放和填充 new_image, new_labels = _resize(image, labels, width=width, height=height) return new_image, new_labels def run(image_path, xml_root, image_root, save_image_root, save_xml_root, class_index_dict, class_name_dict): image_file = os.path.basename(image_path) image_name, suffix = os.path.splitext(image_file) xml_path = image_path.replace(image_root, xml_root).replace(suffix, ".xml") if not os.path.exists(xml_path): print(f"\n{image_path} no xml\n") return try: # 读图 image = cv2.imread(image_path) if image is None: return # (cls_id, x, y, w, h) labels = decodeVocAnnotation(xml_path, class_index_dict) if len(labels) == 0: print(f"\n{image_path} no label\n") return labels = np.array(labels, dtype=np.float64) if labels.ndim < 2: labels = labels[None, ...] # 坐标框转成xywh labels[..., 1:] = xyxy2xywh(labels[..., 1:].copy()) # resize and pad new_image, new_labels = resize_and_pad(image, labels.copy(), width=1920, height=1080) new_img_h, new_img_w = new_image.shape[:2] # 坐标框转成xyxy new_labels[..., 1:] = xywh2xyxy(new_labels[..., 1:].copy()) # 开始保存 # save_image_path = image_path.replace(image_root, save_image_root) save_image_path = os.path.join(save_image_root, os.path.basename(os.path.dirname(image_path)), f"aug_{image_file}") save_xml_path = save_image_path.replace(save_image_root, save_xml_root).replace(suffix, ".xml") os.makedirs(os.path.dirname(save_image_path), exist_ok=True) os.makedirs(os.path.dirname(save_xml_path), exist_ok=True) # 保存图片 cv2.imwrite(save_image_path, new_image) # 创建xml create_voc_xml(image_folder=save_image_path.replace(save_image_root + os.sep, ""), image_filename=os.path.basename(save_image_path), width=new_img_w, height=new_img_h, labels=np.array(new_labels, dtype=np.int32), save_root=os.path.dirname(save_xml_path), class_name_dict=class_name_dict, ) print(f"\r{image_path}", end='') except Exception as e: print(f"{image_path} {run.__name__}:{e}") def run_process(root_file_list, image_root, xml_root, save_image_root, save_xml_root, class_index_dict, class_name_dict): # 使用线程池控制程序执行 with futures.ThreadPoolExecutor(max_workers=5) as executor: for root, file in root_file_list: # 向线程池中提交任务,向线程池中提交任务的时候是一个一个提交的 image_path = os.path.join(root, file) executor.submit(run, *(image_path, xml_root, image_root, save_image_root, save_xml_root, class_index_dict, class_name_dict)) if __name__ == '__main__': class_index_dict = { "fire": 0, "smoke": 1, } class_name_dict = { 0: "fire", 1: "smoke", } data_root = r"Z:\Datasets\FireSmoke_v4" data_root = os.path.abspath(data_root) # 数据的原图根目录 image_root = os.path.join(data_root, "images") # xml标注文件根目录 xml_root = os.path.join(data_root, "annotations") # 保存根目录 save_image_root = os.path.join(image_root, "aug-pad") save_xml_root = os.path.join(xml_root, "aug-pad") # 过滤点不想用的目录 exclude_dirs = [ os.sep + r"background", os.sep + r"candle_fire", os.sep + r"AUG", os.sep + r"smoke", os.sep + r"val", os.sep + r"aug-merge", os.sep + r"candle_fire", os.sep + r"cut_aug", os.sep + r"miniFire", os.sep + r"net", os.sep + r"new_data", os.sep + r"realScenario", os.sep + r"smoke", os.sep + r"TSMCandle", ] max_workers = 10 # 线程/进程 数 print(f"max_workers:{max_workers}") # 一个进程处理多少图片 max_file_num = 3000 # 保存root和file的list root_file_list = list() # 创建进程池,根据自己的设备自行调整,别太多,否则会变慢 pool = multiprocessing.Pool(processes=max_workers) for root, _, files in os.walk(image_root): # 只要其中有一个,就跳过 if any(map(lambda x: x in root, exclude_dirs)): continue for file in files: file_name, suffix = os.path.splitext(file) if suffix.lower() not in (".jpg", ".jpeg", ".bmp", ".png"): continue root_file_list.append((root, file)) if len(root_file_list) > max_file_num: pool.apply_async(run_process, (deepcopy(root_file_list), image_root, xml_root, save_image_root, save_xml_root, class_index_dict, class_name_dict)) # 清除列表中的存储 root_file_list.clear() else: pool.apply_async(run_process, (deepcopy(root_file_list), image_root, xml_root, save_image_root, save_xml_root, class_index_dict, class_name_dict)) # 清除列表中的存储 root_file_list.clear() # 关闭进程池 pool.close() # 等待所有子进程执行结束 pool.join() print("\nFinish ...")
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。