当前位置:   article > 正文

★理解生成对抗网络GAN的训练过程_简述 gan 的训练过程。

简述 gan 的训练过程。

生成器和判别器本质是两个独立的网络,因此训练的时候交替独立训练。(既“交替”又“独立”)

                    

        其训练机理为:生成器和判别器单独交替训练(先训练判别器 -->  训练生成器 --> 训练判别器... )。步骤如下:

首先你要知道:损失函数就是为了将两者的距离拉进,例如loss(A, 1):就是为了将A通过反向传播后更接近于1

1. 训练判别器(最大化鉴别器的损失):

  • 固定生成器的参数,真实图像x 输入判别器后输出的结果标签为1,【使D(x)为1】
    代码为:loss_d_real = loss_func(d_real, torch.ones([batch_size, 1]))
  • 随机噪声 z 输入生成器得到假图像 G(z),再输入判别器后得到的输出结果标签为0,【使D(G(z))为0,也就是1-D(G(z))为1】
    loss_d_fake = loss_func(d_fake, torch.zeros([batch_size, 1]))
  • 训练判别器到收敛。

      

2. 训练生成器(最小化生成器的损失):

    固定判别器的参数,随机噪声z输入生成器得到的假图像G(z),然后输入判别器得到的结果的标签为1,(使D(G(z))为1,看起来有驳常理,但是这是为了迷惑鉴别器)
loss_G = loss_func(d_g_fake, torch.ones([batch_size, 1]))
    训练生成器到收敛。

3. 交替循环步骤1和2,当然也可以在不收敛的过程中交替训练。

代码进一步理解上述过程

以下为代码(代码和上面解释的部分对应着来看)

  1. def train():
  2. G_mean = []
  3. G_std = [] # 用于记录生成器生成的数据的均值和方差
  4. data_mean = 3
  5. data_std = 1 # 目标分布的均值和方差
  6. batch_size = 64
  7. g_input_size = 16
  8. g_output_size = 512
  9. epochs = 1001
  10. d_epoch = 1 # 每个epoch判别器的训练轮数
  11. # 初始化网络
  12. D = Discriminator()
  13. G = Generator()
  14. # 初始化优化器和损失函数
  15. d_learning_rate = 0.01
  16. g_learning_rate = 0.001
  17. loss_func = nn.BCELoss() # - [p * log(q) + (1-p) * log(1-q)]
  18. optimiser_D = optim.Adam(D.parameters(), lr=d_learning_rate)
  19. optimiser_G = optim.Adam(G.parameters(), lr=g_learning_rate)
  20. plt.ion()
  21. for epoch in range(epochs):
  22. G.train()
  23. # 1 训练判别器d_steps次
  24. for _ in range(d_epoch):
  25. # 1.1 真实数据real_data输入D,得到d_real
  26. real_data = torch.tensor(np.random.normal(data_mean, data_std, (batch_size, g_output_size)), dtype=torch.float)
  27. d_real = D(real_data)
  28. # 1.2 生成数据的输出fake_data输入D,得到d_fake
  29. g_input = torch.rand(batch_size, g_input_size)
  30. fake_data = G(g_input).detach() # detach:只更新判别器的参数
  31. d_fake = D(fake_data)
  32. # 1.3 计算损失值 ,判别器学习使得d_real->1、d_fake->0
  33. loss_d_real = loss_func(d_real, torch.ones([batch_size, 1]))
  34. loss_d_fake = loss_func(d_fake, torch.zeros([batch_size, 1]))
  35. d_loss = loss_d_real + loss_d_fake
  36. # 1.4 反向传播,优化
  37. optimiser_D.zero_grad()
  38. d_loss.backward()
  39. optimiser_D.step()
  40. # 2 训练生成器
  41. # 2.1 G输入g_input,输出fake_data。fake_data输入D,得到d_g_fake
  42. g_input = torch.rand(batch_size, g_input_size)
  43. fake_data = G(g_input)
  44. d_g_fake = D(fake_data)
  45. # 2.2 计算损失值,生成器学习使得d_g_fake->1
  46. loss_G = loss_func(d_g_fake, torch.ones([batch_size, 1]))
  47. # 2.3 反向传播,优化
  48. optimiser_G.zero_grad()
  49. loss_G.backward()
  50. optimiser_G.step()
  51. # 2.4 记录生成器输出的均值和方差
  52. G_mean.append(fake_data.mean().item())
  53. G_std.append(fake_data.std().item())
  54. if epoch % 10 == 0:
  55. print("Epoch: {}, 生成数据的均值: {}, 生成数据的标准差: {}".format(epoch, G_mean[-1], G_std[-1]))
  56. print('-' * 10)
  57. G.eval()
  58. draw(G, epoch, g_input_size)
  59. plt.ioff()
  60. plt.show()
  61. plt.plot(G_mean)
  62. plt.title('均值')
  63. plt.savefig('gan_mean.jpg')
  64. plt.show()
  65. plt.plot(G_std)
  66. plt.title('标准差')
  67. plt.savefig('gan_std.jpg')
  68. plt.show()
  69. if __name__ == '__main__':
  70. train()

4. 总结

通过一个判别器而不是直接使用损失函数来进行逼近,更能够自顶向下地把握全局的信息。比如在图片中,虽然都是相差几像素点,但是这个像素点的位置如果在不同地方,那么他们之间的差别可能就非常之大。

比如上图10中的两组生成样本,对应的目标为字体2,但是图中上面的两个样本虽然只相差一个像素点,但是这个像素点对于全局的影响是比较大的,但是单纯地去使用使用损失函数来判断,那么他们的误差都是相差一个像素点,而下面的两个虽然相差了六个像素点的差距(粉色部分的像素点为误差),但是实际上对于整体的判断来说,是没有太大影响的。但是直接使用损失函数的话,却会得到6个像素点的差距,比上面的两幅图差别更大。而如果使用判别器,则可以更好地判别出这种情况(不会拘束于具体像素的差距)

总之GAN是一个非常有意思的东西,现在也有很多相关的利用GAN的应用,比如利用GAN来生成人物头像,用GAN来进行文字的图片说明等等。

引用:通俗理解生成对抗网络GAN - 知乎
 

最简单易懂的GAN(生成对抗网络)教程:从理论到实践(附代码) | 雷峰网

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/712453
推荐阅读
相关标签
  

闽ICP备14008679号