当前位置:   article > 正文

半监督学习 MixMatch:A Holistic Approach to Semi-Supervised Learning(核心代码)_mixmatch: a holistic approach to semi-supervised l

mixmatch: a holistic approach to semi-supervised learning

背景

伪标签半监督学习方法中,伪标签的选择不容易,在模型训练初期容易选出误差较大的伪标签导致模型性能不佳;unsupervised loss中的权重系数不好确定。

提出方法:MixMatch

将当前用于半监督学习的主要方法相结合,以生成一种新算法 MixMatch,该算法猜测数据增强未标记示例的低熵标签,并使用 MixUp 混合标记和未标记数据。

只采用了250个标签,就减小了错误率。

半监督学习(SSL) 旨在通过允许模型利用未标记数据来很大程度上减轻对标记数据的需求。其中一种半监督学习方法是在损失函数中添加一个损失项,该损失项是在未标记的数据上计算的,并鼓励模型更好地泛化到看不见的数据。

其中损失项属于以下三类之一:

熵最小化——它鼓励模型对未标记的数据输出可信的预测;

一致性正则化——它鼓励模型在其输入受到扰动时产生相同的输出分布;

通用正则化——它鼓励模型很好地泛化并避免过度拟合训练数据。

MixMatch,是一种 SSL 算法。它引入了一个单一的损失,可以将以上损失项统一到一个半监督学习方法中。与以前的方法不同,MixMatch 为未标记数据引入了一个统一的损失项,可以无缝地降低熵,同时保持一致性并与传统的正则化技术保持兼容。

方法核心:MixUp

Label Guessing

标签猜测,目的是达到和伪标签相同的作用,但与伪标签不同。

对于未标记数据集中的每个未标记示例,MixMatch 使用模型的预测为示例的标签生成一个“猜测”。这个猜测后来被用在无监督损失项中

Sharpening

锐化

受半监督学习中熵最小化的启发,在生成标记猜测时,我们执行了一个额外的步骤。给定对增强数据 \bar{q_b}的平均预测,应用锐化函数来降低标签分布的熵。 在实践中,通过调节“温度”的方式来控制锐化函数的使用范畴。

MixUp

\chi:labeled data, U:unlabeled data

有标记数据和无标记数据所取batch大小相同,其中无标记数据会经过K个增广

算法步骤

 首先对有标记的每个样本做数据增广;

接着对无标记的每个样本做K次数据增广,文中K=2,即做两次数据增广;

将经过增广的无标记数据送入模型,每一个数据会预测一个结果;对结果取均值,再求Sharpen,这就是label guessing的操作。

上述操作后,得到两份数据,第一份\chi:增广后有标记的数据;第二份U:增广后有标记的无标记数据(这里的标记是经过训练猜测出来的) 

将有标记的和无标记的数据concat起来,组合成一个大的batch,然后将其随机打乱,然后跟原始数据混合起来(MinUp)。

混合之后的数据再进入模型进行前向计算,最后求loss。

其中loss function定义:

模型建立

超参数

损失函数-指数移动平均 

an exponential moving average of model parameter values, it provides a more stable target and was found empirically to significantly improve results.

 通过指数平均来调整模型,使模型更稳定

rampup

we only change α and \lambda_u on a per-dataset basis; we found that \alpha = 0.75 and \lambda_u= 100 are good starting points for tuning. In all experiments, we linearly ramp up \lambda_u to its maximum value over the first 16,000 steps of training as is common practice.

ramp参数是通过实验试出来的经验值,线性缓慢地使其增长到最大值(16000步)。训练初期,\lambda_u的值不能过大,不然会导致模型训练出现问题。

Labelguessing and Sharpening

Label guessing:

\bar{q}_{b}=\frac{1}{K} \sum_{k=1}^{K} \operatorname{p}_{\operatorname{model}}\left(y \mid \hat{u}_{b, k} ; \theta\right)

Sharpening:

\operatorname{Sharpen}(p, T)_{i}:=p_{i}^{\frac{1}{T}} / \sum_{j=1}^{L} p_{j}^{\frac{1}{T}}

