当前位置:   article > 正文

深度卷积生成对抗网络DCGAN——生成手写数字图片_生成对抗网络实现手写数字

生成对抗网络实现手写数字

前言

本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。

 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:

  1. # 用于生成 GIF 图片
  2. pip install -q imageio

目录

前言

一、什么是生成对抗网络?

二、加载数据集

三、创建模型

3.1 生成器

3.1 判别器

四、定义损失函数和优化器

4.1 生成器的损失和优化器

4.2 判别器的损失和优化器

五、训练模型

5.1 保存检查点

5.2 定义训练过程

5.3 训练模型

六、评估模型


一、什么是生成对抗网络?

生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。

生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。

判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。

训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。

当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。

本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。

二、加载数据集

使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。

  1. (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
  2. train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
  3. train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
  4. BUFFER_SIZE = 60000
  5. BATCH_SIZE = 256
  6. # 批量化和打乱数据
  7. train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

三、创建模型

主要创建两个模型,一个是生成器,另一个是判别器

3.1 生成器

生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。

然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。

后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。

  1. def make_generator_model():
  2. model = tf.keras.Sequential()
  3. model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
  4. model.add(layers.BatchNormalization())
  5. model.add(layers.LeakyReLU())
  6. model.add(layers.Reshape((7, 7, 256)))
  7. assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
  8. model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
  9. assert model.output_shape == (None, 7, 7, 128)
  10. model.add(layers.BatchNormalization())
  11. model.add(layers.LeakyReLU())
  12. model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
  13. assert model.output_shape == (None, 14, 14, 64)
  14. model.add(layers.BatchNormalization())
  15. model.add(layers.LeakyReLU())
  16. model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
  17. assert model.output_shape == (None, 28, 28, 1)
  18. return model

用tf.keras.utils.plot_model( ),看一下模型结构

 

用summary(),看一下模型结构和参数

使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。

  1. generator = make_generator_model()
  2. noise = tf.random.normal([1, 100])
  3. generated_image = generator(noise, training=False)
  4. plt.imshow(generated_image[0, :, :, 0], cmap='gray')

3.1 判别器

判别器是基于 CNN卷积神经网络 的图片分类器。

  1. def make_discriminator_model():
  2. model = tf.keras.Sequential()
  3. model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
  4. input_shape=[28, 28, 1]))
  5. model.add(layers.LeakyReLU())
  6. model.add(layers.Dropout(0.3))
  7. model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
  8. model.add(layers.LeakyReLU())
  9. model.add(layers.Dropout(0.3))
  10. model.add(layers.Flatten())
  11. model.add(layers.Dense(1))
  12. return model

用tf.keras.utils.plot_model( ),看一下模型结构

用summary(),看一下模型结构和参数

四、定义损失函数和优化器

由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。

首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。

  1. # 该方法返回计算交叉熵损失的辅助函数
  2. cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

4.1 生成器的损失和优化器

1)生成器损失

生成器损失是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。

这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。

  1. def generator_loss(fake_output):
  2. return cross_entropy(tf.ones_like(fake_output), fake_output)

2)生成器优化器

generator_optimizer = tf.keras.optimizers.Adam(1e-4)

4.2 判别器的损失和优化器

1)判别器损失

判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。

  1. def discriminator_loss(real_output, fake_output):
  2. real_loss = cross_entropy(tf.ones_like(real_output), real_output)
  3. fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
  4. total_loss = real_loss + fake_loss
  5. return total_loss

2)判别器优化器

discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

五、训练模型

5.1 保存检查点

保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。

  1. checkpoint_dir = './training_checkpoints'
  2. checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
  3. checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
  4. discriminator_optimizer=discriminator_optimizer,
  5. generator=generator,
  6. discriminator=discriminator)

5.2 定义训练过程

  1. EPOCHS = 50
  2. noise_dim = 100
  3. num_examples_to_generate = 16
  4. # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
  5. seed = tf.random.normal([num_examples_to_generate, noise_dim])

训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。

判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。

