当前位置:   article > 正文

数据预处理和数据集的设置——以目标检测数据集为例_from draw_box_utils import draw_box

from draw_box_utils import draw_box

数据集

在网上有很多可用的公开的数据集,根据自己的需要,下载相应的数据集,可以用来训练网络,测试网络模型的精度。

[数据集转载来源] 深度学习中的遥感影像数据集

Pascal VOC网址http://host.robots.ox.ac.uk/pascal/VOC/

转载的一篇包含了比较多的数据集的一篇博文,可以参考一下。

但有些时候,我们需要根据我们自己的需求,根据自己的研究方向和类型,设置自己的数据集,以下,简单的阐述了设置数据集的一些步骤。

创建数据集

pytorch中,官方文档简单的介绍了创建数据集的简单步骤。

# ================================================================== #
#                5. Input pipeline for custom dataset                 #
# ================================================================== #

# You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        # TODO
        # 1. Initialize file paths or a list of file names. 
        # 设置文件和标签的路径,或者文件名list,最关键的就是设置好数据集的路径,以及初始化一些数据集的属性
        pass
    def __getitem__(self, index):
        # TODO
        # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
        # 2. Preprocess the data (e.g. torchvision.Transform).
        # 3. Return a data pair (e.g. image and label).
        # 通过上述的数据集路径,读取文件,并且对文件进行预处理操作,返回真实的文件数据,比如image and label
        pass
    def __len__(self):
        # You should change 0 to the total size of your dataset.
        # 比较简单,只是设置数据集的长度,返回一个值
        return 0 

# You can then use the prebuilt data loader. 
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
                                           batch_size=64, 
                                           shuffle=True)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

所以说,最关键的就是初始化文件路径和读取文件,以及文件的预处理。

其他的一些需要用到的属性和方法,在需要的时候加上就行。比如如何进行数据读取、如何进行预处理等。

在实际应用中,创建数据集的基本步骤也大致如此,只需要把相应的方法写全即可,下面以目标检测的数据集为例。

栗子

1.数据准备

首先,我们拿到目标检测的遥感图像,放到一个总的文件夹中。再使用标签工具labelImg进行标注,将标注好的xml标签文件同样放到同一个标签文件夹中。(下图仅为部分数据的截图)

这里有个小问题,就是使用不同的标注工具,得到的bonding box的格式会有不同,在后期读取的时候,可能会报错。

在这里插入图片描述
在这里插入图片描述

以下是图像和标签数据的截图实例:

在这里插入图片描述 在这里插入图片描述

再创建一个类别文件,设置不同的分类的地物名称,以及一个类别对应的JSON文件,不同类别对应不同的key和value。

在这里插入图片描述 在这里插入图片描述
将上述文件都放在同一个文件夹中,再将这些数据随机分成训练集和测试集,代码如下。

import os
import random


def train_val_txt(files_path,val_rate,output_train_path,output_val_path):
    '''
    :param files_path: 保存的所有图片文件的目录
    :param val_rate: 选择测试集相对于总体的比率
    :param output_train_path: 输出的train的filename的txt目录
    :param output_val_path: 输出的val的filename的txt目录
    '''

    if not os.path.exists(files_path):
        print("文件夹不存在")
        exit(1)

    # 获取文件目录下的所有文件名,返回列表格式
    files_name = sorted([file.split('.')[0] for file in os.listdir(files_path)])

    files_num = len(files_name)

    # 设置采样的序号,从[0,files_num] 中随机抽取k个数
    val_index = random.sample(range(0, files_num), k=int(files_num * val_rate))
    train_files = []
    val_files = []
    for index, file_name in enumerate(files_name):
        if index in val_index:
            val_files.append(file_name)
        else:
            train_files.append(file_name)

    try:
        with open(output_train_path,'x') as f:
            f.write('\n'.join(train_files))
        with open(output_val_path, 'x') as f:
            f.write('\n'.join(val_files))
    except Exception as e:
        print(e)
        exit(1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

根据注释,设置路径和分类比,运行后可以得到train.txt和val.txt文件

文本文件中保存着训练集或测试集的样本名称,在后续操作中,直接读取不同的样本名称,就可以加载不同的数据。

最终效果如下:

在这里插入图片描述

这样数据就准备好了。

2.设置数据集

按照官方文档的框架,自定义数据集。

在init中,主要是初始化用户数据集的目录,包括设置标签目录,遥感影像目录,以及预处理。

def __init__(self, data_root, transforms, train=True):
    #设置不同的路径,分别设置成图片路径和标签路径
    self.root = os.path.join(data_root, "data")
    self.img_root = os.path.join(self.root, "JPEGImages")
    self.annotations_root = os.path.join(self.root, "Annotations")

    """读取训练集/测试集,txt_list是路径"""
    if train:
        txt_list = os.path.join(self.root, "ImageSets", "Main", "train_1.txt")
    else:
        txt_list = os.path.join(self.root, "ImageSets", "Main", "val_1.txt")

    with open(txt_list) as read:
        self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                         for line in read.readlines()]

    # 读取分类索引
    try:
        json_file = open('./data/classes.json', 'r')
        self.class_dict = json.load(json_file)
    except Exception as e:
        print(e)
        exit(-1)

    # 定义预处理方式
    self.transforms = transforms 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

