当前位置:   article > 正文

TensorFlow aegan 对mnist 数据集压缩特征及重建

机器学习压缩数据重建

原文链接: TensorFlow aegan 对mnist 数据集压缩特征及重建

上一篇: TensorFlow infogan 生成 mnist 数据集

下一篇: wgan-gp 生成mnist数据集

ae的基本原理是特征的映射,即将高维特征压缩到低维特征,而在特征重建过程中只能模拟输入的单个体样0本来输出结果,aegan的优势在于重建过程中可以生成与自己类似的样0本,其功效等同于变分自decoder

aegan单纯在gan之后加个自解码网络即可,通过gan可以利用噪声生成模拟数据的特点,使用自解码完成特征到图像的反向映射,从而实现一个即可将数据映射到低维空间,又可将低维还原模拟分布数据的网络。

aegan训练分为两步:

1,使用传统方式训练一个gan

2,固定gan,利用自编码网络来训练反向生成网络,这样得到的反向生成网络就具有高维到低维映射的能力了

aegan的原理是先固定复杂样0本分布作为网络输入,再慢慢调整网络输出 去匹配标准高斯分布。

inversegenerator 反向生成器,结果与判别器类似,均为生成器的反向操作,使用两个卷积层,再接两个全连接层

  1. # 反向生成器定义,结构与判别器类似,将图片生成为特征码
  2. def inversegenerator(x):
  3. reuse = len([t for t in tf.global_variables() if t.name.startswith('inversegenerator')]) > 0
  4. print('inv---')
  5. print(x.shape) # (10, 28, 28, 1)
  6. with tf.variable_scope('inversegenerator', reuse=reuse):
  7. # 两个卷积
  8. x = tf.reshape(x, shape=[-1, 28, 28, 1])
  9. print(x.shape) # (10, 28, 28, 1)
  10. x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.leaky_relu)
  11. print(x.shape) # (10, 14, 14, 64)
  12. x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.leaky_relu)
  13. print(x.shape) # (10, 7, 7, 128)
  14. # 两个全连接
  15. x = slim.flatten(x)
  16. print(x.shape) # (10, 6272)
  17. shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=tf.nn.leaky_relu)
  18. print(shared_tensor.shape) # (10, 1024)
  19. z = slim.fully_connected(shared_tensor, num_outputs=50, activation_fn=tf.nn.leaky_relu)
  20. print(z.shape) # (10, 50)
  21. print('z---')
  22. return z

自编码网络的输入是生成器生成的图片generator(z),通过inversegenerator来压缩特征,生成与生成器输入噪声一样的维度,然后再将生成器当成自编码中的decoder重建出原始的图片

gan结果,infogan只会生成属于原始数据分布的图片

539c312b6c1cb33eca28db82de4a3014470.jpg

aegan结果,会生成与原始图片更加相近的图片

