赞
踩
目录
笔者已经听说并跑过一些GAN网络,对GAN有了一些基础的认识。但若要我详细地介绍GAN,又仿佛无从说起。借着梳理这篇博客的机会,让我们重新认识一下这位生成领域的大牛——GAN!为了加深对GAN的理解,笔者还会在第三部分详细介绍HIFI-GAN。
当然,在阅读时,最好同时比对GAN的原文进行思考:https://arxiv.org/pdf/1406.2661.pdf
GAN的全称是Generative Adversarial Network,生成对抗网络。在图像、声音生成领域,GAN大放异彩,笔者较为熟悉的HIFI-GAN就是一个广为使用的声码器,用Mel频谱生成wav波形。
GAN通常包含生成模型(generative model)和判别模型(discriminative model)。生成模型的目标是总任务的目标(比如合成图片、音频等),而判别模型的目标是判断输入是来自数据分布的还是模型分布的。
GAN的理论网络架构如下图所示:
generator负责将随机噪声转化为data形式的数据(现在的GAN模型大多数情况还有condition token,因为生成的任务不是单一的,需要prompt提示,比如生成一张“狗”的图片、按照文本生成一段音频等)。discriminator负责针对输入的generator sample或data sample数据,生成一个概率标量,表示输入数据来自data sample的概率。
从目的分析,生成模型G的目的是最大化判别模型D的错误率(让D混淆model sample和data sample),而D的目的则是最小化错误率(换言之,理想下,对data sample的输出为1,对model sample的输出为0)。在训练最开始时,G的能力一定是弱的,此时我们可以训练D。当D训练到对此时模型分布判断较为精准后,可以反过来开始训练G,直到D无法对G生成的模型分布有效判断。这时,我们又可以开始训练D......
如此交替循环往复训练,直到达到纳什均衡(只要别人不改变策略,我就不改变策略)。也就是说,如果G不做改变,D就不再更新参数,反之亦然。具体来说,GAN网络的优化目标函数如下:
其中,分数方程V(D,G)表征了鉴别器对输入的鉴别能力,对来自data的sample,其D(x)越高越好;而对来自model的sample,其D(G(z))越低越好。训练D时,我们的目标就是最大化分数方程;而训练G时,我们的目标就是最小化分数方程。
以上就是GAN的基本训练流程。当然,仅仅是这样我们并无法非常直观地感受GAN的细节与魅力。我们应当从具体设计实例中去仔细分析!
HiFi-GAN是一个常用的vocoder(声码器),可以将Mel谱(通常由前序模型产生)转换为高质量的wav波形。其包括一个生成器和两个判别器(MPD多周期判别器、MSD多尺度判别器)。
其中,生成器的网络架构如下:
如图,生成器是全卷积的神经网络,不断转置上采样,直到输出序列长度与wav序列相同。 每个转置卷积后还加入了MRF模块。共有|K_l|个这样的模块。MRF(多感受野融合)模块是一个并行计算模块,返回|K_r|组残差块输出的总和,这样的设计可以提取不同长度的数据中包含的模式。
当然,这样直接翻译论文的我也是一头雾水,我们一起来看看官方代码:
进行生成时,在102行按照上采样次数进入循环。循环中,首先在104行进行上采样,然后在106行按照kernel进入MRF的循环。代码给出了两种MRF设计方式,我们来看其中一种:
在ResBlock1中,使用带洞卷积和普通卷积交替进行,其中convs1是带洞卷积(dilation),convs2是普通卷积。通过91行和93行这两处代码可知,对每一次上采样,每一组卷积核都创建了MRF,也就是对应的|K_u|和|K_r|。
至此,我们已经完全清楚了生成器的网络架构。
下面来看两个判别器的网络架构如下:
Multi-Scale Discriminator:
MSD的架构来自于MelGAN,用于评估连续音频序列。论文中的MSD是由三个子判别器组成的混合模型,分别在不同的输入尺度上操作:: raw audio, ×2 average-pooled audio, and ×4 average-pooled audio。图(a)所示的就是第二个子MSD,采用了两倍平均池化。
Multi-Period Discriminator:
MPD也是一种由多个子判别器组成的混合模型,每个子判别器只接受等间隔采样p的输入音频。这些子判别器通过观察输入音频的不同部分来捕捉彼此之间的周期隐含结构。论文中将这些周期设置为[2, 3, 5, 7, 11]。图(B)所示即为第二个子MPD,每间隔3个点采样,形成了3×n的矩阵,进一步操作。后续卷积时,在MPD的每个卷积层中,我们将内核大小在宽度轴上限制为1,以独立处理周期性样本(如图中橙色、蓝色、绿色点后续是分别处理的)。
通过这两个判别器的设置,模型不仅可以捕捉连续语音的信息,也可以捕捉周期性的信息,有助于更好地判别真伪。
L(D;G)是固定生成器G时,训练判别器D的损失;L(G;D)是固定D训练G时损失。
梅尔频谱损失是将生成的wav重构回梅尔谱来计算的损失,其中函数就是用来重建梅尔谱的。显而易见,这样的损失函数是合理的。
特征匹配损失是通过测量判别器在真实样本和生成样本之间的特征差异来学习的相似性度量,判别器的每个中间特征都被提取出来,并计算在每个特征空间中真实样本和有条件生成样本之间的L1距离。图中,T表示判别器中的层数;Di表示第i层判别器的特征,Ni表示第i层判别器的特征数量。
分别训练时,生成器和判别器分别使用这些损失函数计算损失。
注意为生成器、鉴别器分别设置optimizer,这样就能在训练时计算保存各自的梯度,进行更新。比如。更新optim_g时,就不会更新判别器,反之亦然。
这是更新判别器D的代码,使用了两个判别器的损失。
这是更新判别器G的代码,使用了两个判别器的损失,特征匹配损失和梅尔谱重构损失。
HiFi-GAN成功使用了GAN结构实现了高质量的音频生成算法,并且每个batch中都分别训练了生成器和鉴别器。三四年过去,HiFi-GAN仍是学界业界脍炙人口的声码器,可见GAN网络结构的强大之处。
1. https://arxiv.org/pdf/2010.05646.pdf
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。