len方法主要是返回数据集的个数,即有多少张图像(图像和标签是对应的)。该方法比较简单,直接返回即可。

def __len__(self):
    """返回训练集/测试集中图片的个数"""
    return len(self.xml_list)
  • 1
  • 2
  • 3

在getitem中,传入index,即对不同index的图像和标签进行处理,返回一个image和target(包含boxes、label、image_id等信息)。

对于不同的需求,设置不同的方法,这里只是以目标检测为例,故需要返回image、label和boxes边界框等信息。

def __getitem__(self, idx):
    # read xml
    xml_path = self.xml_list[idx]  # idx是xml_list文件中的索引,通过索引找到第idx个xml文件的路径xml_str
    with open(xml_path) as fid:
        xml_str = fid.read()
    # xml = etree.fromstring(xml_str)
    xml = etree.fromstring(xml_str.encode('utf-8'))  # 读取xml文件的内容
    data = self.parse_xml_to_dict(xml)["annotation"]
    img_path = os.path.join(self.img_root, data["filename"])  # 从xml文件中得到img文件路径
    image = Image.open(img_path)
    if image.format != "JPEG":
        raise ValueError("Image format not JPEG")
    boxes = []
    labels = []
    iscrowd = []  # 是否难检测,crowd为0表示单目标
    for obj in data["object"]:
        """得到训练集边框坐标,分类和难易程度"""
        xmin = float(obj["bndbox"]["xmin"])
        xmax = float(obj["bndbox"]["xmax"])
        ymin = float(obj["bndbox"]["ymin"])
        ymax = float(obj["bndbox"]["ymax"])
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(self.class_dict[obj["name"]])
        iscrowd.append(int(obj["difficult"]))

    # convert everything into a torch.Tensor
    boxes = torch.as_tensor(boxes, dtype=torch.float32)
    labels = torch.as_tensor(labels, dtype=torch.int64)
    iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
    image_id = torch.tensor([idx])  # 当前数据对应的索引值
    area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])  # 框的面积:长*宽

    target = {}
    target["boxes"] = boxes
    target["labels"] = labels
    target["image_id"] = image_id
    target["area"] = area
    target["iscrowd"] = iscrowd

    if self.transforms is not None:
        image, target = self.transforms(image, target)

    return image, target
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43

除了以上三个方法外,我们可以根据自己的需求,增加不同的方法,在数据集设置阶段对数据的处理上,会比之后计算得出的要快一些。

这里增加一个标签索引值的处理方法。

#官方的方法:将标签的索引值存储为字典
def parse_xml_to_dict(self, xml):
    if len(xml) == 0:  # 说明已经遍历到底层,直接返回tag对应的信息
        return {xml.tag: xml.text}

    result = {}
    for child in xml:
        child_result = self.parse_xml_to_dict(child)  # 递归 遍历标签信息
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])
    return {xml.tag: result}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

最终的显示效果如下:

# read class_indict
category_index = {}
try:
    json_file = open('./data/classes.json', 'r')
    class_dict = json.load(json_file)
    category_index = {v: k for k, v in class_dict.items()}
except Exception as e:
    print(e)
    exit(-1)

data_transform = {
    "train": transforms.Compose([transforms.ToTensor(),
                                 transforms.RandomHorizontalFlip(0.5)]),
    "val": transforms.Compose([transforms.ToTensor()])
}

# load train data set
train_data_set = SelfDataSet(os.getcwd(), data_transform["train"], True)
print(len(train_data_set))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

在这里插入图片描述
以及图像的显示:

