当前位置:   article > 正文

生成对抗网络GAN论文解读及原理分析_生成对抗网络 论文

生成对抗网络 论文

1. 摘要解读

作者提出了一个通过对抗性过程估计生成模型新框架,在该框架中,将同时训练两个模型:

生成模型G:捕捉数据分布。在统计学中,确定数据的分布就可以生成数据。

判别模型D:估计一个样本的概率。分辨这个样本是来自训练数据,还是来自于G生成的。

对于G的训练过程是最大化D犯错的概率。这个框架对应于一个极小极大的两人博弈。

这在GAN的目标函数中可以明确体现,D会希望这个目标函数越大越好,G则希望该目标函数越来越小。G和D的对抗就体现在这个目标函数中。

在任意函数G和D的空间中,存在一个唯一的解,其中G恢复了训练数据分布,而D在任何地方都等于1/2。

人们最终希望G取得胜利,G能够生成(恢复)数据分布,这也代表着D将无法分辨这个样本是来自训练数据,还是来自于G生成的。

在G和D由多层感知机定义的情况下,整个系统可以使用反向传播进行训练。

原始的GAN重在提出了生成对抗的思想,两个模型用的都是简单的多层感知机(MLP)。后期的很多GAN的改进版本会解决原始GAN的诸多问题,如模式崩溃、梯度爆炸等。

在训练或生成样本期间,不需要任何马尔可夫链或展开的近似推理网络。实验证明了该框架通过对生成的样本的定性和定量评估的潜力。

2.GAN的设计思想解读

2.1 生成对抗思想--原文比喻

在所提出的对抗性网络框架中,生成模型与对手对抗:一个判别模型D,学习确定样本是来自模型分布还是来自数据分布。生成模型G可以被认为类似于一个伪造者团队,试图生产假货币并在没有检测到的情况下使用,而判别模型类似于警察,试图检测假货币。这场比赛中的竞争促使两支球队改进他们的方法,直到赝品与真品分不开。

2.2 生成对抗思想--故事解释

这个故事有两个主角,一个造假者和一个警察。造假者试图制造假币并在未被发现的情况下使用它,警察则试图检测假币。见图2-1(因人民币图片违规用彩纸代替)。

图2-1 故事背景

表2-1 前三轮较量情况

第一轮较量

造假者用一张白纸企图糊弄过去。

警察直接没收了纸币,毕竟很容易发现百元大钞是红色的而不是白色的。

第二轮较量

造假者又拿着一张红纸打算蒙混过关。

警察又很快识破了这拙劣的把戏,毕竟上面连数字都没有。

第三轮较量

造假者通过前两次的较量提高了造假技术。他在红色的纸上添加了数字100,并画了个人像。

警察这次有些迟疑,但仔细观察还是发现了问题,并没收了纸币。警察的检测技术得到了提高。

图2-2 前三轮较量情况

经历了n轮较量,警察和造假者互有胜负。同时他们的造假水平和鉴别水平都有了显著提高。

此时,假币上有数字、人头、纹路、凸点等关键信息,已经十分接近真钞了。见图2-3。

图2-3 多轮较量情况

最终,我们希望造假者更胜一筹,造出高质量的假钞。获得胜利。也就是说,警察将无法分辨原始的真钞和生成的假钞。

2.3 框架说明

该框架可以为许多类型的模型和优化算法生成特定的训练算法。

在本文中,生成模型G通过多层感知器传递随机噪声来生成样本的特殊情况,并且判别模型D也是多层感知机。作者将这种特殊情况称为对抗性网络(adversarial nets)。

生成模型G是一个MLP,输入的是一个随机的噪声。这个MLP能够把产生随机噪声的数据分布(通常是高斯分布)映射到任何一个我们想去拟合的分布。同样,判别模型也是一个MLP。

在这种情况下,可以只使用非常成功的反向传播和丢弃算法来训练这两个模型,并只使用正向传播从生成模型中进行采样。不需要近似推理或马尔可夫链。(在计算上更具有优势)

3.生成对抗网络解析

3.1 目标函数求解说明

表3-1 原文第三章第一段的翻译

The adversarial modeling framework is most straightforward to apply when the models are both multilayer perceptrons.

