赞
踩
在网上有很多可用的公开的数据集,根据自己的需要,下载相应的数据集,可以用来训练网络,测试网络模型的精度。
[数据集转载来源] 深度学习中的遥感影像数据集
Pascal VOC网址:http://host.robots.ox.ac.uk/pascal/VOC/
转载的一篇包含了比较多的数据集的一篇博文,可以参考一下。
但有些时候,我们需要根据我们自己的需求,根据自己的研究方向和类型,设置自己的数据集,以下,简单的阐述了设置数据集的一些步骤。
在pytorch中,官方文档简单的介绍了创建数据集的简单步骤。
# ================================================================== # # 5. Input pipeline for custom dataset # # ================================================================== # # You should build your custom dataset as below. class CustomDataset(torch.utils.data.Dataset): def __init__(self): # TODO # 1. Initialize file paths or a list of file names. # 设置文件和标签的路径,或者文件名list,最关键的就是设置好数据集的路径,以及初始化一些数据集的属性 pass def __getitem__(self, index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). # 通过上述的数据集路径,读取文件,并且对文件进行预处理操作,返回真实的文件数据,比如image and label pass def __len__(self): # You should change 0 to the total size of your dataset. # 比较简单,只是设置数据集的长度,返回一个值 return 0 # You can then use the prebuilt data loader. custom_dataset = CustomDataset() train_loader = torch.utils.data.DataLoader(dataset=custom_dataset, batch_size=64, shuffle=True)
所以说,最关键的就是初始化文件路径和读取文件,以及文件的预处理。
其他的一些需要用到的属性和方法,在需要的时候加上就行。比如如何进行数据读取、如何进行预处理等。
在实际应用中,创建数据集的基本步骤也大致如此,只需要把相应的方法写全即可,下面以目标检测的数据集为例。
1.数据准备
首先,我们拿到目标检测的遥感图像,放到一个总的文件夹中。再使用标签工具labelImg进行标注,将标注好的xml标签文件同样放到同一个标签文件夹中。(下图仅为部分数据的截图)
这里有个小问题,就是使用不同的标注工具,得到的bonding box的格式会有不同,在后期读取的时候,可能会报错。
以下是图像和标签数据的截图实例:
再创建一个类别文件,设置不同的分类的地物名称,以及一个类别对应的JSON文件,不同类别对应不同的key和value。
将上述文件都放在同一个文件夹中,再将这些数据随机分成训练集和测试集,代码如下。
import os import random def train_val_txt(files_path,val_rate,output_train_path,output_val_path): ''' :param files_path: 保存的所有图片文件的目录 :param val_rate: 选择测试集相对于总体的比率 :param output_train_path: 输出的train的filename的txt目录 :param output_val_path: 输出的val的filename的txt目录 ''' if not os.path.exists(files_path): print("文件夹不存在") exit(1) # 获取文件目录下的所有文件名,返回列表格式 files_name = sorted([file.split('.')[0] for file in os.listdir(files_path)]) files_num = len(files_name) # 设置采样的序号,从[0,files_num] 中随机抽取k个数 val_index = random.sample(range(0, files_num), k=int(files_num * val_rate)) train_files = [] val_files = [] for index, file_name in enumerate(files_name): if index in val_index: val_files.append(file_name) else: train_files.append(file_name) try: with open(output_train_path,'x') as f: f.write('\n'.join(train_files)) with open(output_val_path, 'x') as f: f.write('\n'.join(val_files)) except Exception as e: print(e) exit(1)
根据注释,设置路径和分类比,运行后可以得到train.txt和val.txt文件。
文本文件中保存着训练集或测试集的样本名称,在后续操作中,直接读取不同的样本名称,就可以加载不同的数据。
最终效果如下:
这样数据就准备好了。
2.设置数据集
按照官方文档的框架,自定义数据集。
在init中,主要是初始化用户数据集的目录,包括设置标签目录,遥感影像目录,以及预处理。
def __init__(self, data_root, transforms, train=True): #设置不同的路径,分别设置成图片路径和标签路径 self.root = os.path.join(data_root, "data") self.img_root = os.path.join(self.root, "JPEGImages") self.annotations_root = os.path.join(self.root, "Annotations") """读取训练集/测试集,txt_list是路径""" if train: txt_list = os.path.join(self.root, "ImageSets", "Main", "train_1.txt") else: txt_list = os.path.join(self.root, "ImageSets", "Main", "val_1.txt") with open(txt_list) as read: self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml") for line in read.readlines()] # 读取分类索引 try: json_file = open('./data/classes.json', 'r') self.class_dict = json.load(json_file) except Exception as e: print(e) exit(-1) # 定义预处理方式 self.transforms = transforms
len方法主要是返回数据集的个数,即有多少张图像(图像和标签是对应的)。该方法比较简单,直接返回即可。
def __len__(self):
"""返回训练集/测试集中图片的个数"""
return len(self.xml_list)
在getitem中,传入index,即对不同index的图像和标签进行处理,返回一个image和target(包含boxes、label、image_id等信息)。
对于不同的需求,设置不同的方法,这里只是以目标检测为例,故需要返回image、label和boxes边界框等信息。
def __getitem__(self, idx): # read xml xml_path = self.xml_list[idx] # idx是xml_list文件中的索引,通过索引找到第idx个xml文件的路径xml_str with open(xml_path) as fid: xml_str = fid.read() # xml = etree.fromstring(xml_str) xml = etree.fromstring(xml_str.encode('utf-8')) # 读取xml文件的内容 data = self.parse_xml_to_dict(xml)["annotation"] img_path = os.path.join(self.img_root, data["filename"]) # 从xml文件中得到img文件路径 image = Image.open(img_path) if image.format != "JPEG": raise ValueError("Image format not JPEG") boxes = [] labels = [] iscrowd = [] # 是否难检测,crowd为0表示单目标 for obj in data["object"]: """得到训练集边框坐标,分类和难易程度""" xmin = float(obj["bndbox"]["xmin"]) xmax = float(obj["bndbox"]["xmax"]) ymin = float(obj["bndbox"]["ymin"]) ymax = float(obj["bndbox"]["ymax"]) boxes.append([xmin, ymin, xmax, ymax]) labels.append(self.class_dict[obj["name"]]) iscrowd.append(int(obj["difficult"])) # convert everything into a torch.Tensor boxes = torch.as_tensor(boxes, dtype=torch.float32) labels = torch.as_tensor(labels, dtype=torch.int64) iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64) image_id = torch.tensor([idx]) # 当前数据对应的索引值 area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # 框的面积:长*宽 target = {} target["boxes"] = boxes target["labels"] = labels target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd if self.transforms is not None: image, target = self.transforms(image, target) return image, target
除了以上三个方法外,我们可以根据自己的需求,增加不同的方法,在数据集设置阶段对数据的处理上,会比之后计算得出的要快一些。
这里增加一个标签索引值的处理方法。
#官方的方法:将标签的索引值存储为字典
def parse_xml_to_dict(self, xml):
if len(xml) == 0: # 说明已经遍历到底层,直接返回tag对应的信息
return {xml.tag: xml.text}
result = {}
for child in xml:
child_result = self.parse_xml_to_dict(child) # 递归 遍历标签信息
if child.tag != 'object':
result[child.tag] = child_result[child.tag]
else:
if child.tag not in result: # 因为object可能有多个,所以需要放入列表里
result[child.tag] = []
result[child.tag].append(child_result[child.tag])
return {xml.tag: result}
最终的显示效果如下:
# read class_indict category_index = {} try: json_file = open('./data/classes.json', 'r') class_dict = json.load(json_file) category_index = {v: k for k, v in class_dict.items()} except Exception as e: print(e) exit(-1) data_transform = { "train": transforms.Compose([transforms.ToTensor(), transforms.RandomHorizontalFlip(0.5)]), "val": transforms.Compose([transforms.ToTensor()]) } # load train data set train_data_set = SelfDataSet(os.getcwd(), data_transform["train"], True) print(len(train_data_set))
以及图像的显示:
测试后,可以加载出图像和train_data_set,即数据集创建成功。
完整的代码示例:
这个案例是以Skysat数据为例设置的数据集,只需要修改图像和标签的路径即可。
from torch.utils.data import Dataset import os import torch import json from PIL import Image from lxml import etree #设置数据集 class SelfDataSet(Dataset): # 根目录,预处理方式,训练集/验证集 def __init__(self, data_root, transforms, train=True): #设置不同的路径,分别设置成图片路径和标签路径 self.root = os.path.join(data_root, "SkysatData") self.img_root = os.path.join(self.root, "JPEGImages") self.annotations_root = os.path.join(self.root, "Annotations") """读取训练集/测试集,txt_list是路径""" if train: txt_list = os.path.join(self.root, "ImageSets", "Main", "train.txt") else: txt_list = os.path.join(self.root, "ImageSets", "Main", "val.txt") with open(txt_list) as read: self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml") for line in read.readlines()] # 读取分类索引 try: json_file = open('./SkysatData/classex.json', 'r') self.class_dict = json.load(json_file) except Exception as e: print(e) exit(-1) # 定义预处理方式 self.transforms = transforms def __len__(self): """返回训练集/测试集中图片的个数""" return len(self.xml_list) def __getitem__(self, idx): # read xml xml_path = self.xml_list[idx] # idx是xml_list文件中的索引,通过索引找到第idx个xml文件的路径xml_str with open(xml_path) as fid: xml_str = fid.read() # xml = etree.fromstring(xml_str) xml = etree.fromstring(xml_str.encode('utf-8')) # 读取xml文件的内容 data = self.parse_xml_to_dict(xml)["annotation"] img_path = os.path.join(self.img_root, data["filename"]) # 从xml文件中得到img文件路径 image = Image.open(img_path) if image.format != "JPEG": raise ValueError("Image format not JPEG") boxes = [] labels = [] iscrowd = [] # 是否难检测,crowd为0表示单目标 for obj in data["object"]: """得到训练集边框坐标,分类和难易程度""" xmin = float(obj["bndbox"]["xmin"]) xmax = float(obj["bndbox"]["xmax"]) ymin = float(obj["bndbox"]["ymin"]) ymax = float(obj["bndbox"]["ymax"]) boxes.append([xmin, ymin, xmax, ymax]) labels.append(self.class_dict[obj["name"]]) iscrowd.append(int(obj["difficult"])) # convert everything into a torch.Tensor boxes = torch.as_tensor(boxes, dtype=torch.float32) labels = torch.as_tensor(labels, dtype=torch.int64) iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64) image_id = torch.tensor([idx]) # 当前数据对应的索引值 area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) # 框的面积:长*宽 target = {} target["boxes"] = boxes target["labels"] = labels target["image_id"] = image_id target["area"] = area target["iscrowd"] = iscrowd if self.transforms is not None: image, target = self.transforms(image, target) return image, target def get_height_and_width(self, idx): # read xml,每个xml xml_path = self.xml_list[idx] with open(xml_path) as fid: xml_str = fid.read() xml = etree.fromstring(xml_str) data = self.parse_xml_to_dict(xml)["annotation"] data_height = int(data["size"]["height"]) data_width = int(data["size"]["width"]) return data_height, data_width #官方的方法:将标签的索引值存储为字典 def parse_xml_to_dict(self, xml): if len(xml) == 0: # 说明已经遍历到底层,直接返回tag对应的信息 return {xml.tag: xml.text} result = {} for child in xml: child_result = self.parse_xml_to_dict(child) # 递归 遍历标签信息 if child.tag != 'object': result[child.tag] = child_result[child.tag] else: if child.tag not in result: # 因为object可能有多个,所以需要放入列表里 result[child.tag] = [] result[child.tag].append(child_result[child.tag]) return {xml.tag: result} @staticmethod def collate_fn(batch): return tuple(zip(*batch)) import transforms from draw_box_utils import draw_box from PIL import Image import json import matplotlib.pyplot as plt import torchvision.transforms as ts import random # read class_indict category_index = {} try: json_file = open('./SkysatData/classex.json', 'r') class_dict = json.load(json_file) category_index = {v: k for k, v in class_dict.items()} except Exception as e: print(e) exit(-1) data_transform = { "train": transforms.Compose([transforms.ToTensor(), transforms.RandomHorizontalFlip(0.5)]), "val": transforms.Compose([transforms.ToTensor()]) } # load train data set train_data_set = SelfDataSet(os.getcwd(), data_transform["train"], True) print(len(train_data_set)) # index = 40 for index in random.sample(range(0, len(train_data_set)), k=5): img, target = train_data_set[index] img = ts.ToPILImage()(img) draw_box(img, target["boxes"].numpy(), target["labels"].numpy(), [1 for i in range(len(target["labels"].numpy()))], category_index, thresh=0.5, line_thickness=1) Image._show(img)
效果如下:
直接上代码了,根据YOLOv5的数据集设置,提取核心的数据集设置代码,代码如下:
import glob import os from pathlib import Path import cv2 import numpy as np import torch class SkysatDataset(torch.utils.data.Dataset): # 设置基本的文件路径 def __init__(self, path, imgsz, prefix=''): self.path = path self.imgsz = imgsz # set the file path try: f = [] for p in path if isinstance(path, list) else [path]: p = Path(p) if p.is_dir(): f += glob.glob(str(p / '**' / '*.*'), recursive=True) self.img_files = sorted([x.replace('/', os.sep) for x in f]) except Exception as e: raise Exception(f'{prefix}Error loading data from {path}: {e}') self.label_files = img2label_paths(self.img_files) # labels self.n = len(self.img_files) def __len__(self): return self.n # 通过getitem获得img和label def __getitem__(self, index): img_path, label_path = self.img_files[index], self.label_files[index] img = cv2.imread(img_path) label = [] with open(label_path, 'r') as f: for each in f.readlines(): cls, x, y, w, h = each.replace('\n', '').split(' ') label.append([cls,x,y,w,h]) label = np.array(label).astype(np.float32) label = xywh2xyxy(label[:,1:5])*self.imgsz return img, label # 通过img路径得到label路径 def img2label_paths(img_paths): # Define label paths as a function of image paths sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep # /images/, /labels/ substrings return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths] # 可视化,将坐标改变格式 def xywh2xyxy(x): # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x) y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y return y # 可视化操作 def vis(img, boxes): for i in range(len(boxes)): box = boxes[i] x0 = int(box[0]) y0 = int(box[1]) x1 = int(box[2]) y1 = int(box[3]) cv2.rectangle(img, (x0, y0), (x1, y1), (0, 255, 0), 1) return img if __name__ == '__main__': dataset = SkysatDataset(path=r'D:\DATA\Models\customize\YOLOv5-6.0-St\dataset\skysat\images\train', imgsz=512) img, label = dataset[2] img = vis(img, label) cv2.imshow('img', img) cv2.waitKey(0) cv2.destroyWindow()
图片文件和标签存放格式如下:
label的存储格式如下:(cls, x, y, w, h)并且对x, y, w, h进行了归一化处理
按照这个方式存放文件,可以得到如下的效果图:
这个核心代码比较简洁,可以直接使用制作自定义数据集。
本文主要为读书笔记,根据学习资料中的案例,使用自己的例子进行数据集创建,读者仅作参考,如有错误或补充,还请评论批评指正,谢谢!
当然,这只是自定义的一种方式,一般的Github都会有自己的数据集设置方式,按照项目中的修改即可。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。