赞
踩
这是一个用Python编写的简单实现生成对抗网络(GAN)的脚本。
该脚本包括两个神经网络:生成器(G)和判别器(D)。生成器采用随机向量作为输入,生成一幅艺术作品。判别器接收一幅艺术作品作为输入,并判断它是否是由真实艺术家所绘制的。两个网络通过对抗学习的方式相互竞争,直到生成器可以生成与真实艺术品相似的作品。
生成器的目标是生成类似于训练数据的“假”数据,而判别器的目标是识别“真实”数据和生成器生成的“假”数据。两个网络通过博弈的方式相互对抗学习,最终生成器可以生成与训练数据相似的新数据。
具体来说,这段代码实现的是一个简单的 GAN,其中生成器(Generator)试图学习如何生成一个类似于二次函数的曲线。代码的第一部分定义了生成器和判别器的神经网络结构。然后,在训练过程中,生成器产生一个“假”数据,判别器评估这个“假”数据和真实数据的相似度,并根据评估结果更新判别器和生成器的权重。这个过程不断重复,直到生成器可以生成与真实数据相似的数据。
这段代码实现的训练过程如下:
1.定义了一个判别器和一个生成器的神经网络结构;
2.在每一步迭代中,生成器生成一个“假”数据,判别器评估这个“假”数据和真实数据的相似度;
3.计算判别器的损失函数,根据损失函数更新判别器的权重;
4.生成器再次生成一个“假”数据,判别器再次评估这个“假”数据和真实数据的相似度;
5.计算生成器的损失函数,根据损失函数更新生成器的权重;
6.重复步骤 2-5,直到生成器可以生成与真实数据相似的数据。
在每 50 步迭代之后,代码还会画出当前生成的曲线,以及真实曲线和生成器曲线的误差。
import os os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' # 防止Intel MKL导致的内存问题 import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt # 超参数 BATCH_SIZE = 64 # 批大小 LR_G = 0.0001 # 生成器学习率 LR_D = 0.0001 # 判别器学习率 N_IDEAS = 5 # 随机向量的维度 ART_COMPONENTS = 15 # 可以绘制的画作中点的数量 PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)]) # 每个画作的点 def artist_works(): # 真实画作的生成函数(真实数据) a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis] paintings = a * np.power(PAINT_POINTS, 2) + (a - 1) paintings = torch.from_numpy(paintings).float() return paintings G = nn.Sequential( # 生成器 nn.Linear(N_IDEAS, 128), # 接收随机向量作为输入 nn.ReLU(), nn.Linear(128, ART_COMPONENTS), # 输出生成的画作 ) D = nn.Sequential( # 判别器 nn.Linear(ART_COMPONENTS, 128), # 接收画作作为输入 nn.ReLU(), nn.Linear(128, 1), nn.Sigmoid(), # 输出画作是真实画作的概率 ) opt_D = torch.optim.Adam(D.parameters(), lr=LR_D) # 判别器优化器 opt_G = torch.optim.Adam(G.parameters(), lr=LR_G) # 生成器优化器 plt.ion() # 连续画图 for step in range(10000): # 训练10000个epoch artist_paintings = artist_works() # 真实画作 G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True) # 随机向量 G_paintings = G(G_ideas) # 生成画作 prob_artist1 = D(G_paintings) # 判别器判别生成画作是否为真实画作 G_loss = torch.mean(torch.log(1. - prob_artist1)) # 生成器的损失 opt_G.zero_grad() # 清除梯度 G_loss.backward() # 反向传播 opt_G.step() # 更新生成器的参数 prob_artist0 = D(artist_paintings) # 真实画作的判别器概率 prob_artist1 = D(G_paintings.detach()) # 生成画作的判别器概率 D_loss = -torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1)) # 判别器的损失 opt_D.zero_grad() # 清除梯度 D_loss.backward(retain_graph=True) # 反向传播计算梯度并更新D的参数,retain_graph=True 表示保留计算图以便进行后续的反向传播 opt_D.step() if step % 50 == 0: # 当step是50的倍数时,进行可视化绘图 plt.cla() # 清空当前画布 plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting', ) # 绘制上界,即真实艺术家的作品 plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound') # 绘制下界,即随机产生的画作 plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound') # 显示判别器D的准确率,即对于真实艺术家的画作,D判断为真实画作的概率 plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(), fontdict={'size': 13}) plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13}) # 显示判别器D的得分,即对于真实艺术家的画作和生成器G生成的画作,D判断为真实画作的概率的负数 plt.ylim((0, 3)); # 设置y轴范围 plt.legend(loc='upper right', fontsize=10); # 添加图例 plt.draw(); # 显示绘制的图形 plt.pause(0.01) # 停顿0.01秒,以便查看图形 plt.ioff() plt.show()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。