赞
踩
- #!/usr/bin/env python
- # -*- coding: UTF-8 -*-
- """
- @Project :pt_tf_lea
- @Author :Anjou
- @Date :2023/5/15 13:38
- """
- import os
- from torch.utils.data import Dataset
- from PIL import Image
- import numpy as np
- import random
- import PIL
- import torch
- from torchvision import transforms as T
- from torchvision.transforms import functional as F
-
-
- def pad_if_smaller(img, size, fill=0):
- # 如果图像最小边长小于给定size,则用数值fill进行padding
- min_size = min(img.size)
- if min_size < size:
- ow, oh = img.size
- padh = size - oh if oh < size else 0
- padw = size - ow if ow < size else 0
- img = F.pad(img, (0, 0, padw, padh), fill=fill)
- return img
-
-
- class Compose(object):
- # 构建处理图像的transform的处理pipeline
- def __init__(self, transforms):
- self.transforms = transforms
-
- def __call__(self, image, target):
- for t in self.transforms:
- image, target = t(image, target)
- return image, target
-
-
- class RandomResize(object):
- def __init__(self, min_size, max_size=None):
- self.min_size = min_size
- if max_size is None:
- max_size = min_size
- self.max_size = max_size
-
- def __call__(self, image, target):
- size = random.randint(self.min_size, self.max_size)
- # 这里size传入的是int类型,所以是将图像的最小边长缩放到size大小
- image = F.resize(image, size)
- # 这里的interpolation注意下,在torchvision(0.9.0)以后才有InterpolationMode.NEAREST
- # 如果是之前的版本需要使用PIL.Image.NEAREST
- target = F.resize(target, size, interpolation=PIL.Image.NEAREST)
- return image, target
-
-
- class RandomHorizontalFlip(object):
- # 随机翻转图像
- def __init__(self, flip_prob):
- self.flip_prob = flip_prob
-
- def __call__(self, image, target):
- if random.random() < self.flip_prob:
- image = F.hflip(image)
- target = F.hflip(target)
- return image, target
-
-
- class RandomCrop(object):
- # 随机裁剪图像
- def __init__(self, size):
- self.size = size
-
- def __call__(self, image, target):
- # 首先要确定所裁剪图像不要小于标准图像
- image = pad_if_smaller(image, self.size)
- target = pad_if_smaller(target, self.size, fill=255)
- # 得到随机裁剪的参数,返回坐标x,y 和 裁剪框的h, w
- crop_params = T.RandomCrop.get_params(image, (self.size, self.size))
- image = F.crop(image, *crop_params)
- target = F.crop(target, *crop_params)
- return image, target
-
-
- class CenterCrop(object):
- # 中心裁剪
- def __init__(self, size):
- self.size = size
-
- def __call__(self, image, target):
- image = F.center_crop(image, self.size)
- target = F.center_crop(target, self.size)
- return image, target
-
-
- class ToTensor(object):
- def __call__(self, image, target):
- image = F.to_tensor(image)
- target = torch.as_tensor(np.array(target), dtype=torch.int64)
- return image, target
-
-
- class Normalize(object):
- # 图像标准化,设定均值和方差,减均值除方差,将数据标准化为正态分布
- def __init__(self, mean, std):
- self.mean = mean
- self.std = std
-
- def __call__(self, image, target):
- image = F.normalize(image, mean=self.mean, std=self.std)
- return image, target
-
-
- class loadDataset(Dataset):
- "自定义数据集加载器"
-
- def __init__(self, ROOT_IMAGE: str, ROOT_TARGET: str, TRANSFORM=None):
- """
- :param ROOT_IMAGE: 图像目录
- :param ROOT_TARGET: GT目录
- :param TRANSFORM: 增广方法
- """
- self.imagePaths = [os.path.join(ROOT_IMAGE, i) for i in os.listdir(ROOT_IMAGE)]
- self.targetPaths = [os.path.join(ROOT_TARGET, i) for i in os.listdir(ROOT_TARGET)]
- self.imagePaths.sort() # 对两者排序确认数据对应
- self.targetPaths.sort()
- self.transform = TRANSFORM
-
- def __getitem__(self, item):
- image = Image.open(self.imagePaths[item])
- if image.mode is not 'RGB':
- raise ValueError(f'{self.imagePaths[item]} is not RGB mode')
- target = self.targetPaths[item]
- if self.transform:
- image = self.transform(image)
- return image, target
-
- def __len__(self):
- return len(self.imagePaths)
-
- @staticmethod
- def collect_fn(batch):
- "兼容不同大小图像"
- images, targets = list(zip(*batch))
- batched_imgs = cat_list(images, fillValue=0)
- # 如果为mask,则填充,否则不做处理
- # if mask:
- # batched_targets = cat_list(targets, fillValue=255)
- batched_targets = targets
- return batched_imgs, batched_targets
-
-
- def cat_list(images, fillValue=0):
- maxSize = tuple(max(s) for s in zip(*[img.shape for img in images])) # 获取batch图像中最大尺寸的c,h和w
- batch_shape = (len(images),) + maxSize # 变为批次维度 NCHW
- batched_imgs = images[0].new(*batch_shape).fill_(fillValue) # 创建batch_shape同纬度的mask蒙版
- for image, pad_image in zip(images, batched_imgs): # 为蒙版填充原图像,将批次内不同大小的图像统一为最大底图像(相当于填充满边框)
- pad_image[..., :image.shape[-2], :image.shape[-1]].copy_(image)
- return batched_imgs
-
-
- class TransformTrain:
- def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
- max_size = int(1.5 * base_size)
- min_size = int(0.5 * base_size)
- trans_list = [RandomResize(min_size, max_size)]
- if hflip_prob > 0:
- trans_list.append(RandomHorizontalFlip(hflip_prob))
- trans_list.extend(
- [RandomCrop(crop_size),
- ToTensor(),
- Normalize(mean, std)]
- )
- self.transforms = T.Compose(trans_list)
-
- def __call__(self, image, target):
- return self.transforms(image, target)
-
-
- class TransformVal:
- def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
- self.transforms = Compose([
- RandomResize(base_size, base_size),
- ToTensor(),
- Normalize(mean, std)
- ])
-
- def __call__(self, img, target):
- return self.transforms(img, target)
-
-
- def get_transform(train):
- base_size = 520
- crop_size = 480
- return TransformTrain(base_size, crop_size) if train else TransformVal(base_size)

以上内容作为备忘,需要的小伙伴自取咯~~~
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。