对抗性建模框架最直接的应用是在模型都是多层感知机的情况下。

To learn the generator’s distribution pg over data x, we define a prior on input noise variables Pz(z), then represent a mapping to data space as G(z; θg), where G is a differentiable function represented by a multilayer perceptron with parameters θg.

为了学习生成器对数据x的分布Pg,我们定义了输入噪声变量Pz(z)的先验分布,然后将映射到数据空间的函数表示为G(z; θg),其中G是一个可微分的函数,由多层感知机参数θg表示。

We also define a second multilayer perceptron D(x; θd) that outputs a single scalar.D(x) represents the probability that x came from the data rather than Pg.

我们还定义了一个第二个多层感知机D(x; θd),它输出一个单个标量。D(x)表示x来自(真实)数据而不是来自Pg的概率。

We train D to maximize the probability of assigning the correct label to both training examples and samples from G.

我们训练D以最大化分配正确标签给训练样本和从G生成的样本的概率。

We simultaneously train G to minimize log(1 − D(G(z))).In other words, D and G play the following two-player minimax game with value function V (G, D):

我们训练D以最大化分配正确标签给训练样本和从G生成的样本的概率。我们同时训练G以最小化log(1 − D(G(z)))。换句话说,D和G进行以下两人极小极大博弈,其价值函数为V(G, D):


表3-2 符号说明

下面是生成对抗网络的目标函数。

生成器G的目标:让 V(D,G) 尽可能的变小。

辨别器D的目标:让 V(D,G) 尽可能的变大。

目标函数=两项期望的值的和。V(D,G) 的值域是 (-,0]

我们对该公式进行分析可以发现:

当辨别器D表现良好时,D(x) 会较大,D(G(z)) 会较小。

在最好情况下,辨别器D能完全区分真假,此时:

D(x)max=1 → log(D(x))max=0

D(G(z))min=0 → log(1-D(G(z)))max=0

所以 V(D,G)max=0+0=0

当辨别器D表现不好时,D(x) 会较小,D(G(z)) 会较大。

在最坏情况下,辨别器D不能区分真假,此时:

D(x)min=0 → log(D(x))min=

D(G(z))max=1 → log(1-D(G(z)))min=

所以 V(D,G)min=+=

需要注意的是,当辨别器D表现不好时,并不代表生成器G就表现很好。

最终的理想结果也不是D(x)=0或D(G(z))=1,而是达到一种纳什均衡

表3-3 原文第三章第三段的翻译

In practice, equation 1 may not provide sufficient gradient for G to learn well.

在实践中,方程1(目标函数)可能无法为G提供足够的梯度以进行良好的学习。

Early in learning, when G is poor, D can reject samples with high confidence because they are clearly different from the training data. In this case, log(1 − D(G(z))) saturates.

在学习早期,当G表现较差时,D可以以高置信度拒绝样本,因为它们明显与训练数据不同。在这种情况下,log(1 − D(G(z)))会饱和。

Rather than training G to minimize log(1 − D(G(z))) we can train G to maximize log D(G(z)).

我们可以训练G以最大化log D(G(z)),而不是将G的训练目标设为最小化log(1 − D(G(z)))。

This objective function results in the same fixed point of the dynamics of G and D but provides much stronger gradients early in learning.

这个目标函数会导致G和D动态的相同固定点,但在学习早期提供更强的梯度。

如图3-1所示,黑色虚线为真实数据分布(Px),绿色实线来自生成分布(Pg(G)),蓝色虚线时判别分布(D)。我们同时更新判别分布(D)来训练生成对抗性网络,使得其在来自数据生成分布(Px)的样本与来自生成分布(Pg(G))的样本之间进行区分。

通俗的理解:

生成分布(Pg(G))和真实数据分布(Px)的峰值越接近,则越相似。

在判别分布(D)中,认为该处存在真实数据→值为1,认为该处存在生成(假)数据→值为0,完全无法分辨→值为1/2。

图3-1(a) :一开始,Pg(G) 与 Px 有明显差别,峰值不在一个位置,这表明生成器G的效果并不好 。同时,D也不具备很好的分辨能力,在 Px 的非峰值区域也会输出较高的值。

