赞
踩
在ptorch训练模型的时候,一般度输入图像做:
data.Dataset
类,以前都是学习别人的程序,对这些操作写一个类,将所有数据变为numpy,可以处理各种想要处理的数据,但是现在预处理的数据全是图片的,所以想用PIL
做一些简单的操作import torch.utils.data as data import torchvision.transforms as transforms class RandomCrop(object): def __init__(self, output_size): self.crop_size = output_size def __call__(self, sample): raw, gt = sample['raw'], sample['gt'] h, w = raw.shape[0], raw.shape[1] np.random.seed() xx = np.random.randint(0, w - self.crop_size) yy = np.random.randint(0, h - self.crop_size) raw = raw[yy:yy + self.crop_size, xx:xx + self.crop_size, :] gt = gt[yy * 4:yy * 4 + self.crop_size * 4, xx * 4:xx * 4 + self.crop_size * 4, :] sample = { 'raw': raw, 'gt' : gt } return sample class RandomFlip(object): def __init__(self): pass def __call__(self, sample): raw, gt = sample['raw'], sample['gt'] do_reflection = np.random.randint(2) do_mirror = np.random.randint(2) do_transpose = np.random.randint(2) if do_reflection: raw = np.flip(raw, 0) gt = np.flip(gt, 0) if do_mirror: raw = np.flip(raw, 1) gt = np.flip(gt, 1) if do_transpose: raw = np.transpose(raw, (1, 0, 2)) gt = np.transpose(gt, (1, 0, 2)) sample = { 'raw': raw, 'gt' : gt } return sample class ToTensor(object): def __init__(self): pass def __call__(self, sample): raw, gt = sample['raw'], sample['gt'] raw = raw.transpose((2, 0, 1)) gt = gt.transpose((2, 0, 1)) raw, gt = np.ascontiguousarray(raw), np.ascontiguousarray(gt) raw, gt = torch.from_numpy(raw), torch.from_numpy(gt) sample = { 'raw': raw, 'gt' : gt } return sample def get_transform(self): transform_list = [] transform_list.append(RandomCrop(crop_size)) transform_list.append(RandomFlip()) transform_list.append(ToTensor()) return transforms.Compose(transform_list) sample = get_transform()(sample)
PIL
import Image import torchvision.transforms.functional as tf def transform(self, img, gt, crop_size=512): ''' :img PIL image :gt PIL image ''' #crop w, h = img.size assert type(crop_size)==int and crop_size<=min(w, h), 'check crop_size type or num' xx = np.random.randint(0, w-crop_size) yy = np.random.randint(0, h-crop_size) img = img.crop((xx, yy, xx+crop_size, yy+crop_size)) gt = gt.crop((xx, yy, xx+crop_size, yy+crop_size)) #flip if np.random.randint(2): img = img.transpose(Image.FLIP_LEFT_RIGHT) gt = gt.transpose(Image.FLIP_LEFT_RIGHT) if np.random.randint(2): img = img.transpose(Image.FLIP_TOP_BOTTOM) gt = gt.transpose(Image.FLIP_TOP_BOTTOM) if np.random.randint(2): img = img.rotate(180) gt = gt.rotate(180) #toTensor img = tf.to_tensor(img) gt = tf.to_tensor(gt) return img, gt
torchvision.transforms.functional
这个库还挺好用的0.0
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。