当前位置:   article > 正文

自动数据增强方法(附代码)_数据增强代码

数据增强代码

前言

一个模型的性能除了和网络结构本身有关,还非常依赖具体的训练策略,比如优化器,数据增强以及正则化策略等(当然也很训练数据强相关,训练数据量往往决定模型性能的上线)。近年来,图像分类模型在ImageNet数据集的top1 acc已经由原来的56.5(AlexNet,2012)提升至90.88(CoAtNet,2021,用了额外的数据集JFT-3B),这进步除了主要归功于模型,算力和数据的提升,也与训练策略的提升紧密相关。最近刚兴起的vision transformer相比CNN模型往往也需要更heavy的数据增强和正则化策略。这里简单介绍图像分类训练技巧中的常用数据增强策略。

baseline

ImageNet数据集训练常用的数据增强策略如下,训练过程的数据增强包括随机缩放裁剪(RandomResizedCrop,这种处理方式源自谷歌的Inception,所以称为 Inception-style pre-processing)和水平翻转(RandomHorizontalFlip),而测试阶段是执行缩放和中心裁剪。这其实是一种轻量级的策略,这里称之为baseline。torchvision的实现的ResNet50训练采用的策略就是这个,在ImageNet上的top1 acc可以达到76.1。

  1. from torchvision import transforms
  2. normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
  3. std=[0.229, 0.224, 0.225])
  4. # 训练
  5. train_transform = transforms.Compose([
  6. # 这里的scale指的是面积,ratio是宽高比
  7. # 具体实现每次先随机确定scale和ratio,可以生成w和h,然后随机确定裁剪位置进行crop
  8. # 最后是resize到target size
  9. transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)),
  10. transforms.RandomHorizontalFlip(),
  11. transforms.ToTensor(),
  12. normalize
  13. ])
  14. # 测试
  15. test_transform = transforms.Compose([
  16. transforms.Resize(256),
  17. transforms.CenterCrop(224),
  18. transforms.ToTensor(),
  19. normalize,
  20. ])

AutoAugment

