当前位置:   article > 正文

对抗生成网络学习(十三)——conditionalGAN生成自己想要的手写数字(tensorflow实现)_基于mnist,使用conditional-gan框架生成手写数字

基于mnist,使用conditional-gan框架生成手写数字

一、背景

其实我原本是不打算做这个模型,因为conditionalGAN能做的,infoGAN也能做,infoGAN我在之前的文章中写到了:对抗神经网络学习(五)——infoGAN生成宽窄不一,高低各异的服装影像(tensorflow实现)。由于最近入职新公司,还在试用期,由于公司缺乏样本,领导就让我做一些手写数字的生成样本出来,方便做后面的工作。自己想了想,之前用过infoGAN,这次就试试不一样的模型conditionalGAN吧。

conditionalGAN网上介绍也挺多,自己实验也完全采用MNIST开源数据集,所以也不算什么保密信息,就把这个过程发上来吧,为后面要做这个模型的人提供经验。

conditionalGAN(条件GAN,CGAN)是Mehdi Mirza于2014年11月份发表的一篇文章,也是GAN系列早期经典模型之一。现在回过头来看,conditionalGAN和infoGAN的最初想法都是一样的,为了生成自己能够控制的图片,而不是随机图片,不过两者的做法稍有区别,infoGAN是引入互信息对输入的随机数据做了约束,而conditionalGAN则是在输入图片的同时增加了一个判定条件,只有生成与输入同样条件时,才会通过判别器。相比较而言,infoGAN发表于2016年06月份,而contionalGAN发表于2014年11月份,所以实际上infoGAN的改进工作更进一步,不过contionalGAN也非常值得学习。

本文主要学习conditionalGAN的一些原理工作,并用最少的代码生成手写0-9数字。

[1] 文章链接:https://arxiv.org/pdf/1411.1784.pdf

二、conditionalGAN原理

conditionalGAN提出的比较早了,所以网上的介绍相当多,下面给出几篇通俗的讲解链接:

[2] 李弘毅老师GAN笔记(二),Conditional GAN

[3] 李宏毅 2018最新GAN课程 class 2 Conditional Generation by GAN

原论文比较短,下面就根据自己的理解来解读一下原论文吧。

首先背景部分,作者提到了GAN的最大优势在于不需要马尔可夫链(Morkov chain),只用后向传播就可以获得梯度,学习过程中不需要任何推理,以及易于将因子多样性和交互性整合至模型中。

而GAN的问题在于模型没有任何条件限制,生成结果是随机的,因此作者考虑加入一些条件信息,比如类别标签,使得生成的结果能够向规定的方向发展。

现有的一些监督神经网络尽管很成功,但仍存在一些问题:一是模型中拥有及其多的输出预测类别;二是输入和输出是1对1的关系(one-to-one),实际上可能是1对多的关系(one-to-many),比如一张图像可能会有多种标签。而解决第二个问题的一般思路就是用条件概率生成模型(onditional probabilistic generative model),即输入条件变量,那么1对多的制图就能够以一种条件概率的形式实例化。

对于模型部分,作者在判别器和生成器中同时添加了额外信息y,比如类别标签或者是其他类型的数据,然后可以将y作为一个附加层同时扔进判别器和生成器。生成器中输入的噪声和额外数据y可以连在一起隐含表示。

先来看一下作者文章中给出的模型结构:

需要注意的是,conditionalGAN里面并没有用到卷积操作,所以这么表达是没有任何问题的。另外我在网上也找到了别的示意图,觉得做的也挺好的,一起放上来:

对于MNIST数据集的类别标签,作者采用了one-hot编码,相关模型的一些参数作者在论文里都有描述,需要的化直接查看原论文就好了。

最后再给出作者在MNIST数据集上的实验效果:

本文主要参考代码为[4],但是这个代码只能生成0-7这八种固定的手写数字,当然我是想要生成0-9这10种数字的,因此对原代码做了一点点的改进,改进的地方不多,但也算是实现了这个功能。

[4] 参考代码:https://github.com/zhangqianhui/Conditional-GAN

三、conditionalGAN实现

1. 所有文件结构

