赞
踩
我们刚刚看到的现象,在GAN训练中非常常见,我们称它为模式崩溃(mode collapse)。或者,模式崩塌、模式坍塌。
在MNIST的案例中,我们希望生成器能够创建代表所有10个数字的图像。当模式崩溃发生时,生成器只能生成10个数字中的一个或部分数字,无法达到我们的要求。
发生模式崩溃的原因尚未被完全理解。许多相关的研究正在进行中,我们选取其中一些相对比较成熟的理论进行讨论。
其中一种解释是,在鉴别器学会向生成器提供良好的反馈之前,生成器率先发现一个一直被判定为真实图像的输出。为此,有人提出一些解决方案,比如更频繁地训练鉴别器。但在实践中,这样做往往效果不佳。这就表明,解决问题的关键不仅在于训练的数量,也在于训练的质量。
在我们的例子中,生成器的损失值不断增加(见2.3.6节),表明它的学习没有进展。可能的原因是,鉴别器没有很好地为它提供有效的反馈。这再次表明,训练质量是一个挑战。接下来,我们将试验一些想法,以提高鉴别器对生成器反馈的质量。
在开始改良之前,先备份之前生成手写数字图像的笔记本。
现在,我们试图通过提高GAN的训练质量,解决模式崩溃和图像清晰度低的问题。有的方法我们已经在第1章改良MNIST分类器时用过。
第一个改良是,使用二元交叉熵BCELoss()代替损失函数中的均方误差MSELoss()。我们在1.3.1节讨论过,在神经网络执行分类任务时,二元交叉熵更适用。相比于均方误差,它更大程度地奖励正确的分类结果,同时惩罚错误的结果。
我们可以做的下一个改良是,在鉴别器和生成器中使用LeakyReLU()激活函数。因为我们所预期的输出值范围为0~1,所以我们只会在中间层后使用LeakyReLU(),最后一层仍保留S型激活函数。我们在1.3.2节已经讨论过LeakyReLU()如何解决梯度消失问题。一般来说,这是一种常用的提高神经网络训练质量的方法。
另一种改良是,将神经网络中的信号进行标准化,以确保它们的均值为0。同时,标准化也可以有效地限制信号的方差,避免较大值引起的网络饱和。在1.3.4节中,我们已经看到LayerNorm()如何对训练产生积极的影响。
下面是一个改良后的鉴别器神经网络的代码。
生成器的代码也进行相同的改良。
还有一种我们之前尝试过的改良是使用Adam优化器(见1.3.3 节)。我们把它同时用于鉴别器和生成器。
让我们看一下采用以上4个改良方案的效果。
遗憾的是,模式崩溃仍然存在。图像的清晰度有所提高,结构更清晰了,但仍然不是一个清楚的数字。
让我们更深入地思考一下如何进一步改良GAN。.
生成过程的起始点是一个种子值。起初,我们用常数值0.5。随后,我们把它改为一个随机值,因为我们知道,对于固定的输入,任何神经网络总会输出相同的结果。也许生成器神经网络觉得,把一个单值转换成784像素来代表一个数字实在太难了。
我们可以通过提供更多的输入种子来降低这种难度。比如,我们可以尝试100个输入节点,每个节点都是一个随机值。让我们在代码中更新生成器的神经网络定义。
再看一下效果。
现在图像更清晰了,看起来也更像手写数字了,具体地说有点像0。遗憾的是,所有生成的图像都是相同的,说明我们还没解决模式崩溃问题。
不要灰心丧气——即便是最顶尖的GAN研究者,也同样面临模式崩溃的问题。
如果我们继续思考,不难想到输入生成器的随机种子和输入鉴别器的种子,不应该是一样的。
现在,让我们分别创建两个生成随机数据的函数。它们看起来很相似,不过一个使用torch.rand(),而另一个使用torch.randn().
下面是改良后的GAN训练循环。
我们看看效果如何。
太赞了!看上去我们已经解决了模式崩溃问题。现在,生成器可以生成不同的数字。图中的形状看起来一个像8,一个像2,还有一个像3。也有的比较模糊,其中一个看起来既像4又像9。
让我们回顾一下到目前为止的进展。我们训练了一个生成器,并能用它画出手写数字图像。即便没有直接看到任何真实的图像,生成的图像也几乎与训练数据看起来没有区别。这真的很酷。更酷的是,只需改变随机种子,训练过的生成器就可以生成多种不同的数字。
这是一个了不起的成绩。有时候,要解决模式崩溃可能非常困难。很多时候,甚至根本找不到有效的解决方案。
让我们观察一下损失图,看看它们是否能提供一些信息。因为现在使用了BCELoss(),所以这些值并不保证在0~1的范围内。我们需要更新鉴别器和生成器的plot_progress() 函数,删除损失值范围的上限,同时添加更多的水平网格线。
下图所示为鉴别器的训练损失值。
由上图可见,损失值迅速下降到接近于0,并一直保持在很低的位置。训练期间,损失值偶尔发生跳跃。这说明生成器和鉴别器之间仍然没有取得平衡。
下图中是生成器的训练损失值。
损失值先是上升,表示在训练早期生成器落后于鉴别器。之后,损失值下降并保持在3左右。记住,与MSELoss不同,BCELoss没有1.0的上限。
这些损失图看起来有些令人失望,因为损失值的范围更广了。不过,它们仍然好于改良之前的损失图。在之前的图中,鉴别器的损失值在下降时没有太大的波动,生成器的损失值在上升时同样非常工整。这些现象看似令人满意,但不断增加的生成器损失值并不是我们希望的。理想的情况应该是,生成器的损失值只在一个有限的平均值附近变化。
一个很好的问题是,如果我们达到了平衡,BCELoss应该是什么?如果我们运行简单的1010 GAN并达到平衡,由于使用BCELoss,我们会看到生成器和鉴别器的损失值都接近于0.69。读者可以自己试试。对一个完全不确定的分类器使用二元交叉熵,根据数学定义可以计算出,理想的损失值为ln 2或0.693。更多内容可以在附录A中找到。
我们成功地解决了模式崩溃的问题,不过,图像质量还有待改良。我们来看看通过增加训练周期 (epoch)来训练更长时间是否有帮助。我们可以很方便地将GAN训练循环与周期外部循环结合起
来。
以下的图像是训练4个周期后,也就是使用所有训练数据4次的生成效果。总共耗时大约30分钟。
图像看起来好多了。如果读者有时间,可以试试训练8个周期,应该需要1小时左右。
事实上,还有更多改良方法有待我们继续探索。但是,由于我们已经解决了模式崩溃的问题,也可以从生成器获得高质量的图像,因此这里就先告一段落了。
读者可能会问,可以解决模式崩溃是不是因为我们在生成器种子中使用了randn()。如果我们还原之前的代码,在GAN架构中只使用最基本的设置,即便我们将种子改为使用randn(),模式崩溃问题依然不会得到解决。解决问题的是多数或者全部改良的组合作用。 例如,仅为大小为100的生成器种子使用randn()并不能解决模式崩溃问题。读者可以自己试试。
读者可能还希望知道,为什么我们满足于尚未像简单的1010GAN那样达到平衡的生成器和鉴别器,本节的损失图显示,鉴别器的损失值迅速下降到接近于0,并保持在低位,而生成器的损失值仍然很高。
在许多真实的GAN场景中,即使没有达到平衡,仍然可以得到一个可以生成高质量图像的生成器。
我们的最终目标是生成看起来逼真的图像。如果能改善这种平衡,我们当然也应该尝试。我们将继续绘制损失图,因为损失图可以帮我们了解训练的实际状况。例如,MNIST损失图告诉我们训练并不是混乱且不稳定的。
到目前为止,我们把GAN的种子当作一个随机数。经过训练后,种子获得了一些有趣的特性。让我们一起来看一下。
假设有种子1(seed1)和种子2(seed2)两个不同的种子。我们可以用它们分别生成图像。
现在,假设seed1和seed2之间有一个中间种子,使用这个种子会生成什么样的图像呢?除此之外,使用在seed1和seed2之间不同位置上的种子又会生成什么样的图像呢?
让我们试一试。首先,我们需要一个以MNIST数据集训练的GAN。我们可以继续使用之前的笔记本。
以下代码将一个随机种子赋值给seed1,以备后用。接着,我们画出生成的图像。
接着,使用seed2重复上述步骤。
并不是每次生成的图像都很清晰。我们需要重复运行上面的代码,直到得到一个较清楚的数字。
以我自己的实验为例,下图是seed1生成的图像,看起来像5。
下图是seed2生成的图像,看起来像3。
接着,让我们通过代码计算seed1与seed2之间距离相等的10个种子。
上面的代码看起来可能比较复杂,不过它做的只是在seed1和seed2之间选择10个点,并以它们为种子生成图像。 下图展示了包含seed1和seed2在内的12个种子生成的图像。
我们可以明显地看出,随着种子从seed1到seed2,图像从5平滑地演变成了3。
让我们再做另外一个实验。如果把种子相加,又会生成什么图像呢?
这段代码很容易理解。一个新的seed3由seed1与seed2相加得到,并输入生成器。
结果图像看起来非常像8。这是合乎情理的,因为我们把5和3重叠,应该也差不多是这个样子。这再次表明了种子一个很好的特性,即种子相加也会造成它们生成的图像的叠加。
我们看到了种子相加的效果。让我们再看看把种子相减会发生什么。
seed1和seed2的差被输入生成器。
结果图像看起来既像5又像6。它看起来并不完全合乎逻辑,至少不像从5的笔画中减去3的笔画。或许种子的特性并没有这么简单。
让我们再试验一个例子。下图中罗列的图像分别由起始种子(seed1和seed2)、插入 (interpolated)种子、总和种子(seed1+seed2),以及差值种子(seed1−seed2)生成。
我们看到,两个起始种子都生成了看起来像9的图像。在它们之间插入的种子也生成了类似9的图像。两个种子的总和种子生成的图像也是9,这也并不令人意外。令人惊讶的是,差值种子生成的图像却成了8。好奇怪呀!
以下是另一个例子,两个种子生成了非常相似的图像,看起来像5,差值种子却生成了一个非常不同的、看起来很像3的图像。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。