当前位置:   article > 正文

自定DatasetLoad(数据加载器)以及一些图像增强方法--笔记LOG_load_dataset自定义数据集

load_dataset自定义数据集
  1. #!/usr/bin/env python
  2. # -*- coding: UTF-8 -*-
  3. """
  4. @Project :pt_tf_lea
  5. @Author :Anjou
  6. @Date :2023/5/15 13:38
  7. """
  8. import os
  9. from torch.utils.data import Dataset
  10. from PIL import Image
  11. import numpy as np
  12. import random
  13. import PIL
  14. import torch
  15. from torchvision import transforms as T
  16. from torchvision.transforms import functional as F
  17. def pad_if_smaller(img, size, fill=0):
  18. # 如果图像最小边长小于给定size,则用数值fill进行padding
  19. min_size = min(img.size)
  20. if min_size < size:
  21. ow, oh = img.size
  22. padh = size - oh if oh < size else 0
  23. padw = size - ow if ow < size else 0
  24. img = F.pad(img, (0, 0, padw, padh), fill=fill)
  25. return img
  26. class Compose(object):
  27. # 构建处理图像的transform的处理pipeline
  28. def __init__(self, transforms):
  29. self.transforms = transforms
  30. def __call__(self, image, target):
  31. for t in self.transforms:
  32. image, target = t(image, target)
  33. return image, target
  34. class RandomResize(object):
  35. def __init__(self, min_size, max_size=None):
  36. self.min_size = min_size
  37. if max_size is None:
  38. max_size = min_size
  39. self.max_size = max_size
  40. def __call__(self, image, target):
  41. size = random.randint(self.min_size, self.max_size)
  42. # 这里size传入的是int类型,所以是将图像的最小边长缩放到size大小
  43. image = F.resize(image, size)
  44. # 这里的interpolation注意下,在torchvision(0.9.0)以后才有InterpolationMode.NEAREST
  45. # 如果是之前的版本需要使用PIL.Image.NEAREST
  46. target = F.resize(target, size, interpolation=PIL.Image.NEAREST)
  47. return image, target
  48. class RandomHorizontalFlip(object):
  49. # 随机翻转图像
  50. def __init__(self, flip_prob):
  51. self.flip_prob = flip_prob
  52. def __call__(self, image, target):
  53. if random.random() < self.flip_prob:
  54. image = F.hflip(image)
  55. target = F.hflip(target)
  56. return image, target
  57. class RandomCrop(object):
  58. # 随机裁剪图像
  59. def __init__(self, size):
  60. self.size = size
  61. def __call__(self, image, target):
  62. # 首先要确定所裁剪图像不要小于标准图像
  63. image = pad_if_smaller(image, self.size)
  64. target = pad_if_smaller(target, self.size, fill=255)
  65. # 得到随机裁剪的参数,返回坐标x,y 和 裁剪框的h, w
  66. crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
  67. image = F.crop(image, *crop_params)
  68. target = F.crop(target, *crop_params)
  69. return image, target
  70. class CenterCrop(object):
  71. # 中心裁剪
  72. def __init__(self, size):
  73. self.size = size
  74. def __call__(self, image, target):
  75. image = F.center_crop(image, self.size)
  76. target = F.center_crop(target, self.size)
  77. return image, target
  78. class ToTensor(object):
  79. def __call__(self, image, target):
  80. image = F.to_tensor(image)
  81. target = torch.as_tensor(np.array(target), dtype=torch.int64)
  82. return image, target
  83. class Normalize(object):
  84. # 图像标准化,设定均值和方差,减均值除方差,将数据标准化为正态分布
  85. def __init__(self, mean, std):
  86. self.mean = mean
  87. self.std = std
  88. def __call__(self, image, target):
  89. image = F.normalize(image, mean=self.mean, std=self.std)
  90. return image, target
  91. class loadDataset(Dataset):
  92. "自定义数据集加载器"
  93. def __init__(self, ROOT_IMAGE: str, ROOT_TARGET: str, TRANSFORM=None):
  94. """
  95. :param ROOT_IMAGE: 图像目录
  96. :param ROOT_TARGET: GT目录
  97. :param TRANSFORM: 增广方法
  98. """
  99. self.imagePaths = [os.path.join(ROOT_IMAGE, i) for i in os.listdir(ROOT_IMAGE)]
  100. self.targetPaths = [os.path.join(ROOT_TARGET, i) for i in os.listdir(ROOT_TARGET)]
  101. self.imagePaths.sort() # 对两者排序确认数据对应
  102. self.targetPaths.sort()
  103. self.transform = TRANSFORM
  104. def __getitem__(self, item):
  105. image = Image.open(self.imagePaths[item])
  106. if image.mode is not 'RGB':
  107. raise ValueError(f'{self.imagePaths[item]} is not RGB mode')
  108. target = self.targetPaths[item]
  109. if self.transform:
  110. image = self.transform(image)
  111. return image, target
  112. def __len__(self):
  113. return len(self.imagePaths)
  114. @staticmethod
  115. def collect_fn(batch):
  116. "兼容不同大小图像"
  117. images, targets = list(zip(*batch))
  118. batched_imgs = cat_list(images, fillValue=0)
  119. # 如果为mask,则填充,否则不做处理
  120. # if mask:
  121. # batched_targets = cat_list(targets, fillValue=255)
  122. batched_targets = targets
  123. return batched_imgs, batched_targets
  124. def cat_list(images, fillValue=0):
  125. maxSize = tuple(max(s) for s in zip(*[img.shape for img in images])) # 获取batch图像中最大尺寸的c,h和w
  126. batch_shape = (len(images),) + maxSize # 变为批次维度 NCHW
  127. batched_imgs = images[0].new(*batch_shape).fill_(fillValue) # 创建batch_shape同纬度的mask蒙版
  128. for image, pad_image in zip(images, batched_imgs): # 为蒙版填充原图像,将批次内不同大小的图像统一为最大底图像(相当于填充满边框)
  129. pad_image[..., :image.shape[-2], :image.shape[-1]].copy_(image)
  130. return batched_imgs
  131. class TransformTrain:
  132. def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
  133. max_size = int(1.5 * base_size)
  134. min_size = int(0.5 * base_size)
  135. trans_list = [RandomResize(min_size, max_size)]
  136. if hflip_prob > 0:
  137. trans_list.append(RandomHorizontalFlip(hflip_prob))
  138. trans_list.extend(
  139. [RandomCrop(crop_size),
  140. ToTensor(),
  141. Normalize(mean, std)]
  142. )
  143. self.transforms = T.Compose(trans_list)
  144. def __call__(self, image, target):
  145. return self.transforms(image, target)
  146. class TransformVal:
  147. def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
  148. self.transforms = Compose([
  149. RandomResize(base_size, base_size),
  150. ToTensor(),
  151. Normalize(mean, std)
  152. ])
  153. def __call__(self, img, target):
  154. return self.transforms(img, target)
  155. def get_transform(train):
  156. base_size = 520
  157. crop_size = 480
  158. return TransformTrain(base_size, crop_size) if train else TransformVal(base_size)

以上内容作为备忘,需要的小伙伴自取咯~~~

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/412278
推荐阅读
相关标签
  

闽ICP备14008679号