所有文件的结构为:

  1. -- MNIST_data
  2. |------ t10k-images-idx3-ubyte.gz
  3. |------ t10k-labels-idx1-ubyte.gz
  4. |------ train-images-idx3-ubyte.gz
  5. |------ train-labels-idx1-ubyte.gz
  6. -- main.py
  7. -- model.py
  8. -- ops.py
  9. -- utils.py

2. 数据准备

mnist数据集的介绍可以参考我的第一篇文章:对抗神经网络学习(一)——GAN实现mnist手写数字生成(tensorflow实现)。当然这里就不再多说,直接用最简单的方法下载数据集,运行下面两行代码:

  1. from tensorflow.examples.tutorials.mnist import input_data
  2. data = input_data.read_data_sets('MNIST_data/')

运行之后提示如下就说明下载好了:

做好的数据会在'MNIST_data/'路径下,需要注意的是下载好之后一定要解压,conditionalGAN需要类别标签,所以要将所有四个文件全部解压,解压好之后注意文件名是否发生了变化,我在ubuntu系统下做的时候,发现解压后的‘-’变成了‘.’,所以这个细节一定要注意:

3. mnist数据集的一些操作文件utils.py

这个文件里面主要定义了mnist数据的类,还有一些保存图片等相关操作,下面直接给出代码:

  1. import numpy as np
  2. import scipy
  3. import scipy.misc
  4. import matplotlib.pyplot as plt
  5. import os
  6. class Mnist(object):
  7. def __init__(self):
  8. self.dataname = "Mnist"
  9. self.dims = 28 * 28
  10. self.shape = [28, 28, 1]
  11. self.image_size = 28
  12. self.data, self.data_y = self.load_mnist()
  13. def load_mnist(self):
  14. data_dir = "./MNIST_data"
  15. fd = open(os.path.join(data_dir, 'train-images.idx3-ubyte'))
  16. loaded = np.fromfile(file=fd, dtype=np.uint8)
  17. trX = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float)
  18. fd = open(os.path.join(data_dir, 'train-labels.idx1-ubyte'))
  19. loaded = np.fromfile(file=fd, dtype=np.uint8)
  20. trY = loaded[8:].reshape(60000).astype(np.float)
  21. fd = open(os.path.join(data_dir, 't10k-images.idx3-ubyte'))
  22. loaded = np.fromfile(file=fd, dtype=np.uint8)
  23. teX = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)
  24. fd = open(os.path.join(data_dir, 't10k-labels.idx1-ubyte'))
  25. loaded = np.fromfile(file=fd, dtype=np.uint8)
  26. teY = loaded[8:].reshape(10000).astype(np.float)
  27. trY = np.asarray(trY)
  28. teY = np.asarray(teY)
  29. X = np.concatenate((trX, teX), axis=0)
  30. y = np.concatenate((trY, teY), axis=0)
  31. seed = 547
  32. np.random.seed(seed)
  33. np.random.shuffle(X)
  34. np.random.seed(seed)
  35. np.random.shuffle(y)
  36. # convert label to one-hot
  37. y_vec = np.zeros((len(y), 10), dtype=np.float)
  38. for i, label in enumerate(y):
  39. y_vec[i, int(y[i])] = 1.0
  40. return X / 255., y_vec
  41. def getNext_batch(self, iter_num=0, batch_size=100):
  42. ro_num = len(self.data) / batch_size - 1
  43. if iter_num % ro_num == 0:
  44. length = len(self.data)
  45. perm = np.arange(length)
  46. np.random.shuffle(perm)
  47. self.data = np.array(self.data)
  48. self.data = self.data[perm]
  49. self.data_y = np.array(self.data_y)
  50. self.data_y = self.data_y[perm]
  51. return self.data[int(iter_num % ro_num) * batch_size: int(iter_num % ro_num + 1) * batch_size] \
  52. , self.data_y[int(iter_num % ro_num) * batch_size: int(iter_num % ro_num + 1) * batch_size]
  53. def get_image(image_path, is_grayscale=False):
  54. return np.array(inverse_transform(imread(image_path, is_grayscale)))
  55. def save_images(images, size, image_path):
  56. return imsave(inverse_transform(images), size, image_path)
  57. def imread(path, is_grayscale=False):
  58. if (is_grayscale):
  59. return scipy.misc.imread(path, flatten=True).astype(np.float)
  60. else:
  61. return scipy.misc.imread(path).astype(np.float)
  62. def imsave(images, size, path):
  63. return scipy.misc.imsave(path, merge(images, size))
  64. def merge(images, size):
  65. h, w = images.shape[1], images.shape[2]
  66. img = np.zeros((h * size[0], w * size[1], 3))
  67. for idx, image in enumerate(images):
  68. i = idx % size[1]
  69. j = idx // size[1]
  70. img[j * h:j * h + h, i * w: i * w + w, :] = image
  71. return img
  72. def inverse_transform(image):
  73. return (image + 1.) / 2.
  74. def read_image_list(category):
  75. filenames = []
  76. print("list file")
  77. list = os.listdir(category)
  78. for file in list:
  79. filenames.append(category + "/" + file)
  80. print("list file ending!")
  81. return filenames
  82. ##from caffe
  83. def vis_square(visu_path, data, type):
  84. """Take an array of shape (n, height, width) or (n, height, width , 3)
  85. and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
  86. # normalize data for display
  87. data = (data - data.min()) / (data.max() - data.min())
  88. # force the number of filters to be square
  89. n = int(np.ceil(np.sqrt(data.shape[0])))
  90. padding = (((0, n ** 2 - data.shape[0]),
  91. (0, 1), (0, 1)) # add some space between filters
  92. + ((0, 0),) * (data.ndim - 3)) # don't pad the last dimension (if there is one)
  93. data = np.pad(data, padding, mode='constant', constant_values=1) # pad with ones (white)
  94. # tilethe filters into an im age
  95. data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
  96. data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
  97. plt.imshow(data[:, :, 0])
  98. plt.axis('off')
  99. if type:
  100. plt.savefig('./{}/weights.png'.format(visu_path), format='png')
  101. else:
  102. plt.savefig('./{}/activation.png'.format(visu_path), format='png')
  103. def sample_label():
  104. num = 100
  105. label_vector = np.zeros((num, 10), dtype=np.float)
  106. for i in range(0, num):
  107. label_vector[i, int(i / 10)] = 1.0
  108. return label_vector

4. 图层文件ops.py

ops.py文件里面主要定义了一些图层操作,比如反卷积,全链接,BN层等,下面先给出代码:

  1. import tensorflow as tf
  2. from tensorflow.contrib.layers.python.layers import batch_norm, variance_scaling_initializer
  3. #the implements of leakyRelu
  4. def lrelu(x , alpha = 0.2 , name="LeakyReLU"):
  5. return tf.maximum(x, alpha*x)
  6. def conv2d(input_, output_dim,
  7. k_h=3, k_w=3, d_h=2, d_w=2,
  8. name="conv2d"):
  9. with tf.variable_scope(name):
  10. w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
  11. initializer= variance_scaling_initializer())
  12. conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
  13. biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
  14. conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
  15. return conv, w
  16. def de_conv(input_, output_shape,
  17. k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, name="deconv2d",
  18. with_w=False, initializer = variance_scaling_initializer()):
  19. with tf.variable_scope(name):
  20. # filter : [height, width, output_channels, in_channels]
  21. w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
  22. initializer = initializer)
  23. try:
  24. deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
  25. strides=[1, d_h, d_w, 1])
  26. # Support for verisons of TensorFlow before 0.7.0
  27. except AttributeError:
  28. deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
  29. strides=[1, d_h, d_w, 1])
  30. biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
  31. deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
  32. if with_w:
  33. return deconv, w, biases
  34. else:
  35. return deconv
  36. # GEN_NET
  37. def fully_connect(input_, output_size, scope=None, with_w=False, initializer=variance_scaling_initializer()):
  38. shape = input_.get_shape().as_list()
  39. with tf.variable_scope(scope or "Linear"):
  40. matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, initializer=initializer)
  41. bias = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(0.0))
  42. if with_w:
  43. return tf.matmul(input_, matrix) + bias, matrix, bias
  44. else:
  45. return tf.matmul(input_, matrix) + bias
  46. def conv_cond_concat(x, y):
  47. """Concatenate conditioning vector on feature map axis."""
  48. x_shapes = x.get_shape()
  49. y_shapes = y.get_shape()
  50. return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
  51. # GEN_NET
  52. def batch_normal(input , scope="scope" , reuse=False):
  53. return batch_norm(input, epsilon=1e-5, decay=0.9, scale=True, scope=scope, reuse=reuse, updates_collections=None)

5. conditionalGAN模型文件model.py

model.py文件里面定义了conditionalGAN,也是最关键的文件,下面给出代码:

  1. from utils import save_images, vis_square, sample_label
  2. from tensorflow.contrib.layers.python.layers import xavier_initializer
  3. import cv2
  4. from ops import conv2d, lrelu, de_conv, fully_connect, conv_cond_concat, batch_normal
  5. import tensorflow as tf
  6. import numpy as np
  7. class CGAN(object):
  8. # build model
  9. def __init__(self, data_ob, sample_dir, output_size, learn_rate, batch_size,
  10. z_dim, y_dim, log_dir, model_path, visua_path):
  11. self.data_ob = data_ob
  12. self.sample_dir = sample_dir
  13. self.output_size = output_size
  14. self.learn_rate = learn_rate
  15. self.batch_size = batch_size
  16. self.z_dim = z_dim
  17. self.y_dim = y_dim
  18. self.log_dir = log_dir
  19. self.model_path = model_path
  20. self.vi_path = visua_path
  21. self.channel = self.data_ob.shape[2]
  22. self.images = tf.placeholder(tf.float32, [batch_size, self.output_size, self.output_size, self.channel])
  23. self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim])
  24. self.y = tf.placeholder(tf.float32, [self.batch_size, self.y_dim])
  25. def build_model(self):
  26. self.fake_images = self.gern_net(self.z, self.y)
  27. G_image = tf.summary.image("G_out", self.fake_images)
  28. ##the loss of gerenate network
  29. D_pro, D_logits = self.dis_net(self.images, self.y, False)
  30. D_pro_sum = tf.summary.histogram("D_pro", D_pro)
  31. G_pro, G_logits = self.dis_net(self.fake_images, self.y, True)
  32. G_pro_sum = tf.summary.histogram("G_pro", G_pro)
  33. D_fake_loss = tf.reduce_mean(
  34. tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(G_pro), logits=G_logits))
  35. D_real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(D_pro), logits=D_logits))
  36. G_fake_loss = tf.reduce_mean(
  37. tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(G_pro), logits=G_logits))
  38. self.D_loss = D_real_loss + D_fake_loss
  39. self.G_loss = G_fake_loss
  40. loss_sum = tf.summary.scalar("D_loss", self.D_loss)
  41. G_loss_sum = tf.summary.scalar("G_loss", self.G_loss)
  42. self.merged_summary_op_d = tf.summary.merge([loss_sum, D_pro_sum])
  43. self.merged_summary_op_g = tf.summary.merge([G_loss_sum, G_pro_sum, G_image])
  44. t_vars = tf.trainable_variables()
  45. self.d_var = [var for var in t_vars if 'dis' in var.name]
  46. self.g_var = [var for var in t_vars if 'gen' in var.name]
  47. self.saver = tf.train.Saver()
  48. def train(self):
  49. opti_D = tf.train.AdamOptimizer(learning_rate=self.learn_rate,
  50. beta1=0.5).minimize(self.D_loss, var_list=self.d_var)
  51. opti_G = tf.train.AdamOptimizer(learning_rate=self.learn_rate,
  52. beta1=0.5).minimize(self.G_loss, var_list=self.g_var)
  53. init = tf.global_variables_initializer()
  54. config = tf.ConfigProto()
  55. config.gpu_options.allow_growth = True
  56. with tf.Session(config=config) as sess:
  57. sess.run(init)
  58. summary_writer = tf.summary.FileWriter(self.log_dir, graph=sess.graph)
  59. step = 0
  60. while step <= 10000:
  61. realbatch_array, real_labels = self.data_ob.getNext_batch(step)
  62. # Get the z
  63. batch_z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_dim])
  64. _, summary_str = sess.run([opti_D, self.merged_summary_op_d],
  65. feed_dict={self.images: realbatch_array, self.z: batch_z, self.y: real_labels})
  66. summary_writer.add_summary(summary_str, step)
  67. _, summary_str = sess.run([opti_G, self.merged_summary_op_g],
  68. feed_dict={self.z: batch_z, self.y: real_labels})
  69. summary_writer.add_summary(summary_str, step)
  70. if step % 50 == 0:
  71. D_loss = sess.run(self.D_loss, feed_dict={self.images: realbatch_array, self.z: batch_z, self.y: real_labels})
  72. fake_loss = sess.run(self.G_loss, feed_dict={self.z: batch_z, self.y: real_labels})
  73. print("Step %d: D: loss = %.7f G: loss=%.7f " % (step, D_loss, fake_loss))
  74. if np.mod(step, 50) == 1 and step != 0:
  75. sample_images = sess.run(self.fake_images, feed_dict={self.z: batch_z, self.y: sample_label()})
  76. save_images(sample_images, [10, 10],
  77. './{}/train_{:04d}.png'.format(self.sample_dir, step))
  78. self.saver.save(sess, self.model_path)
  79. step = step + 1
  80. save_path = self.saver.save(sess, self.model_path)
  81. print("Model saved in file: %s" % save_path)
  82. def test(self):
  83. init = tf.initialize_all_variables()
  84. with tf.Session() as sess:
  85. sess.run(init)
  86. self.saver.restore(sess, self.model_path)
  87. sample_z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_dim])
  88. output = sess.run(self.fake_images, feed_dict={self.z: sample_z, self.y: sample_label()})
  89. save_images(output, [10, 10], './{}/test{:02d}_{:04d}.png'.format(self.sample_dir, 0, 0))
  90. image = cv2.imread('./{}/test{:02d}_{:04d}.png'.format(self.sample_dir, 0, 0), 0)
  91. cv2.imshow("test", image)
  92. cv2.waitKey(0)
  93. print("Test finish!")
  94. def visual(self):
  95. init = tf.initialize_all_variables()
  96. with tf.Session() as sess:
  97. sess.run(init)
  98. self.saver.restore(sess, self.model_path)
  99. realbatch_array, real_labels = self.data_ob.getNext_batch(0)
  100. batch_z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_dim])
  101. # visualize the weights 1 or you can change weights_2 .
  102. conv_weights = sess.run([tf.get_collection('weight_2')])
  103. vis_square(self.vi_path, conv_weights[0][0].transpose(3, 0, 1, 2), type=1)
  104. # visualize the activation 1
  105. ac = sess.run([tf.get_collection('ac_2')],
  106. feed_dict={self.images: realbatch_array[:64], self.z: batch_z, self.y: sample_label()})
  107. vis_square(self.vi_path, ac[0][0].transpose(3, 1, 2, 0), type=0)
  108. print("the visualization finish!")
  109. def gern_net(self, z, y):
  110. with tf.variable_scope('generator') as scope:
  111. yb = tf.reshape(y, shape=[self.batch_size, 1, 1, self.y_dim])
  112. z = tf.concat([z, y], 1) # [100, 10]
  113. c1, c2 = int(self.output_size / 4), int(self.output_size / 2) # 7, 14
  114. # 10 stand for the num of labels
  115. d1 = tf.nn.relu(batch_normal(fully_connect(z, output_size=1024,
  116. scope='gen_fully'), scope='gen_bn1'))
  117. d1 = tf.concat([d1, y], 1) # [1024, 10]
  118. d2 = tf.nn.relu(batch_normal(fully_connect(d1, output_size=7*7*2*100, scope='gen_fully2'),
  119. scope='gen_bn2'))
  120. d2 = tf.reshape(d2, [self.batch_size, c1, c1, 100 * 2]) # [100, 7, 7, 200]
  121. d2 = conv_cond_concat(d2, yb)
  122. d3 = tf.nn.relu(batch_normal(de_conv(d2, output_shape=[self.batch_size, c2, c2, 200],
  123. name='gen_deconv1'), scope='gen_bn3'))
  124. d3 = conv_cond_concat(d3, yb)
  125. d4 = de_conv(d3, output_shape=[self.batch_size, self.output_size, self.output_size, self.channel],
  126. name='gen_deconv2', initializer=xavier_initializer())
  127. return tf.nn.sigmoid(d4)
  128. def dis_net(self, images, y, reuse=False):
  129. with tf.variable_scope("discriminator") as scope:
  130. if reuse:
  131. scope.reuse_variables()
  132. # mnist data's shape is (28 , 28 , 1)
  133. yb = tf.reshape(y, shape=[self.batch_size, 1, 1, self.y_dim])
  134. # concat
  135. concat_data = conv_cond_concat(images, yb)
  136. conv1, w1 = conv2d(concat_data, output_dim=10, name='dis_conv1')
  137. tf.add_to_collection('weight_1', w1)
  138. conv1 = lrelu(conv1)
  139. conv1 = conv_cond_concat(conv1, yb)
  140. tf.add_to_collection('ac_1', conv1)
  141. conv2, w2 = conv2d(conv1, output_dim=64, name='dis_conv2')
  142. tf.add_to_collection('weight_2', w2)
  143. conv2 = lrelu(batch_normal(conv2, scope='dis_bn1'))
  144. tf.add_to_collection('ac_2', conv2)
  145. conv2 = tf.reshape(conv2, [self.batch_size, -1])
  146. conv2 = tf.concat([conv2, y], 1)
  147. f1 = lrelu(batch_normal(fully_connect(conv2, output_size=1024, scope='dis_fully1'), scope='dis_bn2', reuse=reuse))
  148. f1 = tf.concat([f1, y], 1)
  149. out = fully_connect(f1, output_size=1, scope='dis_fully2', initializer = xavier_initializer())
  150. return tf.nn.sigmoid(out), out

6. 主文件main.py

最后主文件就是用来控制训练或者测试或者可视化过程的,先给出代码:

  1. from model import CGAN
  2. import tensorflow as tf
  3. from utils import Mnist
  4. import os
  5. flags = tf.app.flags
  6. flags.DEFINE_string("sample_dir", "samples_for_test", "the dir of sample images")
  7. flags.DEFINE_integer("output_size", 28, "the size of generate image")
  8. flags.DEFINE_float("learn_rate", 0.0002, "the learning rate for gan")
  9. flags.DEFINE_integer("batch_size", 100, "the batch number")
  10. flags.DEFINE_integer("z_dim", 100, "the dimension of noise z")
  11. flags.DEFINE_integer("y_dim", 10, "the dimension of condition y")
  12. flags.DEFINE_string("log_dir", "/tmp/tensorflow_mnist", "the path of tensorflow's log")
  13. flags.DEFINE_string("model_path", "model/model.ckpt", "the path of model")
  14. flags.DEFINE_string("visua_path", "visualization", "the path of visuzation images")
  15. flags.DEFINE_integer("op", 0, "0: train ; 1:test ; 2:visualize")
  16. FLAGS = flags.FLAGS
  17. #
  18. if not os.path.exists(FLAGS.sample_dir):
  19. os.makedirs(FLAGS.sample_dir)
  20. if not os.path.exists(FLAGS.log_dir):
  21. os.makedirs(FLAGS.log_dir)
  22. if not os.path.exists(FLAGS.model_path):
  23. os.makedirs(FLAGS.model_path)
  24. if not os.path.exists(FLAGS.visua_path):
  25. os.makedirs(FLAGS.visua_path)
  26. def main(_):
  27. mn_object = Mnist()
  28. cg = CGAN(data_ob=mn_object, sample_dir=FLAGS.sample_dir, output_size=FLAGS.output_size,
  29. learn_rate=FLAGS.learn_rate, batch_size=FLAGS.batch_size, z_dim=FLAGS.z_dim,
  30. y_dim=FLAGS.y_dim, log_dir=FLAGS.log_dir, model_path=FLAGS.model_path,
  31. visua_path=FLAGS.visua_path)
  32. cg.build_model()
  33. if FLAGS.op == 0:
  34. cg.train()
  35. elif FLAGS.op == 1:
  36. cg.test()
  37. else:
  38. cg.visual()
  39. if __name__ == '__main__':
  40. tf.app.run()

四、实验结果

做好上述工作之后直接运行main文件就可以了,下面来简单看一下实验效果:

当epoch=0时的生成结果:

当epoch=50时的生成结果:

当epoch=100时的生成结果:

当epoch=400时的生成结果:

当epoch=1000时的生成结果:

当epoch=3000时的生成结果:

当epoch=5000时的生成结果:

可以看到,当epoch=3000时的生成结果就已经很不错了。不过相对于GAN来说训练过程还是慢了一点。

五、分析

1. conditionalGAN是GAN的更进一步,仍然采用的是神经网络而不是卷积,该模型对GAN的输入进行了标签约束。

2. 如果你需要用conditionalGAN的话,不如考虑下infoGAN。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/102986
推荐阅读
相关标签
  

闽ICP备14008679号