赞
踩
现有的半监督学习方法主要有三种:自洽正则化(Consistency Regularization),最小化熵(Entropy Minimization)和传统正则化(Traditional Regularization)。而MixUp同时兼具了这三种方法的优点:集成了自洽正则化,在图像数据增广中使用了对图像的随机左右翻转和剪切(crop);使用“sharpening”函数,最小化未标记数据的熵;使用了单独的权重衰减并使用MixUp作为正则化器(应用于标记数据点)和半监督学习方法。
MixMatch的伪代码如下图所示(图一),接下来将按照步骤详细介绍MixMatch的每一个部分。
图一 MixMatch算法伪代码
1.2.1 数据增强
同时对有标记数据和无标记数据做增强。对一个Batch的有标记数据X和一个Batch的无标记数据U做数据增强,对X做一次增强且标签不变,而对U做K次。
1.2.2 标签猜测
将增强后的未标注数据输入预测模型,模型生成“猜测”标签。为一个Batch中的每一个未标记数据ub的K个增强的猜测标签计算平均值(伪代码第七行所示):
使用Sharpen 算法对上式得到的标签进行处理,得到标签qb。Sharpen 算法具体操作如下:
其中,T为超参数,当T趋近于0时,Sharpen(p, T)i 接近于One-Hot 分布,即对某一类别输出概率1,其他所有类别输出概率0,此时分类熵最低。这很好理解,比如在猫狗二分类中,分类器说,这张图片中50%的几率是猫,50%的几率是狗,对各类别分类概率预测比较平均;而使用Sharpen来使得“伪”标签熵更低,即猫狗分类中,要么百分之九十多是猫,要么百分之九十多是狗。
图二 标签猜测与Sharpen过程
从图中Average到Sharpen的变化也可以看出该操作的作用:使得“伪”标签熵更低,使输出接近于One-Hot 分布。
1.2.3 MixUp
将前两步得到的所有数据增强之后的带标签数据及它们的标签、所有未标注数据及其“猜测”标签整合成以下集合:
将和混合在一起,随机重排得到数据集。最终输出将与做了MixUp() 的一个 Batch 的标记数据,以及与做了MixUp() 的 K 个Batch 的无标记增广数据
。
与之前的Mixup方法不同,MixMatch方法将标记数据与未标记数据做了混合,进行 Mixup。对于两个样本及它们的标签(x1, p1), (x2, p2),混合后的样本为:
其中,权重因子λ’使用超参数α通过Beta函数抽样得到:
关于这个对MixUp的修改,作者给出的解释是需要保持每个Batch中的顺序。这样的操作能让x’更接近于x1而非x2。在Mixup标记数据与混合数据时,这样能增加的权重;在 Mixup 未标记数据与时,这样能增加的权重。
损失函数定义如下:
其中,对于有标签数据,使用Cross Entropy计算Loss。而对于无标签数据,使用L2 Loss。作者对为何无标签数据不使用Cross Entropy Loss而是L2 Loss做出了解释:因为L2 Loss不像Cross Entropy Loss,它是有界的且对错误的预测不太敏感。在文章引用的[25] Temporal ensembling for semi-supervised learning的第三页提供了更详细解释:Cross Entropy 计算是需要先使用 softmax 函数,将Dense Layer输出的类分数转化为类概率,而softmax函数对于常数叠加不敏感,即如果将最后一个Dense Layer的所有输出类分数同时添加一个常数c, 则类概率不发生改变,Cross Entropy Loss不发生改变。因此,如果对未标记数据使用Cross Entropy Loss, 由同一张图片增广得到的两张新图片,最后一层Dense Layer的输出被允许相差一个常数。而使用L2 Loss, 约束更加严格。
最终的整体损失函数是两者的加权,其中超参数λu是无监督学习损失函数的加权因子。
使用到的超参数包括温度参数T,对未标记数据做增强的次数K,MixUp的Beta函数的α以及无监督权重因子λu。作者在实验中发现,这些超参数中的大多数都是可以固定的,不需要对每个实验或每个数据集进行调优。设置T = 0.5, K = 2,只对不同数据集上的α和λu做调整。开始时可以设置α = 0.75, λu= 100。
作者主要进行了三类实验:对比实验在标准半监督学习的基准上测试MixMatch的有效性,消融实验验证MixMatch每个部分的贡献,PATE架构验证MixMatch在隐私保护中的应用。
2.1.1 实验设置
除非特殊说明,在所有实验中,使用的都是Wide ResNet-28模型。在CIFAR-10和CIFAR-100、SVHN和STL-10这四个数据集上进行评估。对比MixMatch和其他四种半监督方法(Π-Model,Mean Teacher,Virtual Adversarial Training和Pseudo-Label),以及MixUp本身,在四个数据集上的错误率。
2.1.2 实验目的
对比实验,对比MixMatch和现有其他方法(5种)在数据集上的错误率,验证MixMatch方法的高性能。
2.1.3 实验结果
CIFAR-10:
使用从250个到4000个不等的带标注数据来评估每种方法的准确性,由均值和方差反应错误率。设置λu= 75。
表一 六种方法在CIFAR-10上的错误率
图三 六种方法在CIFAR-10上的错误率(折线图)
在 CIFAR-10 数据集上,使用全部五万个数据做监督学习,最低误差能降到百分之4.13。而使用MixMatch,250个数据就能将误差降到11%,4000个数据就能将误差降到6.24%。这表明MixMatch使用很少的标记数据点就能达到媲美有监督学习的效果,这正是半监督学习希望达到的效果。此外,从折线图中还可以看到Mean Teacher的错误率的方差是比较大的,中心实线附近还有一大片浅绿色的区域,那片区域就代表算法的表现容易震荡,不稳定。而对比就可以看出MixMatch不仅做到精度最优,还能保证算法本身的稳定性(黑色旁边浅黑色的区域很小)。
CIFAR-10及CIFAR-100(使用更大的模型):
为了与先前工作的结果有合理的比较,使用了有2600万个参数的有28层的Wide ResNet模型。
表二 CIFAR-10及CIFAR-100在大模型上的错误率
由于使用大模型,只将MixMatch和Mean Teacher和SWA做了对比。可以看出MixMatch和先前工作相比,效果相匹配或优于先前工作的最佳结果。
SVHN 及 SVHN+Extra:
和CIFAR-10类似,使用从250到4000个不等的标签数量来评估SVHN上每个方法的性能。设置λu= 250,α = 0.25。
表三 SVHN上的错误率
图四 SVHN上的错误率(折线图)
表四 SVHN+Extra上的错误率
图五 SVHN+Extra上的错误率(折线图)
表五 MixMatch在SVHN及SVHN+Extra上的错误率
SVHN+Extra是将SVHN的额外训练集也组合起来一起训练,这样未标注样本的比例远远超过标注的样本。从结果来看,对于MixMatch,SVHN+Extra上的错误率明显低于SVHN上的错误率。相比其他方法,MixMatch的错误率明显更低,接近监督学习方法(图四、图五)。
STL-10:
STL-10包含5000个训练示例。先前的工作部分使用了全部的5000个标注数据,因此在使用1000/5000个标注的情况下进行对比实验。
表六 STL-10上的错误率
表六中的方法对比没有没有使用相同的实验设置(即模型),因此很难直接比较结果;然而,因为MixMatch的错误率相当于baseline的1/2,因此也能作为证明MixMatch算法有效的证据之一。设置λu= 50。
由于MixMatch结合了各种半监督学习机制,因此,进行消融实验,通过移除或添加组件,进一步了解是哪些部分使得MixMatch表现更好。
具体来说,评估了以下部分的效果:
1)无标签数据的数据增强的次数K;2)移除温度参数T;3)在生成猜测标签时,使用EMA(与Mean Teacher类似);4)只在有标记数据内,只在无标记数据内进行MixUp,并且不混合使用有标记和无标记的数据;5)使用插值一致性训练(Interpolation Consistency Training),这可以被视为本消融研究的一个特例:只对无标记数据进行MixUp,不使用Sharpen函数,EMA方法用于伪标签生成。
消融实验结果如下:
表七 消融实验结果(CIFAR-10)
从结果看出,每个部分都对MixMatch的性能有贡献,其中贡献较大的是MixUp以及Sharpen操作,而使用EMA会略微损害MixMatch的性能。
用于评估方法的泛化性能。并非文章重点,略。
3.2.1 方法优点
提出的MixMatch方法在降低错误率方面效果显著,效果媲美监督学习,优于当时的其他方法。利用很少的标注数据取得媲美监督学习的效果,这正是半监督学习希望实现的。
3.2.2 方法缺点
参考文献:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。