当前位置:   article > 正文

[Pytorch]将自己的数据集载入dataloader_dataloader如何加载本地数据

dataloader如何加载本地数据

一、概述

        初始化DataLoader类时必须注入一个参数dataset,而dataset为自己定义。DataSet类可以继承,但是必须重载__len__()__getitem__

        使用Pytoch封装的DataLoader有以下好处:

                ①可以自动实现多进程加载

                ②自动惰性加载,不会占用过多内存

                ③封装有数据预处理和数据增强等操作,避免重复造轮子

二、自定义DataSet

        以Faster R-CNN为例,一般建议至少传入以下参数,方便后续使用:

  1. class FRCNNDataset(Dataset):
  2. def __init__(self, annotation_lines, input_shape = [600, 600], train = True):
  3. self.annotation_lines = annotation_lines #数据集列表
  4. self.length = len(annotation_lines) #数据集大小
  5. self.input_shape = input_shape #输出尺寸
  6. self.train = train #是否训练

        然后重载__len__()__getitem__

  1. def __len__(self):
  2. return self.length #直接返回长度
  1. def __getitem__(self, index):
  2. index = index % self.length
  3. #训练时候对数据进行随机增强,但验证时不进行
  4. image, y = self.get_random_data(self.annotation_lines[index], self.input_shape[0:2], random = self.train)
  5. #将图片转换成矩阵
  6. image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
  7. #编码先验框
  8. box_data = np.zeros((len(y), 5))
  9. if len(y) > 0:
  10. box_data[:len(y)] = y
  11. box = box_data[:, :4]
  12. label = box_data[:, -1]
  13. return image, box, label

        关于数据增强函数get_random_data(),其中还包含了对图片的无变形缩放功能

  1. def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
  2. # 数据经过处理后格式为:地址——(空格)——预测框,使用split函数即可切割出地址和先验框
  3. line = annotation_line.split()
  4. # 读取图像并转换为RGB格式
  5. image = Image.open(line[0])
  6. image = cvtColor(image)
  7. # 获得图像的高宽与目标高宽
  8. iw, ih = image.size
  9. h, w = input_shape
  10. # 读取先验框
  11. box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])

                仅缩放的无变形缩放功(非训练模式)

  1. # 在不进行随机数据增强的情况下(非训练模式),直接变形后输出
  2. if not random:
  3. #获取变形比例
  4. scale = min(w/iw, h/ih)
  5. nw = int(iw*scale)
  6. nh = int(ih*scale)
  7. dx = (w-nw)//2
  8. dy = (h-nh)//2
  9. # 将图像多余的部分加上灰条
  10. image = image.resize((nw,nh), Image.BICUBIC)
  11. new_image = Image.new('RGB', (w,h), (128,128,128))
  12. new_image.paste(image, (dx, dy))
  13. image_data = np.array(new_image, np.float32)
  14. # 对真实框进行调整
  15. if len(box)>0:
  16. np.random.shuffle(box)
  17. box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
  18. box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
  19. box[:, 0:2][box[:, 0:2]<0] = 0
  20. box[:, 2][box[:, 2]>w] = w
  21. box[:, 3][box[:, 3]>h] = h
  22. box_w = box[:, 2] - box[:, 0]
  23. box_h = box[:, 3] - box[:, 1]
  24. box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
  25. #返回图片和先验框
  26. return image_data, box

                带数据增强的无变形缩放(训练模式)

  1. # 对图像进行缩放并且进行长和宽的扭曲
  2. new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
  3. scale = self.rand(.25, 2)
  4. if new_ar < 1:
  5. nh = int(scale*h)
  6. nw = int(nh*new_ar)
  7. else:
  8. nw = int(scale*w)
  9. nh = int(nw/new_ar)
  10. image = image.resize((nw,nh), Image.BICUBIC)
  11. # 将图像多余的部分加上灰条
  12. dx = int(self.rand(0, w-nw))
  13. dy = int(self.rand(0, h-nh))
  14. new_image = Image.new('RGB', (w,h), (128,128,128))
  15. new_image.paste(image, (dx, dy))
  16. image = new_image
  17. # 翻转图像
  18. flip = self.rand()<.5
  19. if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
  20. image_data = np.array(image, np.uint8)
  21. # 对图像进行色域变换
  22. # 计算色域变换的参数
  23. r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
  24. # 将图像转到HSV上
  25. hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
  26. dtype = image_data.dtype
  27. # 应用变换
  28. x = np.arange(0, 256, dtype=r.dtype)
  29. lut_hue = ((x * r[0]) % 180).astype(dtype)
  30. lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
  31. lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
  32. image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
  33. image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
  34. # 对真实框进行调整
  35. if len(box)>0:
  36. np.random.shuffle(box)
  37. box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
  38. box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
  39. if flip: box[:, [0,2]] = w - box[:, [2,0]]
  40. box[:, 0:2][box[:, 0:2]<0] = 0
  41. box[:, 2][box[:, 2]>w] = w
  42. box[:, 3][box[:, 3]>h] = h
  43. box_w = box[:, 2] - box[:, 0]
  44. box_h = box[:, 3] - box[:, 1]
  45. box = box[np.logical_and(box_w>1, box_h>1)]
  46. return image_data, box

                关于collate_fn参数

                        __getitem__一般返回(image,label)样本对,而DataLoader需要一个batch_size用于处理batch样本,以便于批量训练。

                        默认的default_collate(batch)函数仅能对尺寸一致且batch_size相同的image进行整理,如将(img0,lbl0),(img1,lbl1),(img2,lbl2)整合为([img0,img1,img2],[lbl0,lbl1,lbl2]),如图像中含有box等参数则需要自定义处理

  1. def frcnn_dataset_collate(batch):
  2. images = []
  3. bboxes = []
  4. labels = []
  5. for img, box, label in batch:
  6. images.append(img)
  7. bboxes.append(box)
  8. labels.append(label)
  9. images = torch.from_numpy(np.array(images))
  10. return images, bboxes, labels