在这里插入图片描述
测试后,可以加载出图像和train_data_set,即数据集创建成功。

完整的代码示例:

这个案例是以Skysat数据为例设置的数据集,只需要修改图像和标签的路径即可。

from torch.utils.data import Dataset
import os
import torch
import json
from PIL import Image
from lxml import etree

#设置数据集
class SelfDataSet(Dataset):
    # 根目录,预处理方式,训练集/验证集
    def __init__(self, data_root, transforms, train=True):
        #设置不同的路径,分别设置成图片路径和标签路径
        self.root = os.path.join(data_root, "SkysatData")
        self.img_root = os.path.join(self.root, "JPEGImages")
        self.annotations_root = os.path.join(self.root, "Annotations")

        """读取训练集/测试集,txt_list是路径"""
        if train:
            txt_list = os.path.join(self.root, "ImageSets", "Main", "train.txt")
        else:
            txt_list = os.path.join(self.root, "ImageSets", "Main", "val.txt")

        with open(txt_list) as read:
            self.xml_list = [os.path.join(self.annotations_root, line.strip() + ".xml")
                             for line in read.readlines()]

        # 读取分类索引
        try:
            json_file = open('./SkysatData/classex.json', 'r')
            self.class_dict = json.load(json_file)
        except Exception as e:
            print(e)
            exit(-1)

        # 定义预处理方式
        self.transforms = transforms

    def __len__(self):
        """返回训练集/测试集中图片的个数"""
        return len(self.xml_list)

    def __getitem__(self, idx):
        # read xml
        xml_path = self.xml_list[idx]  # idx是xml_list文件中的索引,通过索引找到第idx个xml文件的路径xml_str
        with open(xml_path) as fid:
            xml_str = fid.read()
        # xml = etree.fromstring(xml_str)
        xml = etree.fromstring(xml_str.encode('utf-8'))  # 读取xml文件的内容
        data = self.parse_xml_to_dict(xml)["annotation"]
        img_path = os.path.join(self.img_root, data["filename"])  # 从xml文件中得到img文件路径
        image = Image.open(img_path)
        if image.format != "JPEG":
            raise ValueError("Image format not JPEG")
        boxes = []
        labels = []
        iscrowd = []  # 是否难检测,crowd为0表示单目标
        for obj in data["object"]:
            """得到训练集边框坐标,分类和难易程度"""
            xmin = float(obj["bndbox"]["xmin"])
            xmax = float(obj["bndbox"]["xmax"])
            ymin = float(obj["bndbox"]["ymin"])
            ymax = float(obj["bndbox"]["ymax"])
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(self.class_dict[obj["name"]])
            iscrowd.append(int(obj["difficult"]))

        # convert everything into a torch.Tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        iscrowd = torch.as_tensor(iscrowd, dtype=torch.int64)
        image_id = torch.tensor([idx])  # 当前数据对应的索引值
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])  # 框的面积:长*宽

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target


    def get_height_and_width(self, idx):
        # read xml,每个xml
        xml_path = self.xml_list[idx]
        with open(xml_path) as fid:
            xml_str = fid.read()
        xml = etree.fromstring(xml_str)
        data = self.parse_xml_to_dict(xml)["annotation"]
        data_height = int(data["size"]["height"])
        data_width = int(data["size"]["width"])
        return data_height, data_width

    #官方的方法:将标签的索引值存储为字典
    def parse_xml_to_dict(self, xml):
        if len(xml) == 0:  # 说明已经遍历到底层,直接返回tag对应的信息
            return {xml.tag: xml.text}

        result = {}
        for child in xml:
            child_result = self.parse_xml_to_dict(child)  # 递归 遍历标签信息
            if child.tag != 'object':
                result[child.tag] = child_result[child.tag]
            else:
                if child.tag not in result:  # 因为object可能有多个,所以需要放入列表里
                    result[child.tag] = []
                result[child.tag].append(child_result[child.tag])
        return {xml.tag: result}

    @staticmethod
    def collate_fn(batch):
        return tuple(zip(*batch))

import transforms
from draw_box_utils import draw_box
from PIL import Image
import json
import matplotlib.pyplot as plt
import torchvision.transforms as ts
import random


# read class_indict
category_index = {}
try:
    json_file = open('./SkysatData/classex.json', 'r')
    class_dict = json.load(json_file)
    category_index = {v: k for k, v in class_dict.items()}
except Exception as e:
    print(e)
    exit(-1)

