当前位置:   article > 正文

[pytorch]数据增强的方式tranforms的使用_cifar norm

cifar norm

[pytorch]数据增强的方式tranforms的使用

pytorch为我们提供了非常好的数据增强的包transforms,以下一CIFAR10为例,介绍一下用法:

import torchvision
import torchvision.transforms as transforms

cifar_norm_mean = (0.49139968, 0.48215827, 0.44653124)
cifar_norm_std = (0.24703233, 0.24348505, 0.26158768)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(cifar_norm_mean, cifar_norm_std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(cifar_norm_mean, cifar_norm_std),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True,
                                          pin_memory=(torch.cuda.is_available()))
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False,
                                         pin_memory=(torch.cuda.is_available()))

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

关于增强方式还可以有:

CenterCrop(size):在图片的中间区域进行裁,剪裁成sizesize大小
RandomCrop(size):在一个随机的位置进行裁,剪裁成size
size大小
Resize():如果传进参数是int,就把小的边缩放为size,另一边等比例缩放,如果传入turple则缩放为指定大小
RandomHorizontalFlip(0.5):以0.5的概率水平翻转给定的PIL图像
RandomVerticalFlip(0.5):以0.5的概率竖直翻转给定的PIL图像

注意在训练的时候可以通过RandomHorizontalFlip()等来进行增强,测试后的时候则不能,道理非常简单,因为测试的数据必须是固定不变的。

以上的方式是在dataloader里直接操作,还有一种方式是在函数的调用里实现:

IMG_INIT_H=256
IMG_crop_size = (224,224)
# for Video Processing
class ClipRandomCrop(torchvision.transforms.RandomCrop):
  def __init__(self, size):
    self.size = size
    self.i = None
    self.j = None
    self.th = None
    self.tw = None

  def __call__(self, img):
    if self.i is None:
      self.i, self.j, self.th, self.tw = self.get_params(img, output_size=self.size)
      #print('crop:', self.i, self.j, self.th, self.tw)
    return torchvision.transforms.functional.crop(img, self.i, self.j, self.th, self.tw)

class ClipRandomHorizontalFlip(object):
  def __init__(self, ratio=0.5):
    self.is_flip = random.random() < ratio

  def __call__(self, img):
    if self.is_flip:
      return torchvision.transforms.functional.hflip(img)
    else:
      return img

def transforms(mode):
    if (mode=='train'):
        random_crop = ClipRandomCrop(IMG_crop_size)
        flip = ClipRandomHorizontalFlip(ratio=0.5)
        toTensor = torchvision.transforms.ToTensor()
        normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        return torchvision.transforms.Compose([random_crop, flip,toTensor,normalize])
    else:   # mode=='test'
        center_crop = torchvision.transforms.CenterCrop(IMG_crop_size)
        toTensor = torchvision.transforms.ToTensor()
        normalize = torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        return torchvision.transforms.Compose([center_crop,toTensor,normalize])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39

调用的时候,可以:

myTransform = transforms(mode=mode)
video=[]

for i in range(video_frames):
    s = "%05d" % image_id
    image_name = 'image_' + s + '.jpg'
    image_path = os.path.join(video_frame_path, image_name)
    image = Image.open(image_path)
       
    if (image.size[0] < 224):
        image = image.resize((224, IMG_INIT_H), Image.ANTIALIAS)
            

    # to apply the same transform for RGB and depthi, shoule check random_crop in the same position [true]
    image = myTransform(image)
    video.append(image)
        
    image_id += 1
    if (image_id > all_frame_count):
        break
video=torch.stack(video,0)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/351042
推荐阅读
相关标签
  

闽ICP备14008679号