两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。

  1. # 注意 `tf.function` 的使用
  2. # 该注解使函数被“编译”
  3. @tf.function
  4. def train_step(images):
  5. noise = tf.random.normal([BATCH_SIZE, noise_dim])
  6. with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
  7. generated_images = generator(noise, training=True)
  8. real_output = discriminator(images, training=True)
  9. fake_output = discriminator(generated_images, training=True)
  10. gen_loss = generator_loss(fake_output)
  11. disc_loss = discriminator_loss(real_output, fake_output)
  12. gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
  13. gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
  14. generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
  15. discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
  16. def train(dataset, epochs):
  17. for epoch in range(epochs):
  18. start = time.time()
  19. for image_batch in dataset:
  20. train_step(image_batch)
  21. # 继续进行时为 GIF 生成图像
  22. display.clear_output(wait=True)
  23. generate_and_save_images(generator,
  24. epoch + 1,
  25. seed)
  26. # 每 15 个 epoch 保存一次模型
  27. if (epoch + 1) % 15 == 0:
  28. checkpoint.save(file_prefix = checkpoint_prefix)
  29. print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
  30. # 最后一个 epoch 结束后生成图片
  31. display.clear_output(wait=True)
  32. generate_and_save_images(generator,
  33. epochs,
  34. seed)
  35. # 生成与保存图片
  36. def generate_and_save_images(model, epoch, test_input):
  37. # 注意 training` 设定为 False
  38. # 因此,所有层都在推理模式下运行(batchnorm)。
  39. predictions = model(test_input, training=False)
  40. fig = plt.figure(figsize=(4,4))
  41. for i in range(predictions.shape[0]):
  42. plt.subplot(4, 4, i+1)
  43. plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
  44. plt.axis('off')
  45. plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  46. plt.show()

5.3 训练模型

调用上面定义的train()函数,来同时训练生成器和判别器。

注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。

  1. %%time
  2. train(train_dataset, EPOCHS)

在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。

训练了15轮的效果:

训练了30轮的效果:

训练过程:

恢复最新的检查点

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

六、评估模型

这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。

  1. # 使用 epoch 数生成单张图片
  2. def display_image(epoch_no):
  3. return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
  4. display_image(EPOCHS)
  1. anim_file = 'dcgan.gif'
  2. with imageio.get_writer(anim_file, mode='I') as writer:
  3. filenames = glob.glob('image*.png')
  4. filenames = sorted(filenames)
  5. last = -1
  6. for i,filename in enumerate(filenames):
  7. frame = 2*(i**0.5)
  8. if round(frame) > round(last):
  9. last = frame
  10. else:
  11. continue
  12. image = imageio.imread(filename)
  13. writer.append_data(image)
  14. image = imageio.imread(filename)
  15. writer.append_data(image)
  16. import IPython
  17. if IPython.version_info > (6,2,0,''):
  18. display.Image(filename=anim_file)

