当前位置:   article > 正文

【用Python简单实现生成对抗网络(GAN)】_python训练对抗生成网络

python训练对抗生成网络

这是一个用Python编写的简单实现生成对抗网络(GAN)的脚本。

该脚本包括两个神经网络:生成器(G)和判别器(D)。生成器采用随机向量作为输入,生成一幅艺术作品。判别器接收一幅艺术作品作为输入,并判断它是否是由真实艺术家所绘制的。两个网络通过对抗学习的方式相互竞争,直到生成器可以生成与真实艺术品相似的作品。

生成器的目标是生成类似于训练数据的“假”数据,而判别器的目标是识别“真实”数据和生成器生成的“假”数据。两个网络通过博弈的方式相互对抗学习,最终生成器可以生成与训练数据相似的新数据。

具体来说,这段代码实现的是一个简单的 GAN,其中生成器(Generator)试图学习如何生成一个类似于二次函数的曲线。代码的第一部分定义了生成器和判别器的神经网络结构。然后,在训练过程中,生成器产生一个“假”数据,判别器评估这个“假”数据和真实数据的相似度,并根据评估结果更新判别器和生成器的权重。这个过程不断重复,直到生成器可以生成与真实数据相似的数据。

这段代码实现的训练过程如下:

1.定义了一个判别器和一个生成器的神经网络结构;
2.在每一步迭代中,生成器生成一个“假”数据,判别器评估这个“假”数据和真实数据的相似度;
3.计算判别器的损失函数,根据损失函数更新判别器的权重;
4.生成器再次生成一个“假”数据,判别器再次评估这个“假”数据和真实数据的相似度;
5.计算生成器的损失函数,根据损失函数更新生成器的权重;
6.重复步骤 2-5,直到生成器可以生成与真实数据相似的数据。
在每 50 步迭代之后,代码还会画出当前生成的曲线,以及真实曲线和生成器曲线的误差。

import os

os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'  # 防止Intel MKL导致的内存问题

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 超参数
BATCH_SIZE = 64  # 批大小
LR_G = 0.0001  # 生成器学习率
LR_D = 0.0001  # 判别器学习率
N_IDEAS = 5  # 随机向量的维度
ART_COMPONENTS = 15  # 可以绘制的画作中点的数量
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])  # 每个画作的点

def artist_works():  # 真实画作的生成函数(真实数据)
    a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
    paintings = a * np.power(PAINT_POINTS, 2) + (a - 1)
    paintings = torch.from_numpy(paintings).float()
    return paintings


G = nn.Sequential(  # 生成器
    nn.Linear(N_IDEAS, 128),  # 接收随机向量作为输入
    nn.ReLU(),
    nn.Linear(128, ART_COMPONENTS),  # 输出生成的画作
)

D = nn.Sequential(  # 判别器
    nn.Linear(ART_COMPONENTS, 128),  # 接收画作作为输入
    nn.ReLU(),
    nn.Linear(128, 1),
    nn.Sigmoid(),  # 输出画作是真实画作的概率
)

opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)  # 判别器优化器
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)  # 生成器优化器

plt.ion()  # 连续画图

for step in range(10000):  # 训练10000个epoch
    artist_paintings = artist_works()  # 真实画作
    G_ideas = torch.randn(BATCH_SIZE, N_IDEAS, requires_grad=True)  # 随机向量
    G_paintings = G(G_ideas)  # 生成画作
    prob_artist1 = D(G_paintings)  # 判别器判别生成画作是否为真实画作
    G_loss = torch.mean(torch.log(1. - prob_artist1))  # 生成器的损失
    opt_G.zero_grad()  # 清除梯度
    G_loss.backward()  # 反向传播
    opt_G.step()  # 更新生成器的参数

    prob_artist0 = D(artist_paintings)  # 真实画作的判别器概率
    prob_artist1 = D(G_paintings.detach())  # 生成画作的判别器概率
    D_loss = -torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))  # 判别器的损失
    opt_D.zero_grad()  # 清除梯度
    D_loss.backward(retain_graph=True) # 反向传播计算梯度并更新D的参数,retain_graph=True 表示保留计算图以便进行后续的反向传播
    opt_D.step()

    if step % 50 == 0:  # 当step是50的倍数时,进行可视化绘图
        plt.cla()  # 清空当前画布
        plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='#4AD631', lw=3, label='Generated painting', )   # 绘制上界,即真实艺术家的作品
        plt.plot(PAINT_POINTS[0], 2 * np.power(PAINT_POINTS[0], 2) + 1, c='#74BCFF', lw=3, label='upper bound')   # 绘制下界,即随机产生的画作
        plt.plot(PAINT_POINTS[0], 1 * np.power(PAINT_POINTS[0], 2) + 0, c='#FF9359', lw=3, label='lower bound')    # 显示判别器D的准确率,即对于真实艺术家的画作,D判断为真实画作的概率
        plt.text(-.5, 2.3, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(),
                 fontdict={'size': 13})
        plt.text(-.5, 2, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
        # 显示判别器D的得分,即对于真实艺术家的画作和生成器G生成的画作,D判断为真实画作的概率的负数
        plt.ylim((0, 3));    # 设置y轴范围
        plt.legend(loc='upper right', fontsize=10);   # 添加图例
        plt.draw();  # 显示绘制的图形
        plt.pause(0.01)   # 停顿0.01秒,以便查看图形

plt.ioff()
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/94425
推荐阅读
相关标签
  

闽ICP备14008679号