赞
踩
伪标签半监督学习方法中,伪标签的选择不容易,在模型训练初期容易选出误差较大的伪标签导致模型性能不佳;unsupervised loss中的权重系数不好确定。
将当前用于半监督学习的主要方法相结合,以生成一种新算法 MixMatch,该算法猜测数据增强未标记示例的低熵标签,并使用 MixUp 混合标记和未标记数据。
只采用了250个标签,就减小了错误率。
半监督学习(SSL) 旨在通过允许模型利用未标记数据来很大程度上减轻对标记数据的需求。其中一种半监督学习方法是在损失函数中添加一个损失项,该损失项是在未标记的数据上计算的,并鼓励模型更好地泛化到看不见的数据。
其中损失项属于以下三类之一:
熵最小化——它鼓励模型对未标记的数据输出可信的预测;
一致性正则化——它鼓励模型在其输入受到扰动时产生相同的输出分布;
通用正则化——它鼓励模型很好地泛化并避免过度拟合训练数据。
MixMatch,是一种 SSL 算法。它引入了一个单一的损失,可以将以上损失项统一到一个半监督学习方法中。与以前的方法不同,MixMatch 为未标记数据引入了一个统一的损失项,可以无缝地降低熵,同时保持一致性并与传统的正则化技术保持兼容。
标签猜测,目的是达到和伪标签相同的作用,但与伪标签不同。
对于未标记数据集中的每个未标记示例,MixMatch 使用模型的预测为示例的标签生成一个“猜测”。这个猜测后来被用在无监督损失项中
锐化
受半监督学习中熵最小化的启发,在生成标记猜测时,我们执行了一个额外的步骤。给定对增强数据 的平均预测,应用锐化函数来降低标签分布的熵。 在实践中,通过调节“温度”的方式来控制锐化函数的使用范畴。
:labeled data, :unlabeled data
有标记数据和无标记数据所取batch大小相同,其中无标记数据会经过K个增广
算法步骤
首先对有标记的每个样本做数据增广;
接着对无标记的每个样本做K次数据增广,文中K=2,即做两次数据增广;
将经过增广的无标记数据送入模型,每一个数据会预测一个结果;对结果取均值,再求Sharpen,这就是label guessing的操作。
上述操作后,得到两份数据,第一份:增广后有标记的数据;第二份:增广后有标记的无标记数据(这里的标记是经过训练猜测出来的)
将有标记的和无标记的数据concat起来,组合成一个大的batch,然后将其随机打乱,然后跟原始数据混合起来(MinUp)。
混合之后的数据再进入模型进行前向计算,最后求loss。
其中loss function定义:
an exponential moving average of model parameter values, it provides a more stable target and was found empirically to significantly improve results.
通过指数平均来调整模型,使模型更稳定
we only change α and on a per-dataset basis; we found that = 0.75 and = 100 are good starting points for tuning. In all experiments, we linearly ramp up to its maximum value over the first 16,000 steps of training as is common practice.
ramp参数是通过实验试出来的经验值,线性缓慢地使其增长到最大值(16000步)。训练初期,的值不能过大,不然会导致模型训练出现问题。
Label guessing:
Sharpening:
As T → 0, the output of Sharpen(p, T) will approach a Dirac (“one-hot”) distribution. T → , the output of Sharpen approach to the same. (各自的区别就无法体现了
首先从Beta中抽样,接着在 和1-中选择最大值,通过乘上和1-,将x1和x2混合得到,标签p也是同样的方式混合。
- def mixup(x, u, u2, trg_x, out_u, out_u2, alpha=0.75):
- """
- mixup: eq.8-11, algorithm: line 12-14
- :param x: labeled x (input data [N, 3, H, W]
- :param u: the first unlabeled data (第一次增广得到的未标记数据 [N, 3, H, W]
- :param u2: the second unlabeled data [N, 3, H, W]
- :param trg_x: labeled x target(y), [N, ] = [0, 7, 8....]这里数据的意思就是比如第一个类别的index,第二个类别的index...
- :param out_u: label guessing以后得到的 q_b, after label guessing
- :param out_u2: q_b
- :param alpha: Beta hupe
- :return: mix up result: x:[3*N, 3, H, W], y:[3*N, 10]
- """
- batch_size = x.size(0) # batch_Size = HP.batch_size
- n_classes = out_u.size(1) # classes number:10
- device = x.device
-
- # [0.1, 0.3, 0.01,...] dim=10
- # 类别index 8 无法与之相加,因此需要做one-hot dim=10
- # target x back to onehot
- trg_x_onehot = torch.zeros(size=(batch_size, n_classes)).float().to(device)
- # [0, 0., 0., 0, 0., 0, 0., 0, 0.,0]
- # trg[7]
- # [0, 0., 0., 0, 0., 0, 0., 1., 0.,0]
- trg_x_onehot.scatter_(1, trg_x.view(-1, 1), 1.)
-
- # mixup
- x_cat = torch.cat([x, u, u2], dim=0)
- trg_cat = torch.cat([trg_x_onehot, out_u, out_u2], dim=0)
- n_item = x_cat.size(0) # N*3 (batch size 的维度改变
- lam = np.random.beta(alpha, alpha) # eq.8
- lam_prime = max(lam, 1-lam) # eq.9
- # 随机一个index
- rand_idx = torch.randperm(n_item) # a rand index sequence:[0, 2, 1], [1, 0, 2] <-如果传入的是3, 那么就随机出这样的index
- x_cat_shuffled = x_cat[rand_idx] # x2 随机得到的
- trg_cat_shuffled = trg_cat[rand_idx] # target也需要对应的随机打乱 p2
-
- x_cat_mixup = lam_prime * x_cat + (1 - lam_prime) * x_cat_shuffled # eq.10
- trg_cat_mixup = lam_prime * trg_cat + (1 - lam_prime) * trg_cat_shuffled # eq.11
-
- return x_cat_mixup, trg_cat_mixup
本质上就是一个target和预测输出之间求交叉熵,supervised loss
本质上就是一个均方误差loss,一致性规范?(consistency regulation)
- class MixUpLoss(nn.Module):
- def __init__(self):
- super(MixUpLoss, self).__init__()
-
- def forward(self, output_x, trg_x, output_u, trg_u):
- """
- loss function: eq.2-4, eq.5 explain in trainer
- :param output_x: mixuped x output-shape[N, 10]
- :param trg_x: trg_x-mixuped target-shape[N, 10]
- :param output_u:[x, u, u2], size 3*N, mixuped u output =>shape [2*N, 10]
- :param trg_u: mixuped target u output shape[2*N, 10]
- :return: Lx, Lu
- """
-
- # cross-entropy, supervised loss
- Lx = -torch.mean(torch.sum(F.log_softmax(output_x, dim=-1)*trg_x, dim=-1)) # dim=-1 到最后一个维度求和
- Lu = F.mse_loss(output_u, trg_u) # consistency regulation
- return Lx, Lu
- # train func
- def train():
- parser = ArgumentParser(description='Model Training')
- parser.add_argument(
- '--c',
- default=None,
- type=str,
- help='train from scratch or resume from checkpint'
- )
- args = parser.parse_args()
-
- # new models: model/ema_model
- model = WideResnet50_2()
- model = model.to(HP.device)
- ema_model = new_ema_model()
- model_ema_opt = WeightEMA(model, ema_model)
-
-
- # loss
- criterion_val = nn.CrossEntropyLoss() # for eval
- criterion_train = MixUpLoss() # for training
-
- opt = optim.Adam(model.parameters(), lr=HP.init_lr, weight_decay=0.001) # optimizer with L2 regular
-
- start_epoch, step = 0, 0
- if args.c:
- checkpoint = torch.load(args.c)
- model.load_state_dict(checkpoint['model_state_dict'])
- model.load_state_dict(checkpoint['ema_model_state_dict'])
- opt.load_state_dict(checkpoint['optimizer_state_dict'])
- start_epoch = checkpoint['epoch']
- print('Resume From %s.' % args.c)
- else:
- print('Training from scratch!')
-
- model.train()
- eval_loss = 0.
- # 因为在半监督中label data很少,所以按照unlabeled data来算step
- n_unlabeled = len(unlabeled_trainloader) # as regist count for training step
-
- # train loop
- for epoch in range(start_epoch, HP.epochs):
- print('Start epoch: %d, Step: %d' % (epoch, n_unlabeled))
- for i in range(n_unlabeled): # one unlabeled data turn as an epoch
- # inputs_x:[N, 3, H, W], trg_x:[N, ]
- inputs_x, trg_x = next(iter(labeled_trainloader)) # get one batch from a labeled dataloader
- # inputs_u / inputs_u2 -> [N, 3, H, W]
- (inputs_u, inputs_u2), _ = next(iter(unlabeled_trainloader))
- inputs_x, trg_x, inputs_u, inputs_u2 = inputs_x.to(HP.device), trg_x.long().to(HP.device), \
- inputs_u.to(HP.device), inputs_u2.to(HP.device)
-
- # $$$$$$$$$$$$$$$ Algorithm Line7-Line 8 $$$$$$$$$$$$$$$ label guessing
- with torch.no_grad():
- out_u = model(inputs_u) # Aug K=1, inference [N, 10]
- out_u2 = model(inputs_u2) # Aug K=2, inference [N, 10]
-
- q = label_guessing(out_u, out_u2) # average post distribution [N, 10]
- q = sharpen(q, T=HP.T) # [N, 10]
-
- # $$$$$$$$$$$$$$$ Algorithm Line10-Line 15 $$$$$$$$$$$$$$$ Mixup
- # u和u2本质上是同一份数据经过两次增广得到的, 因此labelguessing的结果应该是一样的
- # mixuped_x:[3*N, 3, H, W]三组数据concat在一起,mixuped_out:[3*N, 10]
- mixuped_x, mixuped_out = mixup(x=inputs_x, u=inputs_u, u2=inputs_u2, trg_x=trg_x, out_u=q, out_u2=q)
-
- # model forward前向
- mixuped_logits = model(mixuped_x) # [3*N, 10]
- # labeled data 和unlabeled data要分开,因为求loss的时候是两部分组成
- logits_x = mixuped_logits[: HP.batch_size] # [N, 10]
- logits_u = mixuped_logits[HP.batch_size:] # [2*N, 10]
- # eq.2-5
- loss_x, loss_u = criterion_train(logits_x, mixuped_out[: HP.batch_size],
- logits_u, mixuped_out[HP.batch_size:])
- loss = loss_x + lambda_rampup(step, max_v=HP.lambda_u) * loss_u # eq.5
-
- logger.add_scalar('Loss/Train', loss, step)
- opt.zero_grad() # 梯度清零
- # 反向求导
- loss.backward()
- opt.step()
- # 更新ema_model
- model_ema_opt.step()
-
- if not step % HP.verbose_step: # evaluation
- acc1, acc5, eval_loss = evaluate(model, val_loader, criterion_val)
- logger.add_scalar("Loss/Dev", eval_loss, step)
- logger.add_scalar("Acc1", acc1, step)
- logger.add_scalar("Acc5", acc5, step)
-
- if not step % HP.save_step: # save model
- model_path = 'model_%d_%d.pth' % (epoch, step)
- save_checkpoint(model, ema_model, epoch, opt, os.path.join('./model_save', model_path))
-
- # Bar / tqdm
- print('Epoch: [%d/%d], step: %d, Train Loss: %.5f, Dev Loss: %.5f, Acc1: %.3f, Acc5: %.3f' %
- (epoch, HP.epochs, step, loss.item(), eval_loss, acc1, acc5))
- step += 1
- logger.flush()
- logger.close()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。