06f21602e1a3f9b3b3e08992ff9f9e5044e.jpg

  1. import numpy as np
  2. import tensorflow as tf
  3. import matplotlib.pyplot as plt
  4. import tensorflow.contrib.slim as slim
  5. from tensorflow.examples.tutorials.mnist import input_data
  6. mnist = input_data.read_data_sets("MNIST_data/")
  7. tf.reset_default_graph()
  8. # 生成器函数
  9. def generator(x):
  10. reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0
  11. print('g---', x.shape) # (10, 50)
  12. with tf.variable_scope('generator', reuse=reuse):
  13. # 两个带bn的全连接
  14. x = slim.fully_connected(x, 1024)
  15. print(x.shape) # (10, 1024)
  16. x = slim.batch_norm(x, activation_fn=tf.nn.relu)
  17. print(x.shape) # (10, 1024)
  18. x = slim.fully_connected(x, 7 * 7 * 128)
  19. print(x.shape) # (10, 6272)
  20. x = slim.batch_norm(x, activation_fn=tf.nn.relu)
  21. print(x.shape) # (10, 6272)
  22. # 两个转置卷积
  23. x = tf.reshape(x, [-1, 7, 7, 128])
  24. print(x.shape) # (10, 7, 7, 128)
  25. x = slim.conv2d_transpose(x, 64, kernel_size=[4, 4], stride=2, activation_fn=None)
  26. print(x.shape) # (10, 14, 14, 64)
  27. x = slim.batch_norm(x, activation_fn=tf.nn.relu)
  28. print(x.shape) # (10, 14, 14, 64)
  29. z = slim.conv2d_transpose(x, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid)
  30. print(z.shape) # (10, 28, 28, 1)
  31. print('z---')
  32. return z
  33. # 反向生成器定义,结构与判别器类似,将图片生成为特征码
  34. def inversegenerator(x):
  35. reuse = len([t for t in tf.global_variables() if t.name.startswith('inversegenerator')]) > 0
  36. print('inv---')
  37. print(x.shape) # (10, 28, 28, 1)
  38. with tf.variable_scope('inversegenerator', reuse=reuse):
  39. # 两个卷积
  40. x = tf.reshape(x, shape=[-1, 28, 28, 1])
  41. print(x.shape) # (10, 28, 28, 1)
  42. x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.leaky_relu)
  43. print(x.shape) # (10, 14, 14, 64)
  44. x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.leaky_relu)
  45. print(x.shape) # (10, 7, 7, 128)
  46. # 两个全连接
  47. x = slim.flatten(x)
  48. print(x.shape) # (10, 6272)
  49. shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=tf.nn.leaky_relu)
  50. print(shared_tensor.shape) # (10, 1024)
  51. z = slim.fully_connected(shared_tensor, num_outputs=50, activation_fn=tf.nn.leaky_relu)
  52. print(z.shape) # (10, 50)
  53. print('z---')
  54. return z
  55. # 判别器定义
  56. def discriminator(x, num_classes=10, num_cont=2):
  57. reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
  58. print('dis---')
  59. print(x.shape) # (10, 28, 28, 1)
  60. with tf.variable_scope('discriminator', reuse=reuse):
  61. # 两个卷积
  62. x = tf.reshape(x, shape=[-1, 28, 28, 1])
  63. print(x.shape) # (10, 28, 28, 1)
  64. x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.leaky_relu)
  65. print(x.shape) # (10, 14, 14, 64)
  66. x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.leaky_relu)
  67. print(x.shape) # (10, 7, 7, 128)
  68. x = slim.flatten(x)
  69. print(x.shape) # (10, 6272)
  70. # 两个全连接
  71. shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=tf.nn.leaky_relu)
  72. print(shared_tensor.shape) # (10, 1024)
  73. recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=tf.nn.leaky_relu)
  74. print(recog_shared.shape) # (10, 128)
  75. # 通过全连接变换,生成输出信息。
  76. disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=None)
  77. print(disc.shape) # (10, 1)
  78. disc = tf.squeeze(disc, -1)
  79. print(disc.shape) # (10,)
  80. # print ("disc",disc.get_shape())#0 or 1
  81. recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)
  82. print(recog_cat.shape) # (10, 10)
  83. recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
  84. print(recog_cat.shape) # (10, 10)
  85. print('dis end---')
  86. return disc, recog_cat, recog_cont
  87. batch_size = 10 # 最小批次
  88. classes_dim = 10 # 10类数字
  89. con_dim = 2 # total continuous factor
  90. rand_dim = 38
  91. n_input = 784
  92. x = tf.placeholder(tf.float32, [None, n_input])
  93. y = tf.placeholder(tf.int32, [None])
  94. z_con = tf.random_normal((batch_size, con_dim)) # 2
  95. z_rand = tf.random_normal((batch_size, rand_dim)) # 38
  96. z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_con, z_rand]) # 50
  97. gen = generator(z)
  98. genout = tf.squeeze(gen, -1)
  99. # 自编码网络
  100. aelearning_rate = 0.01
  101. igen = generator(inversegenerator(generator(z)))
  102. loss_ae = tf.reduce_mean(tf.pow(gen - igen, 2))
  103. # 输出
  104. igenout = generator(inversegenerator(x))
  105. # labels for discriminator
  106. y_real = tf.ones(batch_size) # 真
  107. y_fake = tf.zeros(batch_size) # 假
  108. # 判别器
  109. disc_real, class_real, _ = discriminator(x)
  110. disc_fake, class_fake, con_fake = discriminator(gen)
  111. pred_class = tf.argmax(class_fake, dimension=1)
  112. # 判别器 loss
  113. loss_d_r = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_real, labels=y_real))
  114. loss_d_f = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_fake))
  115. loss_d = (loss_d_r + loss_d_f) / 2
  116. # generator loss
  117. loss_g = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=disc_fake, labels=y_real))
  118. # categorical factor loss
  119. loss_cf = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_fake, labels=y)) # class ok 图片对不上
  120. loss_cr = tf.reduce_mean(
  121. tf.nn.sparse_softmax_cross_entropy_with_logits(logits=class_real, labels=y)) # 生成的图片与class ok 与输入的class对不上
  122. loss_c = (loss_cf + loss_cr) / 2
  123. # continuous factor loss
  124. loss_con = tf.reduce_mean(tf.square(con_fake - z_con))
  125. # 获得各个网络中各自的训练参数
  126. t_vars = tf.trainable_variables()
  127. d_vars = [var for var in t_vars if 'discriminator' in var.name]
  128. g_vars = [var for var in t_vars if 'generator' in var.name]
  129. ae_vars = [var for var in t_vars if 'inversegenerator' in var.name]
  130. # disc_global_step = tf.Variable(0, trainable=False)
  131. gen_global_step = tf.Variable(0, trainable=False)
  132. # ae_global_step = tf.Variable(0, trainable=False)
  133. global_step = tf.train.get_or_create_global_step() # 使用MonitoredTrainingSession,必须有
  134. train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d + loss_c + loss_con, var_list=d_vars,
  135. global_step=global_step)
  136. train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g + loss_c + loss_con, var_list=g_vars,
  137. global_step=gen_global_step)
  138. train_ae = tf.train.AdamOptimizer(aelearning_rate).minimize(loss_ae, var_list=ae_vars, global_step=global_step)
  139. training_GANepochs = 3 # 训练GAN迭代3次数据集
  140. training_aeepochs = 6 # 训练AE迭代3次数据集(从3开始到6)
  141. display_step = 1
  142. with tf.train.MonitoredTrainingSession(checkpoint_dir='log/aecheckpoints', save_checkpoint_secs=120) as sess:
  143. total_batch = int(mnist.train.num_examples / batch_size)
  144. print("ae_global_step.eval(session=sess)", global_step.eval(session=sess),
  145. int(global_step.eval(session=sess) / total_batch))
  146. for epoch in range(int(global_step.eval(session=sess) / total_batch), training_GANepochs):
  147. avg_cost = 0.
  148. # 遍历全部数据集
  149. for i in range(total_batch):
  150. batch_xs, batch_ys = mnist.train.next_batch(batch_size) # 取数据
  151. feeds = {x: batch_xs, y: batch_ys}
  152. # Fit training using batch data
  153. l_disc, _, l_d_step = sess.run([loss_d, train_disc, global_step], feeds)
  154. l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step], feeds)
  155. # 显示训练中的详细信息
  156. if epoch % display_step == 0:
  157. print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc), l_gen)
  158. print("GAN完成!")
  159. # 测试
  160. print("Result:", loss_d.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]}, session=sess)
  161. , loss_g.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]}, session=sess))
  162. # 根据图片模拟生成图片
  163. show_num = 10
  164. gensimple, inputx = sess.run(
  165. [genout, x], feed_dict={x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]})
  166. f, a = plt.subplots(2, 10, figsize=(10, 2))
  167. for i in range(show_num):
  168. a[0][i].imshow(np.reshape(inputx[i], (28, 28)))
  169. a[1][i].imshow(np.reshape(gensimple[i], (28, 28)))
  170. plt.draw()
  171. plt.show()
  172. # begin ae
  173. print("ae_global_step.eval(session=sess)", global_step.eval(session=sess),
  174. int(global_step.eval(session=sess) / total_batch))
  175. for epoch in range(int(global_step.eval(session=sess) / total_batch), training_aeepochs):
  176. avg_cost = 0.
  177. # 遍历全部数据集
  178. for i in range(total_batch):
  179. batch_xs, batch_ys = mnist.train.next_batch(batch_size) # 取数据
  180. feeds = {x: batch_xs, y: batch_ys}
  181. # Fit training using batch data
  182. l_ae, _, ae_step = sess.run([loss_ae, train_ae, global_step], feeds)
  183. # 显示训练中的详细信息
  184. if epoch % display_step == 0:
  185. print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_ae))
  186. # 测试
  187. print("Result:", loss_ae.eval({x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]}, session=sess))
  188. # 根据图片模拟生成图片
  189. show_num = 10
  190. gensimple, inputx = sess.run(
  191. [igenout, x], feed_dict={x: mnist.test.images[:batch_size], y: mnist.test.labels[:batch_size]})
  192. f, a = plt.subplots(2, 10, figsize=(10, 2))
  193. for i in range(show_num):
  194. a[0][i].imshow(np.reshape(inputx[i], (28, 28)))
  195. a[1][i].imshow(np.reshape(gensimple[i], (28, 28)))
  196. plt.draw()
  197. plt.show()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/375673?site=
推荐阅读
相关标签
  

闽ICP备14008679号