图3-2(b):辨别器D经过更新后,学会了 Px 和 Pg(G) 的分布,在 Px 的峰值位置置值为1,在 Pg(G) 的峰值位置置值为0。

图3-2(c):生成器经过更新后,Pg(G) 逐渐向 Px 靠拢,D还是可以区分两种分布。

图3-2(d):多次训练更新后,Pg(G) 与 Px 的分布完全重合,这表明生成器G的效果非常好,这也是我们希望看到的。辨别器D则完全无法分辨两个分布,此时D(x)=1/2。

图3-1 模型训练演示

3.2 伪代码说明

图3-2 GAN的伪代码

输入:一个小批量(m个)的噪声样本z,一个小批量(m个)的真实样本x。

更新D:首先更新D中的θd,更新k次,需要调用目标函数中的两项,因为两项中都包含D。

更新G:然后更新G中的θg,更新1次,只需要调用目标函数中的后一项,只有后一项包含G。

收敛:很不稳定,因为存在两个更新。两个都收敛,或是一个波动一个收敛等等情况?无法确定收敛。后续工作会优化这个问题。

3.3 理论结果说明

作者对理论结果进行了说明,主要包括两个部分:

  1. 目标函数存在一个全局最优解:当且仅当是生成器G学到的分布和真实数据的分布是相等的情况。

  1. 伪代码所提出的算法确实可以求解目标函数。

4.代码展示

以下是一个简单的基于PyTorch的生成对抗网络(GAN)的代码示例:

  1. import torch
  2. import torch.nn as nn
  3. import torch.optim as optim
  4. from torchvision import datasets, transforms
  5. from torch.utils.data import DataLoader
  6. from torch.autograd import Variable
  7. # 定义生成器
  8. class Generator(nn.Module):
  9. def __init__(self):
  10. super(Generator, self).__init__()
  11. self.fc1 = nn.Linear(100, 128)
  12. self.fc2 = nn.Linear(128, 392)
  13. self.fc3 = nn.Linear(392, 784)
  14. self.relu = nn.ReLU()
  15. self.tanh = nn.Tanh()
  16. def forward(self, x):
  17. x = self.relu(self.fc1(x))
  18. x = self.relu(self.fc2(x))
  19. x = self.tanh(self.fc3(x))
  20. return x
  21. # 定义判别器
  22. class Discriminator(nn.Module):
  23. def __init__(self):
  24. super(Discriminator, self).__init__()
  25. self.fc1 = nn.Linear(784, 392)
  26. self.fc2 = nn.Linear(392, 128)
  27. self.fc3 = nn.Linear(128, 1)
  28. self.relu = nn.ReLU()
  29. self.sigmoid = nn.Sigmoid()
  30. def forward(self, x):
  31. x = self.relu(self.fc1(x))
  32. x = self.relu(self.fc2(x))
  33. x = self.sigmoid(self.fc3(x))
  34. return x
  35. # 定义训练参数
  36. device = "cuda" if torch.cuda.is_available() else "cpu"
  37. batch_size = 100
  38. num_epochs = 200
  39. learning_rate = 0.0002
  40. # 定义数据集
  41. transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5])])
  42. mnist_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
  43. data_loader = DataLoader(dataset=mnist_data, batch_size=batch_size, shuffle=True)
  44. # 初始化网络
  45. G = Generator().to(device)
  46. D = Discriminator().to(device)
  47. # 定义损失函数和优化器
  48. criterion = nn.BCELoss().to(device)
  49. optimizer_G = optim.Adam(G.parameters(), lr=learning_rate)
  50. optimizer_D = optim.Adam(D.parameters(), lr=learning_rate)
  51. # 训练网络
  52. for epoch in range(num_epochs):
  53. for i, (images, _) in enumerate(data_loader):
  54. # 定义真实数据和生成的数据
  55. real_images = Variable(images.view(batch_size, -1)).to(device)
  56. fake_images = Variable(torch.randn(batch_size, 100)).to(device)
  57. # 训练判别器
  58. optimizer_D.zero_grad()
  59. real_labels = Variable(torch.ones(batch_size, 1)).to(device)
  60. fake_labels = Variable(torch.zeros(batch_size, 1)).to(device)
  61. real_outputs = D(real_images)
  62. fake_outputs = D(G(fake_images))
  63. d_loss = criterion(real_outputs, real_labels) + criterion(fake_outputs, fake_labels)
  64. d_loss.backward()
  65. optimizer_D.step()
  66. # 训练生成器
  67. optimizer_G.zero_grad()
  68. fake_images = Variable(torch.randn(batch_size, 100)).to(device)
  69. fake_outputs = D(G(fake_images))
  70. g_loss = criterion(fake_outputs, real_labels)
  71. g_loss.backward()
  72. optimizer_G.step()
  73. # 输出训练过程
  74. print("Epoch [%d/%d], d_loss: %.4f, g_loss: %.4f" % (epoch+1, num_epochs, d_loss.item(), g_loss.item()))
  75. # 保存模型
  76. torch.save(G.state_dict(), 'generator.pth')
  77. torch.save(D.state_dict(), 'discriminator.pth')

