当前位置:   article > 正文

生成对抗网络(GAN)

生成对抗网络

基本概念

生成器

生成对抗网络中的一部分,用于从随机数据或其他分布的数据中生成一个与训练集类似的数据.。

判别器

也是生成对抗网络中的一部分,用于识别出哪些是有生成网络生成的“假”数据,哪些是真正的训练集数据。

生成对抗网络

在这里插入图片描述
生成对抗网络的目的就是生成不存在的数据,类似于让人工智能拥有想象力。

在一个图像生成对抗网络中,生成器负责生成假图像,这里认为训练集是真图像,而判别器负责判别图像真假。“对抗”的含义就是生成器通过不断的训练尽可能的生成以假乱真的图像,判别器通过不断的识别尽可能的区分图像的真假。

生成器的训练过程,简单来说就是,给定网络,给定标签,然后更新输入(这里是随机数据,可以符合一定的分布),使输出图像对应的标签逐渐靠近给定的标签;判别器的训练过程和一般神经网络的训练类似。

总结一下,模型通过 “ Generator生成器” 生成图像,并与真实图像一起输入 “Discriminator判别器” 进行判别。“Discriminator判别器” 通过 loss 以“真实图像为真,生成图像为假” 进行参数更新,而 “ Generator生成器” 通过 loss 以 “生成图像为真” 进行参数更新。

测试代码

import tensorflow as tf
import matplotlib.pyplot as plt

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape((-1, 28, 28, 1)).astype("float32")
# 由于tanh函数的值域为(-1, 1),故这里应该将像素值统一到(-1, 1)
train_images = (train_images - 127.5) / 127.5
# 将图像数组转换成张量
datasets = tf.data.Dataset.from_tensor_slices(train_images)
# 将这60000个张量打乱,每一批有256个张量
datasets = datasets.shuffle(60000).batch(256)
# 参数
n_dim = 100
batch_size = 256
epochs = 10
# 生成器
# 输入格式(batch_size, 100) => 输出格式(batch_size, 28, 28, 1)
generator = tf.keras.Sequential([
    tf.keras.layers.Dense(3 * 3 * 512),
    tf.keras.layers.Reshape([3, 3, 512]),
    tf.keras.layers.Conv2DTranspose(256, 3, 2, activation="relu"),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(128, 2, 2, activation="relu"),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(64, 2, 2, activation="tanh"),
    tf.keras.layers.Dense(1),
    tf.keras.layers.Reshape([28, 28, 1])
])
# 辨别器
# 输入格式(batch_size, 28, 28, 1) => 输出格式(batch_size, 1)
discriminator = tf.keras.Sequential([
    tf.keras.layers.Reshape([28, 28, 1]),
    tf.keras.layers.Conv2D(64, 4, 2, activation="relu"),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(128, 4, 3, activation="relu"),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(256, 4, 3, activation="relu"),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(1),
])
# 优化器
g_optimizer = tf.optimizers.Adam(learning_rate=1e-4)
d_optimizer = tf.optimizers.Adam(learning_rate=1e-4)
# 损失函数
# 判别器的输出用的是二分类的one-hot编码
loss = tf.losses.BinaryCrossentropy(from_logits=True)
# 训练模型
for i in range(epochs):
    for real_images in datasets:
        noise = tf.random.normal([batch_size, n_dim])
        with tf.GradientTape() as g_tape, tf.GradientTape() as d_tape:
            fake_images = generator(noise, training=True)
            fake_out = discriminator(fake_images, training=True)
            real_out = discriminator(real_images, training=True)
            g_loss = loss(tf.ones_like(fake_out), fake_out)
            d_loss = loss(tf.zeros_like(fake_out), fake_out) + loss(tf.ones_like(real_out), real_out)
        g_gradient = g_tape.gradient(g_loss, generator.trainable_variables)
        d_gradient = d_tape.gradient(d_loss, discriminator.trainable_variables)
        g_optimizer.apply_gradients(zip(g_gradient, generator.trainable_variables))
        d_optimizer.apply_gradients(zip(d_gradient, discriminator.trainable_variables))
# 预测结果
test_noise = tf.random.normal([16, n_dim])
pre_images = generator(test_noise, training=False)
fig = plt.figure(1)
for i in range(16):
    plt.subplot(4, 4, i+1)
    plt.imshow((pre_images[i, :, :, 0] + 1)/2)
    plt.axis('off')
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69

测试结果

在这里插入图片描述
在这里插入图片描述
第一张图是训练集选取的部分图片,第二张图片是由生成器产生的图片,可以看出,通过多次训练,生成可以生成比较接近训练集图片的图片。这里只训练了10轮,通过增加训练轮数,生成器产生的图像会更加接近训练集的图像。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/492804
推荐阅读
相关标签
  

闽ICP备14008679号