赞
踩
pytorch为我们提供了非常好的数据增强的包transforms,以下一CIFAR10为例,介绍一下用法:
import torchvision import torchvision.transforms as transforms cifar_norm_mean = (0.49139968, 0.48215827, 0.44653124) cifar_norm_std = (0.24703233, 0.24348505, 0.26158768) transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(cifar_norm_mean, cifar_norm_std), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(cifar_norm_mean, cifar_norm_std), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, pin_memory=(torch.cuda.is_available())) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, pin_memory=(torch.cuda.is_available()))
关于增强方式还可以有:
CenterCrop(size):在图片的中间区域进行裁,剪裁成sizesize大小
RandomCrop(size):在一个随机的位置进行裁,剪裁成sizesize大小
Resize():如果传进参数是int,就把小的边缩放为size,另一边等比例缩放,如果传入turple则缩放为指定大小
RandomHorizontalFlip(0.5):以0.5的概率水平翻转给定的PIL图像
RandomVerticalFlip(0.5):以0.5的概率竖直翻转给定的PIL图像
注意在训练的时候可以通过RandomHorizontalFlip()等来进行增强,测试后的时候则不能,道理非常简单,因为测试的数据必须是固定不变的。
以上的方式是在dataloader里直接操作,还有一种方式是在函数的调用里实现:
IMG_INIT_H=256 IMG_crop_size = (224,224) # for Video Processing class ClipRandomCrop(torchvision.transforms.RandomCrop): def __init__(self, size): self.size = size self.i = None self.j = None self.th = None self.tw = None def __call__(self, img): if self.i is None: self.i, self.j, self.th, self.tw = self.get_params(img, output_size=self.size) #print('crop:', self.i, self.j, self.th, self.tw) return torchvision.transforms.functional.crop(img, self.i, self.j, self.th, self.tw) class ClipRandomHorizontalFlip(object): def __init__(self, ratio=0.5): self.is_flip = random.random() < ratio def __call__(self, img): if self.is_flip: return torchvision.transforms.functional.hflip(img) else: return img def transforms(mode): if (mode=='train'): random_crop = ClipRandomCrop(IMG_crop_size) flip = ClipRandomHorizontalFlip(ratio=0.5) toTensor = torchvision.transforms.ToTensor() normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) return torchvision.transforms.Compose([random_crop, flip,toTensor,normalize]) else: # mode=='test' center_crop = torchvision.transforms.CenterCrop(IMG_crop_size) toTensor = torchvision.transforms.ToTensor() normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) return torchvision.transforms.Compose([center_crop,toTensor,normalize])
调用的时候,可以:
myTransform = transforms(mode=mode) video=[] for i in range(video_frames): s = "%05d" % image_id image_name = 'image_' + s + '.jpg' image_path = os.path.join(video_frame_path, image_name) image = Image.open(image_path) if (image.size[0] < 224): image = image.resize((224, IMG_INIT_H), Image.ANTIALIAS) # to apply the same transform for RGB and depthi, shoule check random_crop in the same position [true] image = myTransform(image) video.append(image) image_id += 1 if (image_id > all_frame_count): break video=torch.stack(video,0)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。