当前位置:   article > 正文

生成对抗网络(GAN) 手写数字图像生成_用生成对抗网络生成数字图像实验

用生成对抗网络生成数字图像实验

生成式对抗网络(GAN)

简介

生成式对抗网络的框架主要有两个模型,一个是生成模型(Generator),记为为 G,是用来生成数据,通过大量的样本学习,生成一些能够以假乱真的数据样本。第二个是辨别模型(Discriminator),记为D,主要是接受 G生成的样本数据和真实样本数据,进行辨别和分类。生成网络G接受一个随机的噪声z并生成图片,记为G(z);判别网络D的作用是判别一张图片x是否真实,对于输入x,D(x)是x为真实图片的概率。G和D相互博弈,通过学习,G的生成能力和D的辨别能力逐渐增强直到收敛。

原理

一个随机生成符合随机分布的噪音 z, 生成器G通过一个复杂的映射关系生成假样本

                                                                                           x^=G(z;Θg)

辨别器对于真实样本和假的样本,输出一个0到1之间的值,越大就越有可能是真实样本

                                                                                            s=D(x;Θd)

总的目标函数

                                   

代码

  1. # encoding: utf-8
  2. import tensorflow as tf
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import os
  6. from tensorflow.examples.tutorials.mnist import input_data
  7. mnist = input_data.read_data_sets('MNIST_data/')
  8. #设置一些参数
  9. batch_size = 100
  10. z_dim = 100
  11. OUTPUT_DIR = 'samples'
  12. if not os.path.exists(OUTPUT_DIR):
  13. os.mkdir(OUTPUT_DIR)
  14. X = tf.placeholder(dtype = tf.float32, shape = [None, 28,28, 1], name = 'X')
  15. Noise = tf.placeholder(dtype = tf.float32, shape = [None, z_dim], name = 'Noise')
  16. is_training = tf.placeholder(dtype = tf.bool, name = 'is_training')
  17. def relu(x, leak = 0.2):
  18. return tf.maximum(x, leak * x)
  19. def sigmoid_cross_entropy_with_logits(x, y):
  20. return tf.nn.sigmoid_cross_entropy_with_logits(logits=x, labels=y)
  21. #辨别器
  22. def discriminator(image,reuse = None, is_training = is_training):
  23. m = 0.9
  24. with tf.variable_scope('discriminator', reuse = reuse):
  25. H0 = relu(tf.layers.conv2d(image, kernel_size = 5, filters = 64, strides=2, padding ='same'))
  26. H1 = tf.layers.conv2d(H0, kernel_size = 5, filters = 128, strides = 2,padding = 'same')
  27. H1 = relu(tf.contrib.layers.batch_norm(H1, is_training=is_training, decay = m))
  28. H2 = tf.layers.conv2d(H1, kernel_size = 5, filters = 256, strides = 2, padding = 'same')
  29. H2 = relu(tf.contrib.layers.batch_norm(H2, is_training = is_training, decay = m))
  30. H3 = tf.layers.conv2d(H2, kernel_size = 5 , filters = 512, strides = 2, padding = 'same')
  31. H3 = relu(tf.contrib.layers.batch_norm(H3, is_training = is_training, decay = m))
  32. H4 = tf.contrib.layers.flatten(H3)
  33. H4 = tf.layers.dense(H4, units=1)
  34. return tf.nn.sigmoid(H4), H4
  35. def generator(z, is_training = is_training):
  36. m = 0.8
  37. with tf.variable_scope('generator', reuse = None):
  38. d = 3
  39. H0 = tf.layers.dense(z, units = d*d*512)
  40. H0 = tf.reshape(H0, shape = [-1, d, d, 512])
  41. H0 = tf.nn.relu(tf.contrib.layers.batch_norm(H0, is_training=is_training, decay = m))
  42. H1 = tf.layers.conv2d_transpose(H0, kernel_size = 5, filters = 256, strides = 2, padding = 'same')
  43. H1 = tf.nn.relu(tf.contrib.layers.batch_norm(H1, is_training = is_training, decay = m))
  44. H2 = tf.layers.conv2d_transpose(H1, kernel_size = 5, filters = 128,strides = 2, padding = 'same')
  45. H2 = tf.nn.relu(tf.contrib.layers.batch_norm(H2 , is_training=is_training, decay = m ))
  46. H3 = tf.layers.conv2d_transpose(H2, kernel_size = 5, filters = 64, strides = 2, padding = 'same')
  47. H3 = tf.nn.relu(tf.contrib.layers.batch_norm(H3, is_training = is_training, decay = m))
  48. H4 = tf.layers.conv2d_transpose(H3, kernel_size = 5, filters= 1, strides = 1, padding = 'valid', activation=tf.nn.tanh, name = 'g')
  49. return H4
  50. g = generator(Noise)
  51. d_real,d_real_logits = discriminator(X)
  52. d_fake, d_fake_logits = discriminator(g, reuse = True)
  53. vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
  54. vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
  55. loss_d_real = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_real_logits,tf.ones_like(d_real)))
  56. loss_d_fake = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits, tf.zeros_like(d_fake)))
  57. loss_g = tf.reduce_mean(sigmoid_cross_entropy_with_logits(d_fake_logits,tf.ones_like(d_fake)))
  58. loss_d = loss_d_real + loss_d_fake
  59. updates_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  60. with tf.control_dependencies(updates_ops):
  61. optimizer_d = tf.train.AdamOptimizer(learning_rate = 0.0002, beta1 = 0.5).minimize(loss_d, var_list=vars_d)
  62. optimizer_g = tf.train.AdamOptimizer(learning_rate = 0.0002, beta1 = 0.5).minimize(loss_g, var_list = vars_g)
  63. def montage(images):
  64. if isinstance(images,list):
  65. images = np.array(images)
  66. image_h = images.shape[1]
  67. image_w = images.shape[2]
  68. n_plots = int(np.ceil(np.sqrt(images.shape[0])))
  69. m = np.ones((images.shape[1] * n_plots + n_plots +1 , images.shape[2] * n_plots + n_plots + 1)) * 0.5
  70. for i in range(n_plots):
  71. for j in range(n_plots):
  72. this_filter = i*n_plots + j
  73. if this_filter < images.shape[0]:
  74. this_img = images[this_filter]
  75. m[1 + i + i*image_h : 1 + i+ (i+1) *image_h,1+j+ j*image_w : 1+ j+(j+1)*image_w] = this_img
  76. return m
  77. sess = tf.Session()
  78. sess.run(tf.global_variables_initializer())
  79. z_samlpes = np.random.uniform(-1.0, 1.0,[batch_size,z_dim]).astype(np.float32)
  80. samples = []
  81. loss = {'d':[], 'g':[]}
  82. for i in range(30000):
  83. n = np.random.uniform(-1.0,1.0,[batch_size,z_dim]).astype(np.float32)
  84. batch = mnist.train.next_batch(batch_size=batch_size)[0]
  85. batch = np.reshape(batch,[-1,28,28,1])
  86. batch = (batch - 0.5) * 2
  87. d_ls,g_ls = sess.run([loss_d, loss_g], feed_dict={X:batch,Noise:n, is_training:True})
  88. loss['d'].append(d_ls)
  89. loss['g'].append(g_ls)
  90. sess.run(optimizer_d, feed_dict={X:batch,Noise:n,is_training:True})
  91. sess.run(optimizer_g, feed_dict={X:batch,Noise:n,is_training:True})
  92. sess.run(optimizer_g, feed_dict={X:batch, Noise:n,is_training:True})
  93. if i % 20 == 0:
  94. print(i,d_ls, g_ls)
  95. gen_imgs = sess.run(g, feed_dict={Noise:z_samlpes,is_training:False})
  96. gen_imgs = (gen_imgs + 1) /2
  97. imgs = [img[:,:,0] for img in gen_imgs]
  98. gen_imgs = montage(imgs)
  99. plt.axis('off')
  100. plt.imshow(gen_imgs, cmap = 'gray')
  101. plt.savefig(os.path.join(OUTPUT_DIR,'sample_%d.jpg'%i))
  102. plt.show()
  103. samples.append(gen_imgs)
  104. plt.plot(loss['d'], label='discriminator')
  105. plt.plot(loss['g'], label = 'generator')
  106. plt.legend(loc = 'upper right')
  107. plt.show()
  108. saver =tf.train.Saver()
  109. saver.save(sess, './mnist_dcgan', global_step=30000)

结果

参考:https://zhuanlan.zhihu.com/p/44167207

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/103052
推荐阅读
相关标签
  

闽ICP备14008679号