完整代码:

  1. import tensorflow as tf
  2. import glob
  3. import imageio
  4. import matplotlib.pyplot as plt
  5. import numpy as np
  6. import os
  7. import PIL
  8. from tensorflow.keras import layers
  9. import time
  10. from IPython import display
  11. (train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
  12. train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
  13. train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
  14. BUFFER_SIZE = 60000
  15. BATCH_SIZE = 256
  16. # 批量化和打乱数据
  17. train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  18. # 创建模型--生成器
  19. def make_generator_model():
  20. model = tf.keras.Sequential()
  21. model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
  22. model.add(layers.BatchNormalization())
  23. model.add(layers.LeakyReLU())
  24. model.add(layers.Reshape((7, 7, 256)))
  25. assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
  26. model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
  27. assert model.output_shape == (None, 7, 7, 128)
  28. model.add(layers.BatchNormalization())
  29. model.add(layers.LeakyReLU())
  30. model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
  31. assert model.output_shape == (None, 14, 14, 64)
  32. model.add(layers.BatchNormalization())
  33. model.add(layers.LeakyReLU())
  34. model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
  35. assert model.output_shape == (None, 28, 28, 1)
  36. return model
  37. # 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
  38. generator = make_generator_model()
  39. noise = tf.random.normal([1, 100])
  40. generated_image = generator(noise, training=False)
  41. plt.imshow(generated_image[0, :, :, 0], cmap='gray')
  42. tf.keras.utils.plot_model(generator)
  43. # 判别器
  44. def make_discriminator_model():
  45. model = tf.keras.Sequential()
  46. model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
  47. input_shape=[28, 28, 1]))
  48. model.add(layers.LeakyReLU())
  49. model.add(layers.Dropout(0.3))
  50. model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
  51. model.add(layers.LeakyReLU())
  52. model.add(layers.Dropout(0.3))
  53. model.add(layers.Flatten())
  54. model.add(layers.Dense(1))
  55. return model
  56. # 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
  57. discriminator = make_discriminator_model()
  58. decision = discriminator(generated_image)
  59. print (decision)
  60. # 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
  61. cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
  62. # 生成器的损失和优化器
  63. def generator_loss(fake_output):
  64. return cross_entropy(tf.ones_like(fake_output), fake_output)
  65. generator_optimizer = tf.keras.optimizers.Adam(1e-4)
  66. # 判别器的损失和优化器
  67. def discriminator_loss(real_output, fake_output):
  68. real_loss = cross_entropy(tf.ones_like(real_output), real_output)
  69. fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
  70. total_loss = real_loss + fake_loss
  71. return total_loss
  72. discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
  73. # 保存检查点
  74. checkpoint_dir = './training_checkpoints'
  75. checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
  76. checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
  77. discriminator_optimizer=discriminator_optimizer,
  78. generator=generator,
  79. discriminator=discriminator)
  80. # 定义训练过程
  81. EPOCHS = 50
  82. noise_dim = 100
  83. num_examples_to_generate = 16
  84. # 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
  85. seed = tf.random.normal([num_examples_to_generate, noise_dim])
  86. # 注意 `tf.function` 的使用
  87. # 该注解使函数被“编译”
  88. @tf.function
  89. def train_step(images):
  90. noise = tf.random.normal([BATCH_SIZE, noise_dim])
  91. with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
  92. generated_images = generator(noise, training=True)
  93. real_output = discriminator(images, training=True)
  94. fake_output = discriminator(generated_images, training=True)
  95. gen_loss = generator_loss(fake_output)
  96. disc_loss = discriminator_loss(real_output, fake_output)
  97. gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
  98. gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
  99. generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
  100. discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
  101. def train(dataset, epochs):
  102. for epoch in range(epochs):
  103. start = time.time()
  104. for image_batch in dataset:
  105. train_step(image_batch)
  106. # 继续进行时为 GIF 生成图像
  107. display.clear_output(wait=True)
  108. generate_and_save_images(generator,
  109. epoch + 1,
  110. seed)
  111. # 每 15 个 epoch 保存一次模型
  112. if (epoch + 1) % 15 == 0:
  113. checkpoint.save(file_prefix = checkpoint_prefix)
  114. print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
  115. # 最后一个 epoch 结束后生成图片
  116. display.clear_output(wait=True)
  117. generate_and_save_images(generator,
  118. epochs,
  119. seed)
  120. # 生成与保存图片
  121. def generate_and_save_images(model, epoch, test_input):
  122. # 注意 training` 设定为 False
  123. # 因此,所有层都在推理模式下运行(batchnorm)。
  124. predictions = model(test_input, training=False)
  125. fig = plt.figure(figsize=(4,4))
  126. for i in range(predictions.shape[0]):
  127. plt.subplot(4, 4, i+1)
  128. plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
  129. plt.axis('off')
  130. plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  131. plt.show()
  132. # 训练模型
  133. train(train_dataset, EPOCHS)
  134. # 恢复最新的检查点
  135. checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
  136. # 评估模型
  137. # 使用 epoch 数生成单张图片
  138. def display_image(epoch_no):
  139. return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
  140. display_image(EPOCHS)
  141. anim_file = 'dcgan.gif'
  142. with imageio.get_writer(anim_file, mode='I') as writer:
  143. filenames = glob.glob('image*.png')
  144. filenames = sorted(filenames)
  145. last = -1
  146. for i,filename in enumerate(filenames):
  147. frame = 2*(i**0.5)
  148. if round(frame) > round(last):
  149. last = frame
  150. else:
  151. continue
  152. image = imageio.imread(filename)
  153. writer.append_data(image)
  154. image = imageio.imread(filename)
  155. writer.append_data(image)
  156. import IPython
  157. if IPython.version_info > (6,2,0,''):
  158. display.Image(filename=anim_file)

参考:https://www.tensorflow.org/tutorials/generative/dcgan

一篇文章“简单”认识《生成对抗网络》(GAN)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/94319
推荐阅读
相关标签
  

闽ICP备14008679号