三、语义分割与目标检测DataSet的区别

        ①在__getitem__中不需要获取box值,转而获取标志图png

  1. def __getitem__(self, index):
  2. annotation_line = self.annotation_lines[index]
  3. name = annotation_line.split()[0]
  4. # 从文件中读取图像
  5. jpg = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/JPEGImages"), name + ".jpg"))
  6. png = Image.open(os.path.join(os.path.join(self.dataset_path, "VOC2007/SegmentationClass"), name + ".png"))
  7. # 数据增强
  8. jpg, png = self.get_random_data(jpg, png, self.input_shape, random = self.train)
  9. jpg = np.transpose(preprocess_input(np.array(jpg, np.float64)), [2,0,1])
  10. png = np.array(png)
  11. png[png >= self.num_classes] = self.num_classes
  12. # 转化成one_hot的形式
  13. # 在这里需要+1是因为voc数据集有些标签具有白边部分
  14. seg_labels = np.eye(self.num_classes + 1)[png.reshape([-1])]
  15. seg_labels = seg_labels.reshape((int(self.input_shape[0]), int(self.input_shape[1]), self.num_classes + 1))
  16. return jpg, png, seg_labels

        ②get_random_data变形时需要对两张图做同样的变换

  1. if not random:
  2. iw, ih = image.size
  3. scale = min(w/iw, h/ih)
  4. nw = int(iw*scale)
  5. nh = int(ih*scale)
  6. image = image.resize((nw,nh), Image.BICUBIC)
  7. new_image = Image.new('RGB', [w, h], (128,128,128))
  8. new_image.paste(image, ((w-nw)//2, (h-nh)//2))
  9. label = label.resize((nw,nh), Image.NEAREST)
  10. new_label = Image.new('L', [w, h], (0))
  11. new_label.paste(label, ((w-nw)//2, (h-nh)//2))
  12. return new_image, new_label

        ③collate_fn需要进行修改

  1. def deeplab_dataset_collate(batch):
  2. images = []
  3. pngs = []
  4. seg_labels = []
  5. for img, png, labels in batch:
  6. images.append(img)
  7. pngs.append(png)
  8. seg_labels.append(labels)
  9. images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
  10. pngs = torch.from_numpy(np.array(pngs)).long()
  11. seg_labels = torch.from_numpy(np.array(seg_labels)).type(torch.FloatTensor)
  12. return images, pngs, seg_labels

四、在训练过程中的调用

         ①读取文件集(经处理的txt文件)

  1. with open(train_annotation_path, encoding='utf-8') as f:
  2. train_lines = f.readlines()
  3. with open(val_annotation_path, encoding='utf-8') as f:
  4. val_lines = f.readlines()
  5. #获取数据集长度
  6. num_train = len(train_lines)
  7. num_val = len(val_lines)

        ②检查数据集是否符合要求

                这里一般检查数据集是否足够大,也可不检查

        ③将数据集装入DataSet中

  1. train_dataset = MyDataset(train_lines, input_shape, anchors, batch_size, num_classes, train = True)
  2. val_dataset = MyDataset(val_lines, input_shape, anchors, batch_size, num_classes, train = False)

        ④将DataSet放入DataLoader中

                关于dataloader:一般有以下5个参数:

                        1.dataset:数据集对象,dataset型

                        2.batch_size:批大小,int型

                        3.shuffe:每一轮epoch是否重新洗牌,bool型

                        4.num_workers:多进程读取

                        5.drop_last:当样本不能被batch_size取整时,是否丢弃最后一批数据,bool型

  1. gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
  2. drop_last=True, collate_fn=ssd_dataset_collate, sampler=train_sampler)
  3. gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
  4. drop_last=True, collate_fn=ssd_dataset_collate, sampler=val_sampler)

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号