赞
踩
上篇文章见:GAN及其变体C_GAN,infoGAN,AC_GAN,DC_GAN(一)
论文:
前一篇论文并没有介绍一个算法,而是给出GAN动态训练的理论理解,一堆公式定理,伤不起,伤不起呀,幸亏有《令人拍案叫绝的Wasserstein GAN》,讲解地很详细,也很通俗易懂,很棒。
在介绍WGAN之前,先介绍两个数学概念KL散度(Kullback-Leibler divergence)和JS散度(Jensen-Shannon divergence):
KL散度又称相对熵,信息散度和信息增益,度量两个概率分布的匹配程度,两个分布差异越大,KL散度就越大,公式定义如下:
有时候将KL散度称为KL距离,但是它并不满足距离的性质,首先KL散度是不对称的,再次KL散度不满足三角不等式
JS散度是相似度衡量指标,衡量了两个概率分布的相似度,如果,
完全相同,那么JS为0,如果完全不相同,则取值为1。假设有两个分布
和
,其JS散度公式为:
JS基于KL散度的变体,解决了KL散度非对称的问题,JS是对称的,并且取值范围为[0, 1]。
下面开始进入正题,原始GAN的损失函数为:
分成两步进行计算,首先,训练判别器D最大化
可以得出最简化的最优判别器为:
接下来将(2)式代入代入(1)式,就可以得到,所以,原始GAN的优化目标经过一定的数学推导后,可以等价于当判别器最优的时候,最小化真实分布
和生成分布
之间JS散度。然而由于
和
几乎不可能有不可忽略的重叠,所以无论它们相距多远,JS散度都是常数log2,最终导致生成器的梯度近似于0,梯度消失。即使是对接近于最优的判别器来说,生成器有很大机会面临梯度消失的问题。总结来说,判别器训练的太好,生成器梯度消失,使得生成器loss降不下去;判别器训练的不好,生成器梯度不准,四处乱跑。所以只有训练器训练的将将好的时候,才能达到要求,然而这个火候很不好把握,即使是同一轮的训练的前后不同阶段,所以GAN才会不好训练。
基于等价优化的衡量标准JS散度不合理。WGAN中提出了Wasserstein距离,Wasserstein距离,又称Earth-Mover(EM)距离,度量两个概率分布之间的距离。定义如下:
其中是联合分布,对于每一个可能的联合分布
而言,从中取样一个真实样本x和一个生成样本y,
,计算出这两个样本之间的距离
,所以可以计算出该联合分布
下样本的期望值,在所有可能的联合分布中能够对这个期望值取得下限,就定义为Wasserstein距离。Wasserstein距离相比KL散度,JS散度的优越性在于,即使两个分布没有重叠,Wasserstein距离仍然能够反映它们的远近。论文中给出KL散度和JS散度是突变的,要么最大或者最小,Wasserstein距离却是平滑的,如果使用梯度下降法优化
这个参数,前两者是提供不了梯度的,Wasserstein距离却是可以的。
上述的Wasserstein公式中的下确界没法直接求解,又经过一系列数学推导,最后变换成如下形式:
由于用到K-Lipschitz函数,公式限制条件为,在此,作者采取了一个简单的做法,每次参数更新后,限制神经网络的所有参数的范围为
。
至此,就构造了一个含参数,最后一层不是非线性激活函数的判别器网络
,在限制
不超过某个范围的条件下,使得
尽可能最大,此时L就会近似真实分布与生成分布之间的Wasserstein距离,接下来生成器要最小化Wasserstein距离,可以最小化L。而且由于Wasserstein距离的优良特性,不用担心生成器梯度消失的问题。如下图是WGAN的算法过程
由于原始GAN的判别器是一个二分类问题,而WGAN中的判别器是去近似拟合Wasserstein 距离,二分类问题就变成了回归任务,所以需要将最后一个的sigmoid函数拿掉。既然Wasserstein 具体可以量化真实分布和生成分布
之间的距离,可以作为训练进程的判别标准,其值越小,表示GAN训练得越好。
总结来说,Wasserstein GAN(WGAN)主要贡献在于:
上述介绍的WGAN虽然能够解决GAN模型训练时的不稳定性问题,但是参数的修剪策略(weight clipping)会导致最优化困难。由于限制权重的范围,使得尝试获得最大梯度范数的神经网络架构常常以学得简单的函数而告终。也就是说,通过权重剪枝实现K-Lipshitz将会趋向更简单的函数,为了展示这个结论,使得真实分布加上unit-variance高斯噪声作为生成分布,作者在几个toy分布上训练WGAN得到最优值,如下图所示,top是使用weight clipping的结果,bottom是gradient penalty的结果,gradient penalty是为weight clipping出现的问题提出的改进策略。
如果一个可微函数为1-Lipschtiz,当且仅当它的gradient with norm <= 1, 那么我们就可以考虑直接限制input到output的gradient norm,, 因此,在原有critic loss上添加对于来自随机的样本
gradient norm的惩罚项,新的目标函数如下:
那么是什么呢?
就是从data distribution
中sample一个点
,从generator distribution
中sample一个点
,然后连接这两个点成一条直线,从这条直线上sample一个点
,作为
中的点。论文给出了利用weight clipping和gradient penalty训练得到的gradient norm(如下图左边)和weight分布(如下图右边),可以看到weight clipping学到的weight主要集中在两个边界值处,而使用gradient penalty学到的weight的分布符合我们的设想。
下面是带gradient penalty的WGAN的算法过程
论文:DualGAN:Unsupervised Dual Learning for Image-to-Image Translation
对偶学习,是出现在机器翻译领域的一种新的学习范式,对偶学习最关键的一点在于给定一个原始任务模型,其对偶任务的模型可以给其提供反馈;同样的,给定一个对偶任务的模型,其原始任务的模型也可以给该对偶任务的模型提供反馈,那么这两个互为对偶的任务可以相互提供反馈,相互学习,相互提高。对偶学习是微软亚洲研究院提出来的,见《对偶学习:一种新的机器学习范式》。
下面以一个中英翻译游戏为例,假设有两个玩家Alice和小明,Alice讲英文,小明讲中文,两人的目地是想要提高中译英和英译中模型的准确度。给定一个英文句子x,Alice首先通过英译中模型f将句子x翻译成中文,并将
传送给小明,小明虽然不知道Alice具体想要表达的意思,但是小明可以判断收到的中文句子
是不是语法正确,符不符合中文的语言模型,这些信息可以帮助小明大概判断英译中模型f是不是做的好,然后小明将这个中文的句子
通过中译英模型g翻译成一个新的英文句子
,并发给Alice,通过比较x和
是不是相似,Alice就能够知道英译中模型f和中译英模型g是不是做的好,尽管x是一个没有标准的句子。可以看到面,这些互为对偶的任务可以形成一个闭环,使从来没有标注的数据中进行学习成为可能。
DualGAN的灵感来源于上述机器翻译中的对偶问题,如下图所示,DualGAN中存在两个生成器和两个判别器,以素描和照片为例,生成器对素描像u进行翻译,其中包含噪声z,翻译结果为
,把这个翻译结果作为生成器
的输入,附上噪声
,翻译结果为
,同理对于图像v。判别器A判别一张图片是否是photo,而判别器B判别一张图片是否是sketch。
论文:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks
cycleGAN与DualGAN模型相似。都包含两个生成器和两个判别器。cycleGAN主要目地是让两个domain里的图片互相转化,思想主要是在不配对的训练样本中,从一类图片中捕获出特定的特征,然后找出如何将这些特征转化成另一类图片(capturing special characteristics of one image collection and figuring out how these characteristics counld be translated into the other image collection, all in the absence of any [aired training samples)
cycleGAN模型学习两个domainX和Y之间的映射,对于训练数据和
,定义两个映射关系,
,和两个判别器
和
,
判别器的作用是区分图片
和转换的图片
,同样地
判别器的作用是判别图片
和转换的图片
。目标方程主要包含两种类型,adversarial losses是为了使得生成的图片的分布和目标域中的数据分布相等,cycle consistency loss的目地是为了防止两个映射G和F互相矛盾。如果我们从domainX转换到domain Y中然后再从domain Y转换到X,应当回到原先开始的地方,看起来这是一个循环。下图(b)中,对来自domain X中的图像
,经过图像转化环之后应当将
还原,我们称之为一个forward cycle-consistency loss:
,下图(c)中,对来自domainY中的图像
,经过图像转化环之后应当将
还原一个backward cycle-consistency loss:
对于映射函数和它的判别器
,损失函数表示为:
同理对于映射函数和它的判别器
,损失函数表示为:
cycle consistency loss表示为
最后,整个目标方程为
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。