当前位置:   article > 正文

生成对抗网络(GAN)是干什么用的?

gan的作用

什么是生成对抗网络?生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前景的方法之一。模型通过框架中(至少)两个模块:生成模型(Generative Model)和判别模型(Discriminative Model)的互相博弈学习产生相当好的输出。原始 GAN 理论中,并不要求 G D 都是神经网络,只需要是能拟合相应生成和判别的函数即可。但实用中一般均使用深度神经网络作为 G D 。一个优秀的GAN应用需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。

一个典型的生成对抗网络模型大概长这个样子:

 

81ebac7551582049df9edad99126da79822.jpg

我们先来理解下GAN的两个模型要做什么。

首先判别模型,就是图中右半部分的网络,直观来看就是一个简单的神经网络结构,输入就是一副图像,输出就是一个概率值,用于判断真假使用(概率值大于0.5那就是真,小于0.5那就是假),真假也不过是人们定义的概率而已。

其次是生成模型,生成模型要做什么呢,同样也可以看成是一个神经网络模型,输入是一组随机数Z,输出是一个图像,不再是一个数值而已。从图中可以看到,会存在两个数据集,一个是真实数据集,这好说,另一个是假的数据集,那这个数据集就是有生成网络造出来的数据集。好了根据这个图我们再来理解一下GAN的目标是要干什么:

判别网络的目的:就是能判别出来属于的一张图它是来自真实样本集还是假样本集。假如输入的是真样本,网络输出就接近1,输入的是假样本,网络输出接近0,那么很完美,达到了很好判别的目的。

生成网络的目的:生成网络是造样本的,它的目的就是使得自己造样本的能力尽可能强,强到什么程度呢,你判别网络没法判断我是真样本还是假样本。

因此辨别网络的作用就是对噪音生成的数据辨别他为假的,对真实的数据辨别他为真的。

而生成网络的损失函数就是使得对于噪音数据,经过辨别网络之后的辨别结果是真的,这样就能达到生成真实图像的目的。

这里会感觉比较饶,这也是生成对抗网络的难点所在,理解了这点,整个生成对抗网络模型也就理解了。

 

  1. 工作模式

一般的工作流程很简单直接:

1. 采样训练样本的一个 minibatch,然后计算它们的鉴别器分数;

2. 得到一个生成样本 minibatch,然后计算它们的鉴别器分数;

3. 使用这两个步骤累积的梯度执行一次更新。

下一个诀窍是避免使用稀疏梯度,尤其是在生成器中。只需将特定的层换成它们对应的「平滑」的类似层就可以了,比如:

1.ReLU 换成 LeakyReLU

2. 最大池化换成平均池化、卷积+stride

3.Unpooling 换成去卷积

两个主要网络模型,一个是生成器模型,一个是辨别器模型。

辨别器模型要辨别两种数据源,一种是真实数据,一种是生成器生成的数据。这里可以分成两个辨别器模型,设置reuse=True来共享模型参数。

 

