当前位置:   article > 正文

使用生成对抗网络(GAN)实现对图像的生成_基于生成对抗网络的图像生成

基于生成对抗网络的图像生成

目录

前言

一、GAN模型简介

二、Fashion MNIST数据集简介

三、算法实现

1.导入必要的库

2.下载并展示数据集

3.数据的预处理

4.定义生成器

5.定义判别器

6.构建模型

7.训练模型

四、总结

参考资料:


前言

        生成对抗网络(GAN)是一种无监督学习模型,它可以生成与真实数据相似的假数据,其应用非常广泛。本文基于python,使用生成对抗网络(GAN模型)对Fashion MNIST数据集中的图像,进行了生成。

一、GAN模型简介

        GAN的英文全称为:Generative Adversarial Networks,这是一种生成模型,它由Goodfellow等人于2014年提出。

        GAN由两个神经网络组成:生成器(G)和判别器(D)。生成器用于生成假数据;判别器用于判断数据的真假。两个网络相互对抗又彼此促进,生成器生成的假数据越来越逼真,而判别器的判断能力也越来越强。最终,生成器生成的假数据足以骗过判别器,达到了生成真实数据的目的。就像在草原上,狮子为了生存,需要捕捉到斑马,就要跑得比斑马更快;而斑马为了生存,需要逃避狮子的追捕,就要跑得比狮子更快,所以狮子和斑马都会跑得越来越快。

二、Fashion MNIST数据集简介

        Fashion-MNIST是一个服装分类数据集,有如下表所示的10个类别,每个类别都包含训练集(6k个图像)和测试集(1k个图像),故训练集与测试集的图像分别共有6万张和1万张。

t-shirttrouserpulloverdresscoatsandalshirtsneakerbagankle boot
T恤牛仔裤套衫裙子外套凉鞋衬衫运动鞋短靴

三、算法实现

1.导入必要的库

  1. import numpy as np
  2. import matplotlib.pyplot as plt
  3. import tensorflow as tf
  4. from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
  5. from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
  6. from tensorflow.keras.layers import LeakyReLU
  7. from tensorflow.keras.models import Sequential, Model
  8. from tensorflow.keras.optimizers import Adam

2.下载并展示数据集

  1. # 下载数据集
  2. (X_train, y_train), (_, _) = tf.keras.datasets.fashion_mnist.load_data()
  3. # 定义类别的名字
  4. class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
  5. 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
  6. # 创建图片
  7. fig, axes = plt.subplots(3, 3, figsize=(8, 8))
  8. axes = axes.ravel()
  9. # 随机选择9张图片
  10. for i in np.arange(0, 9):
  11. index = np.random.randint(0, len(X_train))
  12. axes[i].imshow(X_train[index], cmap='gray')
  13. axes[i].set_title(class_names[y_train[index]])
  14. axes[i].axis('off')
  15. plt.savefig("服装分类数据集示例.png")
  16. # 显示图片
  17. plt.show()

3.数据的预处理

        对训练数据进行归一化处理,将像素值缩放到了[-1,1],并将图像的通道数从1变为3,以便与模型的输入形状匹配。

  1. # 归一化数据
  2. X_train = X_train / 127.5 - 1.
  3. X_train = np.expand_dims(X_train, axis=3)

4.定义生成器

        该模型输入一个形状为 (100,) 的噪声向量,并输出一个形状为 (28, 28, 1) 的图像。包含了四个全连接层,前三个全连接层后面都跟着一个斜率为 0.2 的 LeakyReLU 激活函数和一个批量归一化层,最后一个全连接层具有 tanh 激活函数,输出一个范围在 -1 到 1 之间的值(生成图像的像素值)。最终,输出的图像形状被重塑为 (28, 28, 1)。

  1. def build_generator():
  2. model = Sequential()# 创建了一个序列模型
  3. model.add(Dense(256, input_dim=100))# 添加全连接层,输入维度为100,输出维度为256
  4. model.add(LeakyReLU(alpha=0.2))# 添加LeakyReLU激活函数层
  5. model.add(BatchNormalization(momentum=0.8))# 添加批量归一化层
  6. model.add(Dense(512))
  7. model.add(LeakyReLU(alpha=0.2))
  8. model.add(BatchNormalization(momentum=0.8))
  9. model.add(Dense(1024))
  10. model.add(LeakyReLU(alpha=0.2))
  11. model.add(BatchNormalization(momentum=0.8))
  12. model.add(Dense(784, activation='tanh'))
  13. model.add(Reshape((28, 28, 1)))
  14. noise = Input(shape=(100,))# 定义输入层,维度为100
  15. img = model(noise)# 生成图像
  16. return Model(noise, img)

