当前位置:   article > 正文

基于GAN的手写数字生成实践_gan手写数字生成实验

gan手写数字生成实验

       之前就对GAN这项技术很感兴趣,可是后面一直没有找到时间研究一下,今天找来了一个很不错的例子学习实践了一下,简单来记录一下自己的实践,具体的代码如下:

  1. #!usr/bin/env python
  2. #encoding:utf-8
  3. from __future__ import division
  4. '''
  5. __Author__:沂水寒城
  6. 功能: 基于GAN的手写数字生成实践
  7. '''
  8. import os
  9. import numpy as np
  10. import tensorflow as tf
  11. import matplotlib.pyplot as plt
  12. import matplotlib.gridspec as gridspec
  13. from tensorflow.examples.tutorials.mnist import input_data
  14. #设置基本的参数信息
  15. mb_size = 32
  16. X_dim = 784
  17. z_dim = 64
  18. h_dim = 128
  19. lr = 1e-3
  20. m = 5
  21. lam = 1e-3
  22. gamma = 0.5
  23. k_curr = 0
  24. if not os.path.exists('result/'):
  25. os.makedirs('result/')
  26. mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
  27. def numberPloter(samples):
  28. '''
  29. 数字图像绘制
  30. '''
  31. figure = plt.figure(figsize=(8, 8))
  32. gs = gridspec.GridSpec(4, 4)
  33. gs.update(wspace=0.05, hspace=0.05)
  34. for i, sample in enumerate(samples):
  35. ax = plt.subplot(gs[i])
  36. plt.axis('off')
  37. ax.set_xticklabels([])
  38. ax.set_yticklabels([])
  39. ax.set_aspect('equal')
  40. plt.imshow(sample.reshape(28, 28), cmap='Greys_r')
  41. return figure
  42. def xavier_init(size):
  43. '''
  44. 初始化
  45. '''
  46. in_dim = size[0]
  47. xavier_stddev = 1. / tf.sqrt(in_dim / 2.)
  48. return tf.random_normal(shape=size, stddev=xavier_stddev)
  49. X = tf.placeholder(tf.float32, shape=[None, X_dim])
  50. z = tf.placeholder(tf.float32, shape=[None, z_dim])
  51. k = tf.placeholder(tf.float32)
  52. D_W1 = tf.Variable(xavier_init([X_dim, h_dim]))
  53. D_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
  54. D_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
  55. D_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
  56. G_W1 = tf.Variable(xavier_init([z_dim, h_dim]))
  57. G_b1 = tf.Variable(tf.zeros(shape=[h_dim]))
  58. G_W2 = tf.Variable(xavier_init([h_dim, X_dim]))
  59. G_b2 = tf.Variable(tf.zeros(shape=[X_dim]))
  60. theta_G = [G_W1, G_W2, G_b1, G_b2]
  61. theta_D = [D_W1, D_W2, D_b1, D_b2]
  62. def sample_z(m, n):
  63. '''
  64. 随机数
  65. '''
  66. return np.random.uniform(-1., 1., size=[m, n])
  67. def G(z):
  68. '''
  69. 定义两个网络
  70. '''
  71. G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
  72. G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
  73. G_prob = tf.nn.sigmoid(G_log_prob)
  74. return G_prob
  75. def D(X):
  76. '''
  77. 定义两个网络
  78. '''
  79. D_h1 = tf.nn.relu(tf.matmul(X, D_W1) + D_b1)
  80. X_recon = tf.matmul(D_h1, D_W2) + D_b2
  81. return tf.reduce_mean(tf.reduce_sum((X - X_recon)**2, 1))
  82. # 计算损失
  83. G_sample = G(z)
  84. D_real = D(X)
  85. D_fake = D(G_sample)
  86. D_loss = D_real - k*D_fake
  87. G_loss = D_fake
  88. D_solver=(tf.train.AdamOptimizer(learning_rate=lr).minimize(D_loss, var_list=theta_D))
  89. G_solver=(tf.train.AdamOptimizer(learning_rate=lr).minimize(G_loss, var_list=theta_G))
  90. sess = tf.Session()
  91. sess.run(tf.global_variables_initializer())
  92. # 迭代计算一百万次,每1000次绘制一张图片
  93. num = 0
  94. for it in range(1000000):
  95. X_mb, _ = mnist.train.next_batch(mb_size)
  96. _, D_real_curr = sess.run([D_solver, D_real],feed_dict={X: X_mb, z: sample_z(mb_size, z_dim), k: k_curr})
  97. _, D_fake_curr = sess.run([G_solver, D_fake],feed_dict={X: X_mb, z: sample_z(mb_size, z_dim)})
  98. k_curr = k_curr + lam * (gamma*D_real_curr - D_fake_curr)
  99. if it % 1000 == 0:
  100. measure = D_real_curr + np.abs(gamma*D_real_curr - D_fake_curr)
  101. print('Iter-{}; Convergence measure: {:.4}'.format(it, measure))
  102. samples = sess.run(G_sample, feed_dict={z: sample_z(16, z_dim)})
  103. fig = plot(samples)
  104. plt.savefig('result/{}.png'.format(str(num).zfill(3)), bbox_inches='tight')
  105. num += 1
  106. plt.close(fig)

      这是一个很简单实用的例子,基于GAN来生成手写数字,关于各部分的代码作用,我在具体的代码里面已经加入了相应的注释,下面我们来简单看一下输出的结果:

。。。。。。。。。。。。。。。。。。

     上面是展示了1000张图片的前100张,和后面将近100张左右的结果缩略图,这里给出来第一张和最后一张:

第一张:

    最后一张:

       之后找时间继续学习,欢迎交流!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/103006
推荐阅读
相关标签
  

闽ICP备14008679号