谷歌在2018年提出通过AutoML来自动搜索数据增强策略,称之为AutoAugment(算是自动数据增强开山之作)。搜索方法采用强化学习,和NAS类似,只不过搜索空间是数据增强策略,而不是网络架构。在搜索空间里,一个policy包含5个sub-policies,每个sub-policy包含两个串行的图像增强操作,每个增强操作有两个超参数:进行该操作的概率图像增强的幅度(magnitude,这个表示数据增强的强度,比如对于旋转,旋转的角度就是增强幅度,旋转角度越大,增强越大)。每个policy在执行时,首先随机从5个策略中随机选择一个sub-policy,然后序列执行两个图像操作。

 搜索空间一共有16种图像增强类型,具体如下所示,大部分操作都定义了图像增强的幅度范围,在搜索时需要将幅度值离散化,具体地是将幅度值在定义范围内均匀地取10个值。

 论文在不同的数据集上( CIFAR-10 , SVHN, ImageNet)做了实验,这里给出在ImageNet数据集上搜索得到的最优policy(最后实际上是将搜索得到的前5个最好的policies合成了一个policy,所以这里包含25个sub-policies):

  1. # operation, probability, magnitude
  2. (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
  3. (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
  4. (("Equalize", 0.8, None), ("Equalize", 0.6, None)),
  5. (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
  6. (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
  7. (("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
  8. (("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
  9. (("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
  10. (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
  11. (("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
  12. (("Rotate", 0.8, 8), ("Color", 0.4, 0)),
  13. (("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
  14. (("Equalize", 0.0, None), ("Equalize", 0.8, None)),
  15. (("Invert", 0.6, None), ("Equalize", 1.0, None)),
  16. (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
  17. (("Rotate", 0.8, 8), ("Color", 1.0, 2)),
  18. (("Color", 0.8, 8), ("Solarize", 0.8, 7)),
  19. (("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
  20. (("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
  21. (("Color", 0.4, 0), ("Equalize", 0.6, None)),
  22. (("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
  23. (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
  24. (("Invert", 0.6, None), ("Equalize", 1.0, None)),
  25. (("Color", 0.6, 4), ("Contrast", 1.0, 8)),
  26. (("Equalize", 0.8, None), ("Equalize", 0.6, None))

基于搜索得到的AutoAugment训练可以将ResNet50在ImageNet数据集上的top1 acc从76.3提升至77.6。一个比较重要的问题,这些从某一个数据集搜索得到的策略是否只对固定的数据集有效,论文也通过具体实验证明了AutoAugment的迁移能力,比如将ImageNet数据集上得到的策略用在5个 FGVC数据集(与ImageNet图像输入大小相似)也均有提升。 ​

目前torchvision库已经实现了AutoAugment,具体使用如下所示(注意AutoAug前也需要包括一个RandomResizedCrop):

  1. from torchvision.transforms import autoaugment, transforms
  2. train_transform = transforms.Compose([
  3. transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
  4. transforms.RandomHorizontalFlip(hflip_prob),
  5. # 这里policy属于torchvision.transforms.autoaugment.AutoAugmentPolicy,
  6. # 对于ImageNet就是 AutoAugmentPolicy.IMAGENET
  7. # 此时aa_policy = autoaugment.AutoAugmentPolicy('imagenet')
  8. autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation),
  9. transforms.PILToTensor(),
  10. transforms.ConvertImageDtype(torch.float),
  11. transforms.Normalize(mean=mean, std=std)
  12. ])

RandAugment

AutoAugment存在的一个问题是搜索空间巨大,这使得搜索只能在代理任务中进行:使用小的模型在ImageNet的一个小的子集( 120类和6000图片)搜索。谷歌在2019年又提出了一个更简单的数据增强策略:RandAugment。这篇论文首先发现AutoAugment这样在小数据集上搜索出来的策略在大的数据集上应用会存在问题,这主要是因为数据增强策略和模型大小和数据量大小存在强相关,如下图所示可以看到模型或者训练数据量越大,其最优的数据增强的幅度越大,这说明AutoAugment得到的结果应该是次优的。另外,Population Based Augmentation这篇论文发现最优的数据增强幅度是随训练过程增加,而且不同的增强操作遵循类似的规律,这启发作者采用固定的增强幅度而不是去搜索。RandAugment相比AutoAugment的策略空间很小(10^{2} vs 10^{32}),所以它不需要采用代理任务,甚至直接采用简单的网格搜索。

 具体地,RandAugment共包含两个超参数:图像增强操作的数量N和一个全局的增强幅度M,其实现代码如下所示,每次从候选操作集合(共14种策略)随机选择N个操作(等概率),然后串行执行(这里没有判断概率,是一定执行)。这里的M取值范围为{0, . . . , 30}(每个图像增强操作归一化到同样的幅度范围),而N取值范围一般为 {1, 2, 3}。

  1. # Identity是恒等变换,不做任何增强
  2. transforms = ['Identity', 'AutoContrast', 'Equalize', 'Rotate', 'Solarize', 'Color', 'Posterize',
  3. 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY']
  4. def randaugment(N, M):
  5. """Generate a set of distortions.
  6. Args:
  7. N: Number of augmentation transformations to
  8. apply sequentially.
  9. M: Magnitude for all the transformations.
  10. """
  11. sampled_ops = np.random.choice(transforms, N)
  12. return [(op, M) for op in sampled_ops]

对于ResNet50,其搜索得到的N=2,M=9,RandAugment相比AutoAugment可以在ImageNet得到相似的效果(77.6),不过DeiT中发现使用RandAugment效果更好一些( DeiT-B:81.8 vs 81.2)。目前torchvision库也已经实现了RandAugment,具体使用如下所示:

  1. from torchvision.transforms import autoaugment, transforms
  2. train_transform = transforms.Compose([
  3. transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
  4. transforms.RandomHorizontalFlip(hflip_prob),
  5. autoaugment.RandAugment(interpolation=interpolation),
  6. transforms.PILToTensor(),
  7. transforms.ConvertImageDtype(torch.float),
  8. transforms.Normalize(mean=mean, std=std)
  9. ])

TrivialAugment

虽然RandAugment的搜索空间极小,但是对于不同的数据集还是需要确定最优的N和M,这依然有较大的实验成本。RandAugment后,华为提出了UniformAugment,这种策略不需要搜索也能取得较好的结果。不过这里我们介绍一项更新的工作:TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation。TrivialAugment也不需要任何搜索,整个方法非常简单:每次随机选择一个图像增强操作,然后随机确定它的增强幅度,并对图像进行增强。由于没有任何超参数,所以不需要任何搜索。从实验结果上看,TA可以在多个数据集上取得更好的结果,如在ImageNet数据集上,ResNet50的top1 acc可以达到78.1,超过RandAugment。

 TrivialAugment的图像增强集合和RandAugment基本一样,不过TA也定义了一套更宽的增强幅度,目前torchvision中已经实现了TrivialAugmentWide,具体使用代码如下所示:

  1. from torchvision.transforms import autoaugment, transforms
  2. augmentation_space = {
  3. # op_name: (magnitudes, signed)
  4. "Identity": (torch.tensor(0.0), False),
  5. "ShearX": (torch.linspace(0.0, 0.99, num_bins), True),
  6. "ShearY": (torch.linspace(0.0, 0.99, num_bins), True),
  7. "TranslateX": (torch.linspace(0.0, 32.0, num_bins), True),
  8. "TranslateY": (torch.linspace(0.0, 32.0, num_bins), True),
  9. "Rotate": (torch.linspace(0.0, 135.0, num_bins), True),
  10. "Brightness": (torch.linspace(0.0, 0.99, num_bins), True),
  11. "Color": (torch.linspace(0.0, 0.99, num_bins), True),
  12. "Contrast": (torch.linspace(0.0, 0.99, num_bins), True),
  13. "Sharpness": (torch.linspace(0.0, 0.99, num_bins), True),
  14. "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 6)).round().int(), False),
  15. "Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
  16. "AutoContrast": (torch.tensor(0.0), False),
  17. "Equalize": (torch.tensor(0.0), False),
  18. }
  19. train_transform = transforms.Compose([
  20. transforms.RandomResizedCrop(crop_size, interpolation=interpolation),
  21. transforms.RandomHorizontalFlip(hflip_prob),
  22. autoaugment.TrivialAugmentWide(interpolation=interpolation),
  23. transforms.PILToTensor(),
  24. transforms.ConvertImageDtype(torch.float),
  25. transforms.Normalize(mean=mean, std=std)
  26. ])

RandomErasing

RandomErasing是厦门大学在2017年提出的一种简单的数据增强(这个策略和同期的CutOut基本一样),基本原理是:随机从图像中擦除一个矩形区域而不改变图像的原始标签。DeiT的训练策略中也包括了RandomErasing。

 目前torchvision也实现了RandomErasing,其具体使用代码如下(注意这个op不支持PIL图像,需要在转换为torch.tensor后使用):

  1. train_transform = transforms.Compose([
  2. transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.)),
  3. transforms.RandomHorizontalFlip(),
  4. transforms.PILToTensor()
  5. transforms.ConvertImageDtype(torch.float),
  6. normalize,
  7. # scale是指相对于原图的擦除面积范围
  8. # ratio是指擦除区域的宽高比
  9. # value是指擦除区域的值,如果是int,也可以是tuple(RGB3个通道值),或者是str,需为'random',表示随机生成
  10. transforms.RandomErasing(p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False),
  11. ])

MixUp

MixUp在FAIR在2017年提出的一种数据增强方法:两张不同的图像随机线性组合,而同时生成线性组合的标签。

这里的x_{i}y_{j}是两张不同的图像,y_{i}y_{j}是它们对应的one-hot标签,而\lambda \in(0,1)是线性组合系数,每次执行时随机生成。假定图像分类任务是2分类(区分狗和猫),两张输入图像分别是狗和猫(如下图所示),它们对应的one-hot标签分别是[1,0]和[0, 1]。在进行mixup之前,首先对它们进行必要的数据增强得到aug_img1和aug_img2,然后随机生成线性组合系数,对于\lambda = 0.7得到的图像是mix_img1,标签变为[0.7, 0.3],而\lambda = 0.3得到的图像是mix_img2,标签变为[0.3, 0.7]。

 目前timm和torchvision中已经实现了mixup,这里以torchvision为例来讲述具体的代码实现。由于mixup需要两个输入,而不单单是对当前图像进行操作,所以一般是在得到batch数据后再进行mixup,这也意味着图像也已经完成了其它的数据增强如RandAugment,对于batch中的每个样本可以随机选择另外一个样本进行mixup。具体的实现代码如下所示:

  1. # from https://github.com/pytorch/vision/blob/main/references/classification/transforms.py
  2. class RandomMixup(torch.nn.Module):
  3. """Randomly apply Mixup to the provided batch and targets.
  4. The class implements the data augmentations as described in the paper
  5. `"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_.
  6. Args:
  7. num_classes (int): number of classes used for one-hot encoding.
  8. p (float): probability of the batch being transformed. Default value is 0.5.
  9. alpha (float): hyperparameter of the Beta distribution used for mixup.
  10. Default value is 1.0. # beta分布超参数
  11. inplace (bool): boolean to make this transform inplace. Default set to False.
  12. """
  13. def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
  14. super().__init__()
  15. assert num_classes > 0, "Please provide a valid positive value for the num_classes."
  16. assert alpha > 0, "Alpha param can't be zero."
  17. self.num_classes = num_classes
  18. self.p = p
  19. self.alpha = alpha
  20. self.inplace = inplace
  21. def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
  22. """
  23. Args:
  24. batch (Tensor): Float tensor of size (B, C, H, W)
  25. target (Tensor): Integer tensor of size (B, )
  26. Returns:
  27. Tensor: Randomly transformed batch.
  28. """
  29. if batch.ndim != 4:
  30. raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
  31. if target.ndim != 1:
  32. raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
  33. if not batch.is_floating_point():
  34. raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
  35. if target.dtype != torch.int64:
  36. raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
  37. if not self.inplace:
  38. batch = batch.clone()
  39. target = target.clone()
  40. # 建立one-hot标签
  41. if target.ndim == 1:
  42. target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
  43. # 判断是否进行mixup
  44. if torch.rand(1).item() >= self.p:
  45. return batch, target
  46. # 这里将batch数据平移一个单位,产生mixup的图像对,这意味着每个图像与相邻的下一个图像进行mixup
  47. # timm实现是通过flip来做的,这意味着第一个图像和最后一个图像进行mixup
  48. # It's faster to roll the batch by one instead of shuffling it to create image pairs
  49. batch_rolled = batch.roll(1, 0)
  50. target_rolled = target.roll(1, 0)
  51. # 随机生成组合系数
  52. # Implemented as on mixup paper, page 3.
  53. lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
  54. batch_rolled.mul_(1.0 - lambda_param)
  55. batch.mul_(lambda_param).add_(batch_rolled) # 得到mixup后的图像
  56. target_rolled.mul_(1.0 - lambda_param)
  57. target.mul_(lambda_param).add_(target_rolled) # 得到mixup后的标签
  58. return batch, target

然后可以将MixUp操作放在DataLoader的collate_fn中,这个函数要实现的是将多个样本合并成一个mini-batch,所以可以将MixUp插在得到mini-batch后,具体实现如下所示:

  1. from torch.utils.data.dataloader import default_collate
  2. mixup_transform = RandomMixup(num_classes, p=1.0, alpha=mixup_alpha)
  3. collate_fn = lambda batch: mixup_transform(*default_collate(batch))
  4. data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
  5. sampler=train_sampler, collate_fn=collate_fn)

对于MixUp,还要注意两个两点。第一个是如果同时采用了label smoothing,那么在创建one-hot标签时要直接得到smooth后的标签,具体实现如下(参考timm):

  1. def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
  2. x = x.long().view(-1, 1)
  3. return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value)
  4. off_value = smoothing / num_classes
  5. on_value = 1. - smoothing + off_value
  6. smooth_one_hot = one_hot(target, num_classes, on_value=on_value, off_value=off_value)

第二个要注意的是MixUp后得到标签时soft label,不能直接采用torch.nn.CrossEntropyLoss来计算loss,而是直接计算交叉熵(参考timm):

  1. class SoftTargetCrossEntropy(nn.Module):
  2. def __init__(self):
  3. super(SoftTargetCrossEntropy, self).__init__()
  4. def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
  5. loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1)
  6. return loss.mean()

注意在PyTorch1.10版本之后,torch.nn.CrossEntropyLoss已经支持直接送入的target是probabilities for each class,原来只支持target是class indices;而且也支持label_smoothing参数,所以上述两个注意点就不再需要了。

说到计算loss,timm作者近期在ResNet strikes back: An improved training procedure in timm指出采用MixUp后可以将多分类改成多标签分类(multi-label classification),即从N分类变成N个2分类(直接采用BinaryCrossEntropy),这应该更符合MixUp后图像的语义,从对比实验来看效果有微弱的提升。 MixUp除了可以用于图像分类任务,还可以用于物体检测任务中,比如YOLOX就采用了MixUp,这里面的做法是对图像mixup后,其box为两个图像的box的合并集合,而没有对标签软化,这块也可以见论文Bag of Freebies for Training Object Detection Neural Networks

CutMix

CutMix是2019年提出的一项和MixUp和类似的数据增强策略,它也是同时对两个图像和标签进行混合,与MixUp不同的是它的图像混合方式。CutMix不是对两个图像线性组合,而是从另外一张图像随机剪切一个patch并粘贴到第一张图像上,patch的起始坐标随机生成,而宽高是由\lambda来控制:

这里W和H是原始图像的宽和高,所以λ其实决定的是patch和原图的面积比:。下图展示了λ分别取0.7和0.3的混合效果,λ越小,粘贴的patch越大。对于标签,其处理方式和MixUp一样,通过λ来得到两张图像的线性组合。

 CutMix做了ImageNet上的对比实验,相比MixUp,ResNet50的top1 acc大约能提升一个点(77.4 vs 78.6):

目前timm和torchvision中也已经实现了CutMix,这里还是以torchvision为例来讲述具体的代码实现,如下所示(和MixUp基本类似,只不过内部处理存在差异):

  1. class RandomCutmix(torch.nn.Module):
  2. """Randomly apply Cutmix to the provided batch and targets.
  3. The class implements the data augmentations as described in the paper
  4. `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features"
  5. <https://arxiv.org/abs/1905.04899>`_.
  6. Args:
  7. num_classes (int): number of classes used for one-hot encoding.
  8. p (float): probability of the batch being transformed. Default value is 0.5.
  9. alpha (float): hyperparameter of the Beta distribution used for cutmix.
  10. Default value is 1.0.
  11. inplace (bool): boolean to make this transform inplace. Default set to False.
  12. """
  13. def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None:
  14. super().__init__()
  15. assert num_classes > 0, "Please provide a valid positive value for the num_classes."
  16. assert alpha > 0, "Alpha param can't be zero."
  17. self.num_classes = num_classes
  18. self.p = p
  19. self.alpha = alpha
  20. self.inplace = inplace
  21. def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
  22. """
  23. Args:
  24. batch (Tensor): Float tensor of size (B, C, H, W)
  25. target (Tensor): Integer tensor of size (B, )
  26. Returns:
  27. Tensor: Randomly transformed batch.
  28. """
  29. if batch.ndim != 4:
  30. raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}")
  31. if target.ndim != 1:
  32. raise ValueError(f"Target ndim should be 1. Got {target.ndim}")
  33. if not batch.is_floating_point():
  34. raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.")
  35. if target.dtype != torch.int64:
  36. raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}")
  37. if not self.inplace:
  38. batch = batch.clone()
  39. target = target.clone()
  40. if target.ndim == 1:
  41. target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype)
  42. if torch.rand(1).item() >= self.p:
  43. return batch, target
  44. # It's faster to roll the batch by one instead of shuffling it to create image pairs
  45. batch_rolled = batch.roll(1, 0)
  46. target_rolled = target.roll(1, 0)
  47. # Implemented as on cutmix paper, page 12 (with minor corrections on typos).
  48. lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
  49. W, H = F.get_image_size(batch)
  50. # 确定patch的起点
  51. r_x = torch.randint(W, (1,))
  52. r_y = torch.randint(H, (1,))
  53. # 确定patch的w和h(其实是一半大小)
  54. r = 0.5 * math.sqrt(1.0 - lambda_param)
  55. r_w_half = int(r * W)
  56. r_h_half = int(r * H)
  57. # 越界处理
  58. x1 = int(torch.clamp(r_x - r_w_half, min=0))
  59. y1 = int(torch.clamp(r_y - r_h_half, min=0))
  60. x2 = int(torch.clamp(r_x + r_w_half, max=W))
  61. y2 = int(torch.clamp(r_y + r_h_half, max=H))
  62. batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2]
  63. # 由于越界处理, λ可能发生改变,所以要重新计算
  64. lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H))
  65. target_rolled.mul_(1.0 - lambda_param)
  66. target.mul_(lambda_param).add_(target_rolled)
  67. return batch, target

其它使用和MixUp一样。

Repeated Augmentation

Repeated Augmentation (RA)是FAIR在MultiGrain提出的一种抽样策略,一般情况下,训练的mini-batch包含的增强过的sample都是来自不同的图像,但是RA这种抽样策略允许一个mini-batch中包含来自同一个图像的不同增强版本,此时mini-batch的各个样本并非是完全独立的,这相当于对同一个样本进行重复抽样,所以称为Repeated Augmentation。这篇论文认为在一个mini-batch学习来自同一个图像的不同增强版本能让模型更容易学习到增强不变的特征。关于RA,其实另外一篇较早的论文Augment your batch: better training with larger batches也提出了类似的策略,另外DeepMind在最近的论文Drawing Multiple Augmentation Samples Per Image During Training Efficiently Decreases Test Error也进一步通过实验来证明这种策略的效果。 ​

DeiT的训练也采用了RA,严格来说RA不属于数据增强策略,而是一种mini-batch抽样方法,这里也简单给出DeiT实现的RA(可以替换torch.utils.data.DistributedSampler):

  1. class RASampler(torch.utils.data.Sampler):
  2. """Sampler that restricts data loading to a subset of the dataset for distributed,
  3. with repeated augmentation.
  4. It ensures that different each augmented version of a sample will be visible to a
  5. different process (GPU)
  6. Heavily based on torch.utils.data.DistributedSampler
  7. """
  8. def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
  9. if num_replicas is None:
  10. if not dist.is_available():
  11. raise RuntimeError("Requires distributed package to be available")
  12. num_replicas = dist.get_world_size()
  13. if rank is None:
  14. if not dist.is_available():
  15. raise RuntimeError("Requires distributed package to be available")
  16. rank = dist.get_rank()
  17. self.dataset = dataset
  18. self.num_replicas = num_replicas
  19. self.rank = rank
  20. self.epoch = 0
  21. # 重复采样后每个replica的样本量
  22. self.num_samples = int(math.ceil(len(self.dataset) * 3.0 / self.num_replicas))
  23. # 重复采样后的总样本量
  24. self.total_size = self.num_samples * self.num_replicas
  25. # self.num_selected_samples = int(math.ceil(len(self.dataset) / self.num_replicas))
  26. # 每个replica实际样本量,即不重复采样时的每个replica的样本量
  27. self.num_selected_samples = int(math.floor(len(self.dataset) // 256 * 256 / self.num_replicas))
  28. self.shuffle = shuffle
  29. def __iter__(self):
  30. # deterministically shuffle based on epoch
  31. g = torch.Generator()
  32. g.manual_seed(self.epoch)
  33. if self.shuffle:
  34. indices = torch.randperm(len(self.dataset), generator=g).tolist()
  35. else:
  36. indices = list(range(len(self.dataset)))
  37. # add extra samples to make it evenly divisible
  38. indices = [ele for ele in indices for i in range(3)] # 重复3次
  39. indices += indices[:(self.total_size - len(indices))]
  40. assert len(indices) == self.total_size
  41. # subsample: 使得同一个样本的重复版本进入不同的进程(GPU)
  42. indices = indices[self.rank:self.total_size:self.num_replicas]
  43. assert len(indices) == self.num_samples
  44. return iter(indices[:self.num_selected_samples]) # 截取实际样本量
  45. def __len__(self):
  46. return self.num_selected_samples
  47. def set_epoch(self, epoch):
  48. self.epoch = epoch

小结

这里简单介绍了几种常用且有效的数据增强策略,这些策略在vision transformer模型被使用,而且timm训练的ResNet新baseline也使用了这些策略。

参考

  1. Training data-efficient image transformers & distillation through attention
  2. AutoAugment: Learning Augmentation Policies from Data
  3. RandAugment: Practical automated data augmentation with a reduced search space
  4. TrivialAugment: Tuning-free Yet State-of-the-Art Data Augmentation
  5. Random Erasing Data Augmentation
  6. Augment your batch: better training with larger batches
  7. MultiGrain: a unified image embedding for classes and instances
  8. mixup: Beyond Empirical Risk Minimization
  9. CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features

转载

图像分类训练技巧之数据增强篇 - 知乎

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/255253
推荐阅读
相关标签
  

闽ICP备14008679号