在这个代码示例中,我们首先定义了生成器和判别器两个神经网络。生成器将一个100维的噪声向量映射成784维的向量,代表了一张28x28的灰度图像。判别器则将这个784维的向量映射为一个实数,表示该图像是否为真实的MNIST数据集中的图像。

在训练过程中,我们首先训练判别器,将真实的MNIST图像标记为1,生成的图像标记为0,并计算判别器的损失函数。然后我们训练生成器,将生成的图像标记为1,并计算生成器的损失函数。最后,我们输出训练过程中的判别器损失和生成器损失,并保存训练好的生成器和判别器模型。

在使用这个生成器模型生成新的图像时,我们可以使用以下代码:

  1. import torch
  2. import matplotlib.pyplot as plt
  3. import torchvision.utils
  4. from torch.autograd import Variable
  5. import torch.nn as nn
  6. import numpy as np
  7. # 定义生成器
  8. class Generator(nn.Module):
  9. def __init__(self):
  10. super(Generator, self).__init__()
  11. self.fc1 = nn.Linear(100, 128)
  12. self.fc2 = nn.Linear(128, 392)
  13. self.fc3 = nn.Linear(392, 784)
  14. self.relu = nn.ReLU()
  15. self.tanh = nn.Tanh()
  16. def forward(self, x):
  17. x = self.relu(self.fc1(x))
  18. x = self.relu(self.fc2(x))
  19. x = self.tanh(self.fc3(x))
  20. return x
  21. # 加载模型
  22. device = "cuda" if torch.cuda.is_available() else "cpu"
  23. G = Generator().to(device)
  24. # 尝试加载模型的权重参数
  25. try:
  26. G.load_state_dict(torch.load('generator.pth'))
  27. print('模型参数加载成功!')
  28. except:
  29. print('模型参数加载失败,请检查模型的结构是否一致,或者权重参数是否被正确保存。')
  30. # 切换到评估模式
  31. G.eval()
  32. # 打印出生成器的输入和输出
  33. z = Variable(torch.randn(1, 100)).to(device)
  34. print('生成器的输入:', z)
  35. image = G(z)
  36. image = image.view(28, 28) # Reshape to image format
  37. print('生成器的输出:', image)
  38. # 生成图像
  39. img_grid = torchvision.utils.make_grid(image, nrow=8, normalize=True)
  40. # 将网格状图像转换为numpy数组并交换通道维度
  41. img_grid = img_grid.cpu().numpy().transpose((1, 2, 0))
  42. # 将像素值缩放到[0, 255]范围内,并转换为整数类型
  43. img_grid = (img_grid * 255).astype(np.uint8)
  44. # 调整通道顺序为RGB
  45. img_grid = img_grid[:, :, [2, 1, 0]]
  46. # 显示图像
  47. plt.imshow(img_grid, interpolation='nearest')
  48. plt.axis("off")
  49. plt.show()

这里我们首先加载训练好的生成器模型,然后生成一个100维的噪声向量,并将其输入到生成器中生成一张新的图像。由于我们在训练时对生成的噪声向量进行了归一化处理,因此在生成新的图像时也需要对噪声向量进行归一化处理,否则生成的图像可能会很模糊或失真。

图3-3 输出图片示例

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小舞很执着/article/detail/881347
推荐阅读
相关标签
  

闽ICP备14008679号