5.定义判别器

        该模型的输入是一个(28, 28, 1)的图像,输出一个在[0,1]区间的概率值。先将输入的图像在 Flatten()层将其展平,然后通过三个全连接层,最后输出一个概率分数,来判别输入图像的真假。

  1. def build_discriminator():
  2. model = Sequential()
  3. model.add(Flatten(input_shape=(28, 28, 1)))# 将输入的28*28*1的图像展平为一维向量
  4. model.add(Dense(512))
  5. model.add(LeakyReLU(alpha=0.2))
  6. model.add(Dense(256))
  7. model.add(LeakyReLU(alpha=0.2))
  8. model.add(Dense(1, activation='sigmoid'))
  9. img = Input(shape=(28, 28, 1))
  10. validity = model(img)
  11. return Model(img, validity)

6.构建模型

  • 生成器模型生成噪声,然后用于生成假图像。
  • 判别器模型随后对真实和假图像进行训练,以区分它们。
  • 组合模型用于训练生成器生成更逼真的图像。
  1. # 构建生成器
  2. generator = build_generator()
  3. z = Input(shape=(latent_dim,))# 生成噪声
  4. img = generator(z)
  5. # 构建判别器
  6. discriminator = build_discriminator()
  7. discriminator.compile(loss='binary_crossentropy',optimizer=Adam(0.0002, 0.5),metrics=['accuracy'])
  8. discriminator.trainable = False# 固定判别器的权重
  9. # 判别器判断真假
  10. valid = discriminator(img)
  11. # 构建组合模型
  12. combined = Model(z, valid)
  13. combined.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))

7.训练模型

        定义以下四个超参数:

  1. latent_dim = 100 # 噪声向量的维度
  2. epochs = 10001 # 训练的轮数。为后续方便显示第一个生成图像和最后一个生成图像,故在训练1w轮后,再训练了一轮
  3. batch_size = 128 # 每个训练批次的大小
  4. sam_inter = 1000 # 图像展示频率。即每隔多少轮训练,就展示一次生成器生成的图像。

        训练过程大致分为以下三个部分:

①从Fashion MNIST数据集中随机选择一批真实数据,生成一批噪声向量,用生成器生成一      批假数据。

②判别器分别判断这些真实数据和假数据的真假,并计算出它们的损失值。

③根据损失值更新判别器和生成器的权重。

        这个过程不断重复,直到达到指定的训练轮数(epoch)。

  1. for epoch in range(epochs):
  2. '''训练判别器'''
  3. # 随机选择一批真实图片
  4. idx = np.random.randint(0, X_train.shape[0], batch_size)
  5. imgs = X_train[idx]
  6. # 生成一批假图片
  7. noise = np.random.normal(0, 1, (batch_size, latent_dim))
  8. gen_imgs = generator.predict(noise)
  9. # 训练判别器
  10. d_loss_real = discriminator.train_on_batch(imgs, np.ones((batch_size, 1)))# 真照片的损失值
  11. d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1)))# 假照片的损失值
  12. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# 取平均值,其将被用作反向传播的损失值,用于更新判别器的权重。
  13. '''训练生成器'''
  14. # 生成一批噪声
  15. noise = np.random.normal(0, 1, (batch_size, latent_dim))
  16. # 训练生成器
  17. g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))
  18. # 打印损失
  19. #print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
  20. '''展示生成的图片'''
  21. if epoch % sam_inter == 0:# 每1000轮展示一次
  22. r, c = 3, 3
  23. noise = np.random.normal(0, 1, (r * c, latent_dim))
  24. gen_imgs = generator.predict(noise)
  25. # 将图片像素值调整到0-1之间
  26. gen_imgs = 0.5 * gen_imgs + 0.5
  27. fig, axs = plt.subplots(r, c)
  28. cnt = 0
  29. for i in range(r):
  30. for j in range(c):
  31. axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
  32. axs[i,j].axis('off')
  33. cnt += 1
  34. # 保存部分最终生成的图片
  35. if epoch == epochs-1:
  36. plt.savefig("最终生成效果.png")
  37. plt.show()

        将此生成图与前面所展示的样本图片比较,可以发现部分图片已经不易通过肉眼识别出,其为真图片还是假图片了。例如:前面示例的真图片最中间那个Shirt,与此生成图的最右边中间的Shirt。

四、总结

        通过上述例子,我们可以发现在仅通过1W轮的训练,所生成的图片,就已经与真实的图片十分相似。倘若经过1亿轮呢?估计已与真实图片别无二致了吧。可以预见,AI绘图定会引发众多行业的变革。

参考资料:

[1406.2661] 生成对抗网络 (arxiv.org)https://arxiv.org/abs/1406.2661

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/103064
推荐阅读
相关标签
  

闽ICP备14008679号