2、代码

  1. import numpy as np
  2. import tensorflow as tf
  3. import matplotlib.pyplot as plt
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. # TODO:数据准备
  6. mnist = input_data.read_data_sets('data')
  7. # TODO:获得输入数据
  8. def get_inputs(noise_dim, image_height, image_width, image_depth):
  9. # 真实数据
  10. inputs_real = tf.placeholder(tf.float32, [None, image_height, image_width, image_depth], name='inputs_real')
  11. # 噪声数据
  12. inputs_noise = tf.placeholder(tf.float32, [None, noise_dim], name='inputs_noise')
  13. return inputs_real, inputs_noise
  14. # TODO:生成器
  15. def get_generator(noise_img, output_dim, is_train=True, alpha=0.01):
  16. with tf.variable_scope("generator", reuse=(not is_train)):
  17. # 100 x 1 to 4 x 4 x 512
  18. # 全连接层
  19. layer1 = tf.layers.dense(noise_img, 4 * 4 * 512)
  20. layer1 = tf.reshape(layer1, [-1, 4, 4, 512])
  21. # batch normalization
  22. layer1 = tf.layers.batch_normalization(layer1, training=is_train)
  23. # Leaky ReLU
  24. layer1 = tf.maximum(alpha * layer1, layer1)
  25. # dropout
  26. layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
  27. # 4 x 4 x 512 to 7 x 7 x 256
  28. layer2 = tf.layers.conv2d_transpose(layer1, 256, 4, strides=1, padding='valid')
  29. layer2 = tf.layers.batch_normalization(layer2, training=is_train)
  30. layer2 = tf.maximum(alpha * layer2, layer2)
  31. layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
  32. # 7 x 7 256 to 14 x 14 x 128
  33. layer3 = tf.layers.conv2d_transpose(layer2, 128, 3, strides=2, padding='same')
  34. layer3 = tf.layers.batch_normalization(layer3, training=is_train)
  35. layer3 = tf.maximum(alpha * layer3, layer3)
  36. layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
  37. # 14 x 14 x 128 to 28 x 28 x 1
  38. logits = tf.layers.conv2d_transpose(layer3, output_dim, 3, strides=2, padding='same')
  39. # MNIST原始数据集的像素范围在0-1,这里的生成图片范围为(-1,1)
  40. # 因此在训练时,记住要把MNIST像素范围进行resize
  41. outputs = tf.tanh(logits)
  42. return outputs
  43. # TODO:判别器
  44. def get_discriminator(inputs_img, reuse=False, alpha=0.01):
  45. with tf.variable_scope("discriminator", reuse=reuse):
  46. # 28 x 28 x 1 to 14 x 14 x 128
  47. # 第一层不加入BN
  48. layer1 = tf.layers.conv2d(inputs_img, 128, 3, strides=2, padding='same')
  49. layer1 = tf.maximum(alpha * layer1, layer1)
  50. layer1 = tf.nn.dropout(layer1, keep_prob=0.8)
  51. # 14 x 14 x 128 to 7 x 7 x 256
  52. layer2 = tf.layers.conv2d(layer1, 256, 3, strides=2, padding='same')
  53. layer2 = tf.layers.batch_normalization(layer2, training=True)
  54. layer2 = tf.maximum(alpha * layer2, layer2)
  55. layer2 = tf.nn.dropout(layer2, keep_prob=0.8)
  56. # 7 x 7 x 256 to 4 x 4 x 512
  57. layer3 = tf.layers.conv2d(layer2, 512, 3, strides=2, padding='same')
  58. layer3 = tf.layers.batch_normalization(layer3, training=True)
  59. layer3 = tf.maximum(alpha * layer3, layer3)
  60. layer3 = tf.nn.dropout(layer3, keep_prob=0.8)
  61. # 4 x 4 x 512 to 4*4*512 x 1
  62. flatten = tf.reshape(layer3, (-1, 4 * 4 * 512))
  63. logits = tf.layers.dense(flatten, 1)
  64. outputs = tf.sigmoid(logits)
  65. return logits, outputs
  66. # TODO: 目标函数
  67. def get_loss(inputs_real, inputs_noise, image_depth, smooth=0.1):
  68. g_outputs = get_generator(inputs_noise, image_depth, is_train=True)
  69. d_logits_real, d_outputs_real = get_discriminator(inputs_real)
  70. d_logits_fake, d_outputs_fake = get_discriminator(g_outputs, reuse=True)
  71. # 计算Loss
  72. g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.ones_like(d_outputs_fake) * (1 - smooth)))
  73. d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,labels=tf.ones_like(d_outputs_real) * (1 - smooth)))
  74. d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,labels=tf.zeros_like(d_outputs_fake)))
  75. d_loss = tf.add(d_loss_real, d_loss_fake)
  76. return g_loss, d_loss
  77. # TODO:优化器
  78. def get_optimizer(g_loss, d_loss, learning_rate=0.001):
  79. train_vars = tf.trainable_variables()
  80. g_vars = [var for var in train_vars if var.name.startswith("generator")]
  81. d_vars = [var for var in train_vars if var.name.startswith("discriminator")]
  82. # Optimizer
  83. with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
  84. g_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
  85. d_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
  86. return g_opt, d_opt
  87. # 显示图片
  88. def plot_images(samples):
  89. fig, axes = plt.subplots(nrows=5, ncols=5, sharex=True, sharey=True, figsize=(7, 7))
  90. for img, ax in zip(samples, axes.flatten()):
  91. ax.imshow(img.reshape((28, 28)), cmap='Greys_r')
  92. ax.get_xaxis().set_visible(False)
  93. ax.get_yaxis().set_visible(False)
  94. fig.tight_layout(pad=0)
  95. plt.show()
  96. def show_generator_output(sess, n_images, inputs_noise, output_dim):
  97. noise_shape = inputs_noise.get_shape().as_list()[-1]
  98. # 生成噪声图片
  99. examples_noise = np.random.uniform(-1, 1, size=[n_images, noise_shape])
  100. samples = sess.run(get_generator(inputs_noise, output_dim, False),
  101. feed_dict={inputs_noise: examples_noise})
  102. result = np.squeeze(samples, -1)
  103. return result
  104. # TODO:开始训练
  105. # 定义参数
  106. batch_size = 64
  107. noise_size = 100
  108. epochs = 5
  109. n_samples = 25
  110. learning_rate = 0.001
  111. def train(noise_size, data_shape, batch_size, n_samples):
  112. # 存储loss
  113. losses = []
  114. steps = 0
  115. inputs_real, inputs_noise = get_inputs(noise_size, data_shape[1], data_shape[2], data_shape[3])
  116. g_loss, d_loss = get_loss(inputs_real, inputs_noise, data_shape[-1])
  117. print("FUNCTION READY!!")
  118. g_train_opt, d_train_opt = get_optimizer(g_loss, d_loss, learning_rate)
  119. print("TRAINING....")
  120. with tf.Session() as sess:
  121. sess.run(tf.global_variables_initializer())
  122. # 迭代epoch
  123. for e in range(epochs):
  124. for batch_i in range(mnist.train.num_examples // batch_size):
  125. steps += 1
  126. batch = mnist.train.next_batch(batch_size)
  127. batch_images = batch[0].reshape((batch_size, data_shape[1], data_shape[2], data_shape[3]))
  128. # scale to -1, 1
  129. batch_images = batch_images * 2 - 1
  130. # noise
  131. batch_noise = np.random.uniform(-1, 1, size=(batch_size, noise_size))
  132. # run optimizer
  133. sess.run(g_train_opt, feed_dict={inputs_real: batch_images,
  134. inputs_noise: batch_noise})
  135. sess.run(d_train_opt, feed_dict={inputs_real: batch_images,
  136. inputs_noise: batch_noise})
  137. if steps % 101 == 0:
  138. train_loss_d = d_loss.eval({inputs_real: batch_images,
  139. inputs_noise: batch_noise})
  140. train_loss_g = g_loss.eval({inputs_real: batch_images,
  141. inputs_noise: batch_noise})
  142. losses.append((train_loss_d, train_loss_g))
  143. print("Epoch {}/{}....".format(e + 1, epochs),
  144. "Discriminator Loss: {:.4f}....".format(train_loss_d),
  145. "Generator Loss: {:.4f}....".format(train_loss_g))
  146. if e % 1 == 0:
  147. # 显示图片
  148. samples = show_generator_output(sess, n_samples, inputs_noise, data_shape[-1])
  149. plot_images(samples)
  150. with tf.Graph().as_default():
  151. train(noise_size, [-1, 28, 28, 1], batch_size, n_samples)
  152. print("OPTIMIZER END!!")

 

转载于:https://my.oschina.net/u/778683/blog/3100336

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/712540
推荐阅读
相关标签
  

闽ICP备14008679号