As T → 0, the output of Sharpen(p, T) will approach a Dirac (“one-hot”) distribution. T → \infty, the output of Sharpen approach to the same. (各自的区别就无法体现了

MixUp

首先从Beta中抽样\lambda,接着在 \lambda和1-\lambda中选择最大值{\lambda }',通过乘上{\lambda }'和1-{\lambda }',将x1和x2混合得到{x}',标签p也是同样的方式混合。 

  1. def mixup(x, u, u2, trg_x, out_u, out_u2, alpha=0.75):
  2. """
  3. mixup: eq.8-11, algorithm: line 12-14
  4. :param x: labeled x (input data [N, 3, H, W]
  5. :param u: the first unlabeled data (第一次增广得到的未标记数据 [N, 3, H, W]
  6. :param u2: the second unlabeled data [N, 3, H, W]
  7. :param trg_x: labeled x target(y), [N, ] = [0, 7, 8....]这里数据的意思就是比如第一个类别的index,第二个类别的index...
  8. :param out_u: label guessing以后得到的 q_b, after label guessing
  9. :param out_u2: q_b
  10. :param alpha: Beta hupe
  11. :return: mix up result: x:[3*N, 3, H, W], y:[3*N, 10]
  12. """
  13. batch_size = x.size(0) # batch_Size = HP.batch_size
  14. n_classes = out_u.size(1) # classes number:10
  15. device = x.device
  16. # [0.1, 0.3, 0.01,...] dim=10
  17. # 类别index 8 无法与之相加,因此需要做one-hot dim=10
  18. # target x back to onehot
  19. trg_x_onehot = torch.zeros(size=(batch_size, n_classes)).float().to(device)
  20. # [0, 0., 0., 0, 0., 0, 0., 0, 0.,0]
  21. # trg[7]
  22. # [0, 0., 0., 0, 0., 0, 0., 1., 0.,0]
  23. trg_x_onehot.scatter_(1, trg_x.view(-1, 1), 1.)
  24. # mixup
  25. x_cat = torch.cat([x, u, u2], dim=0)
  26. trg_cat = torch.cat([trg_x_onehot, out_u, out_u2], dim=0)
  27. n_item = x_cat.size(0) # N*3 (batch size 的维度改变
  28. lam = np.random.beta(alpha, alpha) # eq.8
  29. lam_prime = max(lam, 1-lam) # eq.9
  30. # 随机一个index
  31. rand_idx = torch.randperm(n_item) # a rand index sequence:[0, 2, 1], [1, 0, 2] <-如果传入的是3, 那么就随机出这样的index
  32. x_cat_shuffled = x_cat[rand_idx] # x2 随机得到的
  33. trg_cat_shuffled = trg_cat[rand_idx] # target也需要对应的随机打乱 p2
  34. x_cat_mixup = lam_prime * x_cat + (1 - lam_prime) * x_cat_shuffled # eq.10
  35. trg_cat_mixup = lam_prime * trg_cat + (1 - lam_prime) * trg_cat_shuffled # eq.11
  36. return x_cat_mixup, trg_cat_mixup

 Loss Function

 本质上就是一个target和预测输出之间求交叉熵,supervised loss

\mathcal{L}_{\mathcal{X}}=\frac{1}{\left|\mathcal{X}^{\prime}\right|} \sum_{x, p \in \mathcal{X}^{\prime}} \mathrm{H}\left(p, \mathrm{p}_{\text {model }}(y \mid x ; \theta)\right)

本质上就是一个均方误差loss,一致性规范?(consistency regulation)

\mathcal{L}_{\mathcal{U}}=\frac{1}{L\left|\mathcal{U}^{\prime}\right|} \sum_{u, q \in \mathcal{U}^{\prime}}\left\|q-\mathrm{p}_{\text {model }}(y \mid u ; \theta)\right\|_{2}^{2}

  1. class MixUpLoss(nn.Module):
  2. def __init__(self):
  3. super(MixUpLoss, self).__init__()
  4. def forward(self, output_x, trg_x, output_u, trg_u):
  5. """
  6. loss function: eq.2-4, eq.5 explain in trainer
  7. :param output_x: mixuped x output-shape[N, 10]
  8. :param trg_x: trg_x-mixuped target-shape[N, 10]
  9. :param output_u:[x, u, u2], size 3*N, mixuped u output =>shape [2*N, 10]
  10. :param trg_u: mixuped target u output shape[2*N, 10]
  11. :return: Lx, Lu
  12. """
  13. # cross-entropy, supervised loss
  14. Lx = -torch.mean(torch.sum(F.log_softmax(output_x, dim=-1)*trg_x, dim=-1)) # dim=-1 到最后一个维度求和
  15. Lu = F.mse_loss(output_u, trg_u) # consistency regulation
  16. return Lx, Lu

Training process

  1. # train func
  2. def train():
  3. parser = ArgumentParser(description='Model Training')
  4. parser.add_argument(
  5. '--c',
  6. default=None,
  7. type=str,
  8. help='train from scratch or resume from checkpint'
  9. )
  10. args = parser.parse_args()
  11. # new models: model/ema_model
  12. model = WideResnet50_2()
  13. model = model.to(HP.device)
  14. ema_model = new_ema_model()
  15. model_ema_opt = WeightEMA(model, ema_model)
  16. # loss
  17. criterion_val = nn.CrossEntropyLoss() # for eval
  18. criterion_train = MixUpLoss() # for training
  19. opt = optim.Adam(model.parameters(), lr=HP.init_lr, weight_decay=0.001) # optimizer with L2 regular
  20. start_epoch, step = 0, 0
  21. if args.c:
  22. checkpoint = torch.load(args.c)
  23. model.load_state_dict(checkpoint['model_state_dict'])
  24. model.load_state_dict(checkpoint['ema_model_state_dict'])
  25. opt.load_state_dict(checkpoint['optimizer_state_dict'])
  26. start_epoch = checkpoint['epoch']
  27. print('Resume From %s.' % args.c)
  28. else:
  29. print('Training from scratch!')
  30. model.train()
  31. eval_loss = 0.
  32. # 因为在半监督中label data很少,所以按照unlabeled data来算step
  33. n_unlabeled = len(unlabeled_trainloader) # as regist count for training step
  34. # train loop
  35. for epoch in range(start_epoch, HP.epochs):
  36. print('Start epoch: %d, Step: %d' % (epoch, n_unlabeled))
  37. for i in range(n_unlabeled): # one unlabeled data turn as an epoch
  38. # inputs_x:[N, 3, H, W], trg_x:[N, ]
  39. inputs_x, trg_x = next(iter(labeled_trainloader)) # get one batch from a labeled dataloader
  40. # inputs_u / inputs_u2 -> [N, 3, H, W]
  41. (inputs_u, inputs_u2), _ = next(iter(unlabeled_trainloader))
  42. inputs_x, trg_x, inputs_u, inputs_u2 = inputs_x.to(HP.device), trg_x.long().to(HP.device), \
  43. inputs_u.to(HP.device), inputs_u2.to(HP.device)
  44. # $$$$$$$$$$$$$$$ Algorithm Line7-Line 8 $$$$$$$$$$$$$$$ label guessing
  45. with torch.no_grad():
  46. out_u = model(inputs_u) # Aug K=1, inference [N, 10]
  47. out_u2 = model(inputs_u2) # Aug K=2, inference [N, 10]
  48. q = label_guessing(out_u, out_u2) # average post distribution [N, 10]
  49. q = sharpen(q, T=HP.T) # [N, 10]
  50. # $$$$$$$$$$$$$$$ Algorithm Line10-Line 15 $$$$$$$$$$$$$$$ Mixup
  51. # u和u2本质上是同一份数据经过两次增广得到的, 因此labelguessing的结果应该是一样的
  52. # mixuped_x:[3*N, 3, H, W]三组数据concat在一起,mixuped_out:[3*N, 10]
  53. mixuped_x, mixuped_out = mixup(x=inputs_x, u=inputs_u, u2=inputs_u2, trg_x=trg_x, out_u=q, out_u2=q)
  54. # model forward前向
  55. mixuped_logits = model(mixuped_x) # [3*N, 10]
  56. # labeled data 和unlabeled data要分开,因为求loss的时候是两部分组成
  57. logits_x = mixuped_logits[: HP.batch_size] # [N, 10]
  58. logits_u = mixuped_logits[HP.batch_size:] # [2*N, 10]
  59. # eq.2-5
  60. loss_x, loss_u = criterion_train(logits_x, mixuped_out[: HP.batch_size],
  61. logits_u, mixuped_out[HP.batch_size:])
  62. loss = loss_x + lambda_rampup(step, max_v=HP.lambda_u) * loss_u # eq.5
  63. logger.add_scalar('Loss/Train', loss, step)
  64. opt.zero_grad() # 梯度清零
  65. # 反向求导
  66. loss.backward()
  67. opt.step()
  68. # 更新ema_model
  69. model_ema_opt.step()
  70. if not step % HP.verbose_step: # evaluation
  71. acc1, acc5, eval_loss = evaluate(model, val_loader, criterion_val)
  72. logger.add_scalar("Loss/Dev", eval_loss, step)
  73. logger.add_scalar("Acc1", acc1, step)
  74. logger.add_scalar("Acc5", acc5, step)
  75. if not step % HP.save_step: # save model
  76. model_path = 'model_%d_%d.pth' % (epoch, step)
  77. save_checkpoint(model, ema_model, epoch, opt, os.path.join('./model_save', model_path))
  78. # Bar / tqdm
  79. print('Epoch: [%d/%d], step: %d, Train Loss: %.5f, Dev Loss: %.5f, Acc1: %.3f, Acc5: %.3f' %
  80. (epoch, HP.epochs, step, loss.item(), eval_loss, acc1, acc5))
  81. step += 1
  82. logger.flush()
  83. logger.close()

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/350372
推荐阅读
相关标签
  

闽ICP备14008679号