data_transform = {
    "train": transforms.Compose([transforms.ToTensor(),
                                 transforms.RandomHorizontalFlip(0.5)]),
    "val": transforms.Compose([transforms.ToTensor()])
}

# load train data set
train_data_set = SelfDataSet(os.getcwd(), data_transform["train"], True)
print(len(train_data_set))

# index = 40
for index in random.sample(range(0, len(train_data_set)), k=5):
    img, target = train_data_set[index]
    img = ts.ToPILImage()(img)
    draw_box(img,
             target["boxes"].numpy(),
             target["labels"].numpy(),
             [1 for i in range(len(target["labels"].numpy()))],
             category_index,
             thresh=0.5,
             line_thickness=1)
    Image._show(img)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158

效果如下:

在这里插入图片描述

YOLOv5数据集设置

直接上代码了,根据YOLOv5的数据集设置,提取核心的数据集设置代码,代码如下:

import glob
import os
from pathlib import Path

import cv2
import numpy as np
import torch

class SkysatDataset(torch.utils.data.Dataset):
    # 设置基本的文件路径
    def __init__(self, path, imgsz, prefix=''):
        self.path = path
        self.imgsz = imgsz
        # set the file path
        try:
            f = []
            for p in path if isinstance(path, list) else [path]:
                p = Path(p)
                if p.is_dir():
                    f += glob.glob(str(p / '**' / '*.*'), recursive=True)
            self.img_files = sorted([x.replace('/', os.sep) for x in f])
        except Exception as e:
            raise Exception(f'{prefix}Error loading data from {path}: {e}')
        self.label_files = img2label_paths(self.img_files)  # labels
        self.n = len(self.img_files)

    def __len__(self):
        return self.n

    # 通过getitem获得img和label
    def __getitem__(self, index):
        img_path, label_path = self.img_files[index], self.label_files[index]
        img = cv2.imread(img_path)
        label = []
        with open(label_path, 'r') as f:
            for each in f.readlines():
                cls, x, y, w, h = each.replace('\n', '').split(' ')
                label.append([cls,x,y,w,h])
        label = np.array(label).astype(np.float32)
        label = xywh2xyxy(label[:,1:5])*self.imgsz
        return img, label

# 通过img路径得到label路径
def img2label_paths(img_paths):
    # Define label paths as a function of image paths
    sa, sb = os.sep + 'images' + os.sep, os.sep + 'labels' + os.sep  # /images/, /labels/ substrings
    return [sb.join(x.rsplit(sa, 1)).rsplit('.', 1)[0] + '.txt' for x in img_paths]

# 可视化,将坐标改变格式
def xywh2xyxy(x):
    # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
    y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
    y[:, 0] = x[:, 0] - x[:, 2] / 2  # top left x
    y[:, 1] = x[:, 1] - x[:, 3] / 2  # top left y
    y[:, 2] = x[:, 0] + x[:, 2] / 2  # bottom right x
    y[:, 3] = x[:, 1] + x[:, 3] / 2  # bottom right y
    return y

# 可视化操作
def vis(img, boxes):
    for i in range(len(boxes)):
        box = boxes[i]
        x0 = int(box[0])
        y0 = int(box[1])
        x1 = int(box[2])
        y1 = int(box[3])
        cv2.rectangle(img, (x0, y0), (x1, y1), (0, 255, 0), 1)
    return img

if __name__ == '__main__':
    dataset = SkysatDataset(path=r'D:\DATA\Models\customize\YOLOv5-6.0-St\dataset\skysat\images\train', imgsz=512)
    img, label = dataset[2]
    img = vis(img, label)
    cv2.imshow('img', img)
    cv2.waitKey(0)
    cv2.destroyWindow()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76

图片文件和标签存放格式如下:
在这里插入图片描述
label的存储格式如下:(cls, x, y, w, h)并且对x, y, w, h进行了归一化处理

在这里插入图片描述

按照这个方式存放文件,可以得到如下的效果图:

在这里插入图片描述

这个核心代码比较简洁,可以直接使用制作自定义数据集。

本文主要为读书笔记,根据学习资料中的案例,使用自己的例子进行数据集创建,读者仅作参考,如有错误或补充,还请评论批评指正,谢谢!

当然,这只是自定义的一种方式,一般的Github都会有自己的数据集设置方式,按照项目中的修改即可。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/148220
推荐阅读
相关标签
  

闽ICP备14008679号