赞
踩
其实我原本是不打算做这个模型,因为conditionalGAN能做的,infoGAN也能做,infoGAN我在之前的文章中写到了:对抗神经网络学习(五)——infoGAN生成宽窄不一,高低各异的服装影像(tensorflow实现)。由于最近入职新公司,还在试用期,由于公司缺乏样本,领导就让我做一些手写数字的生成样本出来,方便做后面的工作。自己想了想,之前用过infoGAN,这次就试试不一样的模型conditionalGAN吧。
conditionalGAN网上介绍也挺多,自己实验也完全采用MNIST开源数据集,所以也不算什么保密信息,就把这个过程发上来吧,为后面要做这个模型的人提供经验。
conditionalGAN(条件GAN,CGAN)是Mehdi Mirza于2014年11月份发表的一篇文章,也是GAN系列早期经典模型之一。现在回过头来看,conditionalGAN和infoGAN的最初想法都是一样的,为了生成自己能够控制的图片,而不是随机图片,不过两者的做法稍有区别,infoGAN是引入互信息对输入的随机数据做了约束,而conditionalGAN则是在输入图片的同时增加了一个判定条件,只有生成与输入同样条件时,才会通过判别器。相比较而言,infoGAN发表于2016年06月份,而contionalGAN发表于2014年11月份,所以实际上infoGAN的改进工作更进一步,不过contionalGAN也非常值得学习。
本文主要学习conditionalGAN的一些原理工作,并用最少的代码生成手写0-9数字。
[1] 文章链接:https://arxiv.org/pdf/1411.1784.pdf
conditionalGAN提出的比较早了,所以网上的介绍相当多,下面给出几篇通俗的讲解链接:
[2] 李弘毅老师GAN笔记(二),Conditional GAN
[3] 李宏毅 2018最新GAN课程 class 2 Conditional Generation by GAN
原论文比较短,下面就根据自己的理解来解读一下原论文吧。
首先背景部分,作者提到了GAN的最大优势在于不需要马尔可夫链(Morkov chain),只用后向传播就可以获得梯度,学习过程中不需要任何推理,以及易于将因子多样性和交互性整合至模型中。
而GAN的问题在于模型没有任何条件限制,生成结果是随机的,因此作者考虑加入一些条件信息,比如类别标签,使得生成的结果能够向规定的方向发展。
现有的一些监督神经网络尽管很成功,但仍存在一些问题:一是模型中拥有及其多的输出预测类别;二是输入和输出是1对1的关系(one-to-one),实际上可能是1对多的关系(one-to-many),比如一张图像可能会有多种标签。而解决第二个问题的一般思路就是用条件概率生成模型(onditional probabilistic generative model),即输入条件变量,那么1对多的制图就能够以一种条件概率的形式实例化。
对于模型部分,作者在判别器和生成器中同时添加了额外信息y,比如类别标签或者是其他类型的数据,然后可以将y作为一个附加层同时扔进判别器和生成器。生成器中输入的噪声和额外数据y可以连在一起隐含表示。
先来看一下作者文章中给出的模型结构:
需要注意的是,conditionalGAN里面并没有用到卷积操作,所以这么表达是没有任何问题的。另外我在网上也找到了别的示意图,觉得做的也挺好的,一起放上来:
对于MNIST数据集的类别标签,作者采用了one-hot编码,相关模型的一些参数作者在论文里都有描述,需要的化直接查看原论文就好了。
最后再给出作者在MNIST数据集上的实验效果:
本文主要参考代码为[4],但是这个代码只能生成0-7这八种固定的手写数字,当然我是想要生成0-9这10种数字的,因此对原代码做了一点点的改进,改进的地方不多,但也算是实现了这个功能。
[4] 参考代码:https://github.com/zhangqianhui/Conditional-GAN
所有文件的结构为:
- -- MNIST_data
- |------ t10k-images-idx3-ubyte.gz
- |------ t10k-labels-idx1-ubyte.gz
- |------ train-images-idx3-ubyte.gz
- |------ train-labels-idx1-ubyte.gz
- -- main.py
- -- model.py
- -- ops.py
- -- utils.py
mnist数据集的介绍可以参考我的第一篇文章:对抗神经网络学习(一)——GAN实现mnist手写数字生成(tensorflow实现)。当然这里就不再多说,直接用最简单的方法下载数据集,运行下面两行代码:
- from tensorflow.examples.tutorials.mnist import input_data
-
- data = input_data.read_data_sets('MNIST_data/')
运行之后提示如下就说明下载好了:
做好的数据会在'MNIST_data/'路径下,需要注意的是下载好之后一定要解压,conditionalGAN需要类别标签,所以要将所有四个文件全部解压,解压好之后注意文件名是否发生了变化,我在ubuntu系统下做的时候,发现解压后的‘-’变成了‘.’,所以这个细节一定要注意:
这个文件里面主要定义了mnist数据的类,还有一些保存图片等相关操作,下面直接给出代码:
- import numpy as np
- import scipy
- import scipy.misc
- import matplotlib.pyplot as plt
- import os
-
-
- class Mnist(object):
-
- def __init__(self):
-
- self.dataname = "Mnist"
- self.dims = 28 * 28
- self.shape = [28, 28, 1]
- self.image_size = 28
- self.data, self.data_y = self.load_mnist()
-
- def load_mnist(self):
-
- data_dir = "./MNIST_data"
- fd = open(os.path.join(data_dir, 'train-images.idx3-ubyte'))
- loaded = np.fromfile(file=fd, dtype=np.uint8)
- trX = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float)
-
- fd = open(os.path.join(data_dir, 'train-labels.idx1-ubyte'))
- loaded = np.fromfile(file=fd, dtype=np.uint8)
- trY = loaded[8:].reshape(60000).astype(np.float)
-
- fd = open(os.path.join(data_dir, 't10k-images.idx3-ubyte'))
- loaded = np.fromfile(file=fd, dtype=np.uint8)
- teX = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)
-
- fd = open(os.path.join(data_dir, 't10k-labels.idx1-ubyte'))
- loaded = np.fromfile(file=fd, dtype=np.uint8)
- teY = loaded[8:].reshape(10000).astype(np.float)
-
- trY = np.asarray(trY)
- teY = np.asarray(teY)
-
- X = np.concatenate((trX, teX), axis=0)
- y = np.concatenate((trY, teY), axis=0)
-
- seed = 547
-
- np.random.seed(seed)
- np.random.shuffle(X)
- np.random.seed(seed)
- np.random.shuffle(y)
-
- # convert label to one-hot
-
- y_vec = np.zeros((len(y), 10), dtype=np.float)
- for i, label in enumerate(y):
- y_vec[i, int(y[i])] = 1.0
-
- return X / 255., y_vec
-
- def getNext_batch(self, iter_num=0, batch_size=100):
-
- ro_num = len(self.data) / batch_size - 1
-
- if iter_num % ro_num == 0:
- length = len(self.data)
- perm = np.arange(length)
- np.random.shuffle(perm)
- self.data = np.array(self.data)
- self.data = self.data[perm]
- self.data_y = np.array(self.data_y)
- self.data_y = self.data_y[perm]
-
- return self.data[int(iter_num % ro_num) * batch_size: int(iter_num % ro_num + 1) * batch_size] \
- , self.data_y[int(iter_num % ro_num) * batch_size: int(iter_num % ro_num + 1) * batch_size]
-
-
- def get_image(image_path, is_grayscale=False):
- return np.array(inverse_transform(imread(image_path, is_grayscale)))
-
-
- def save_images(images, size, image_path):
- return imsave(inverse_transform(images), size, image_path)
-
-
- def imread(path, is_grayscale=False):
- if (is_grayscale):
- return scipy.misc.imread(path, flatten=True).astype(np.float)
- else:
- return scipy.misc.imread(path).astype(np.float)
-
-
- def imsave(images, size, path):
- return scipy.misc.imsave(path, merge(images, size))
-
-
- def merge(images, size):
- h, w = images.shape[1], images.shape[2]
- img = np.zeros((h * size[0], w * size[1], 3))
- for idx, image in enumerate(images):
- i = idx % size[1]
- j = idx // size[1]
- img[j * h:j * h + h, i * w: i * w + w, :] = image
-
- return img
-
-
- def inverse_transform(image):
- return (image + 1.) / 2.
-
-
- def read_image_list(category):
- filenames = []
- print("list file")
- list = os.listdir(category)
-
- for file in list:
- filenames.append(category + "/" + file)
-
- print("list file ending!")
-
- return filenames
-
-
- ##from caffe
- def vis_square(visu_path, data, type):
- """Take an array of shape (n, height, width) or (n, height, width , 3)
- and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
-
- # normalize data for display
- data = (data - data.min()) / (data.max() - data.min())
-
- # force the number of filters to be square
- n = int(np.ceil(np.sqrt(data.shape[0])))
-
- padding = (((0, n ** 2 - data.shape[0]),
- (0, 1), (0, 1)) # add some space between filters
- + ((0, 0),) * (data.ndim - 3)) # don't pad the last dimension (if there is one)
- data = np.pad(data, padding, mode='constant', constant_values=1) # pad with ones (white)
-
- # tilethe filters into an im age
- data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
- data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
-
- plt.imshow(data[:, :, 0])
- plt.axis('off')
-
- if type:
- plt.savefig('./{}/weights.png'.format(visu_path), format='png')
- else:
- plt.savefig('./{}/activation.png'.format(visu_path), format='png')
-
-
- def sample_label():
- num = 100
- label_vector = np.zeros((num, 10), dtype=np.float)
- for i in range(0, num):
- label_vector[i, int(i / 10)] = 1.0
- return label_vector
-
-
ops.py文件里面主要定义了一些图层操作,比如反卷积,全链接,BN层等,下面先给出代码:
- import tensorflow as tf
- from tensorflow.contrib.layers.python.layers import batch_norm, variance_scaling_initializer
-
- #the implements of leakyRelu
- def lrelu(x , alpha = 0.2 , name="LeakyReLU"):
- return tf.maximum(x, alpha*x)
-
- def conv2d(input_, output_dim,
- k_h=3, k_w=3, d_h=2, d_w=2,
- name="conv2d"):
- with tf.variable_scope(name):
-
- w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim],
- initializer= variance_scaling_initializer())
- conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME')
-
- biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0))
- conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape())
-
- return conv, w
-
- def de_conv(input_, output_shape,
- k_h=3, k_w=3, d_h=2, d_w=2, stddev=0.02, name="deconv2d",
- with_w=False, initializer = variance_scaling_initializer()):
-
- with tf.variable_scope(name):
- # filter : [height, width, output_channels, in_channels]
- w = tf.get_variable('w', [k_h, k_w, output_shape[-1], input_.get_shape()[-1]],
- initializer = initializer)
- try:
- deconv = tf.nn.conv2d_transpose(input_, w, output_shape=output_shape,
- strides=[1, d_h, d_w, 1])
- # Support for verisons of TensorFlow before 0.7.0
- except AttributeError:
- deconv = tf.nn.deconv2d(input_, w, output_shape=output_shape,
- strides=[1, d_h, d_w, 1])
-
- biases = tf.get_variable('biases', [output_shape[-1]], initializer=tf.constant_initializer(0.0))
- deconv = tf.reshape(tf.nn.bias_add(deconv, biases), deconv.get_shape())
-
- if with_w:
- return deconv, w, biases
- else:
- return deconv
-
-
- # GEN_NET
- def fully_connect(input_, output_size, scope=None, with_w=False, initializer=variance_scaling_initializer()):
-
- shape = input_.get_shape().as_list()
-
- with tf.variable_scope(scope or "Linear"):
- matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, initializer=initializer)
- bias = tf.get_variable("bias", [output_size], initializer=tf.constant_initializer(0.0))
- if with_w:
- return tf.matmul(input_, matrix) + bias, matrix, bias
- else:
- return tf.matmul(input_, matrix) + bias
-
- def conv_cond_concat(x, y):
- """Concatenate conditioning vector on feature map axis."""
- x_shapes = x.get_shape()
- y_shapes = y.get_shape()
-
- return tf.concat([x, y*tf.ones([x_shapes[0], x_shapes[1], x_shapes[2], y_shapes[3]])], 3)
-
- # GEN_NET
- def batch_normal(input , scope="scope" , reuse=False):
- return batch_norm(input, epsilon=1e-5, decay=0.9, scale=True, scope=scope, reuse=reuse, updates_collections=None)
model.py文件里面定义了conditionalGAN,也是最关键的文件,下面给出代码:
- from utils import save_images, vis_square, sample_label
- from tensorflow.contrib.layers.python.layers import xavier_initializer
- import cv2
- from ops import conv2d, lrelu, de_conv, fully_connect, conv_cond_concat, batch_normal
- import tensorflow as tf
- import numpy as np
-
-
- class CGAN(object):
-
- # build model
- def __init__(self, data_ob, sample_dir, output_size, learn_rate, batch_size,
- z_dim, y_dim, log_dir, model_path, visua_path):
-
- self.data_ob = data_ob
- self.sample_dir = sample_dir
- self.output_size = output_size
- self.learn_rate = learn_rate
- self.batch_size = batch_size
- self.z_dim = z_dim
- self.y_dim = y_dim
- self.log_dir = log_dir
- self.model_path = model_path
- self.vi_path = visua_path
- self.channel = self.data_ob.shape[2]
- self.images = tf.placeholder(tf.float32, [batch_size, self.output_size, self.output_size, self.channel])
- self.z = tf.placeholder(tf.float32, [self.batch_size, self.z_dim])
- self.y = tf.placeholder(tf.float32, [self.batch_size, self.y_dim])
-
- def build_model(self):
-
- self.fake_images = self.gern_net(self.z, self.y)
- G_image = tf.summary.image("G_out", self.fake_images)
- ##the loss of gerenate network
- D_pro, D_logits = self.dis_net(self.images, self.y, False)
- D_pro_sum = tf.summary.histogram("D_pro", D_pro)
-
- G_pro, G_logits = self.dis_net(self.fake_images, self.y, True)
- G_pro_sum = tf.summary.histogram("G_pro", G_pro)
-
- D_fake_loss = tf.reduce_mean(
- tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(G_pro), logits=G_logits))
-
- D_real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(D_pro), logits=D_logits))
- G_fake_loss = tf.reduce_mean(
- tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(G_pro), logits=G_logits))
-
- self.D_loss = D_real_loss + D_fake_loss
- self.G_loss = G_fake_loss
-
- loss_sum = tf.summary.scalar("D_loss", self.D_loss)
- G_loss_sum = tf.summary.scalar("G_loss", self.G_loss)
-
- self.merged_summary_op_d = tf.summary.merge([loss_sum, D_pro_sum])
- self.merged_summary_op_g = tf.summary.merge([G_loss_sum, G_pro_sum, G_image])
-
- t_vars = tf.trainable_variables()
- self.d_var = [var for var in t_vars if 'dis' in var.name]
- self.g_var = [var for var in t_vars if 'gen' in var.name]
-
- self.saver = tf.train.Saver()
-
- def train(self):
-
- opti_D = tf.train.AdamOptimizer(learning_rate=self.learn_rate,
- beta1=0.5).minimize(self.D_loss, var_list=self.d_var)
- opti_G = tf.train.AdamOptimizer(learning_rate=self.learn_rate,
- beta1=0.5).minimize(self.G_loss, var_list=self.g_var)
- init = tf.global_variables_initializer()
- config = tf.ConfigProto()
- config.gpu_options.allow_growth = True
-
- with tf.Session(config=config) as sess:
-
- sess.run(init)
- summary_writer = tf.summary.FileWriter(self.log_dir, graph=sess.graph)
-
- step = 0
- while step <= 10000:
-
- realbatch_array, real_labels = self.data_ob.getNext_batch(step)
-
- # Get the z
- batch_z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_dim])
-
- _, summary_str = sess.run([opti_D, self.merged_summary_op_d],
- feed_dict={self.images: realbatch_array, self.z: batch_z, self.y: real_labels})
- summary_writer.add_summary(summary_str, step)
-
- _, summary_str = sess.run([opti_G, self.merged_summary_op_g],
- feed_dict={self.z: batch_z, self.y: real_labels})
- summary_writer.add_summary(summary_str, step)
-
- if step % 50 == 0:
- D_loss = sess.run(self.D_loss, feed_dict={self.images: realbatch_array, self.z: batch_z, self.y: real_labels})
- fake_loss = sess.run(self.G_loss, feed_dict={self.z: batch_z, self.y: real_labels})
- print("Step %d: D: loss = %.7f G: loss=%.7f " % (step, D_loss, fake_loss))
-
- if np.mod(step, 50) == 1 and step != 0:
-
- sample_images = sess.run(self.fake_images, feed_dict={self.z: batch_z, self.y: sample_label()})
- save_images(sample_images, [10, 10],
- './{}/train_{:04d}.png'.format(self.sample_dir, step))
-
- self.saver.save(sess, self.model_path)
-
- step = step + 1
-
- save_path = self.saver.save(sess, self.model_path)
- print("Model saved in file: %s" % save_path)
-
- def test(self):
- init = tf.initialize_all_variables()
-
- with tf.Session() as sess:
- sess.run(init)
-
- self.saver.restore(sess, self.model_path)
- sample_z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_dim])
- output = sess.run(self.fake_images, feed_dict={self.z: sample_z, self.y: sample_label()})
- save_images(output, [10, 10], './{}/test{:02d}_{:04d}.png'.format(self.sample_dir, 0, 0))
- image = cv2.imread('./{}/test{:02d}_{:04d}.png'.format(self.sample_dir, 0, 0), 0)
- cv2.imshow("test", image)
- cv2.waitKey(0)
- print("Test finish!")
-
- def visual(self):
- init = tf.initialize_all_variables()
- with tf.Session() as sess:
- sess.run(init)
-
- self.saver.restore(sess, self.model_path)
-
- realbatch_array, real_labels = self.data_ob.getNext_batch(0)
- batch_z = np.random.uniform(-1, 1, size=[self.batch_size, self.z_dim])
- # visualize the weights 1 or you can change weights_2 .
- conv_weights = sess.run([tf.get_collection('weight_2')])
- vis_square(self.vi_path, conv_weights[0][0].transpose(3, 0, 1, 2), type=1)
-
- # visualize the activation 1
- ac = sess.run([tf.get_collection('ac_2')],
- feed_dict={self.images: realbatch_array[:64], self.z: batch_z, self.y: sample_label()})
- vis_square(self.vi_path, ac[0][0].transpose(3, 1, 2, 0), type=0)
-
- print("the visualization finish!")
-
- def gern_net(self, z, y):
- with tf.variable_scope('generator') as scope:
-
- yb = tf.reshape(y, shape=[self.batch_size, 1, 1, self.y_dim])
- z = tf.concat([z, y], 1) # [100, 10]
- c1, c2 = int(self.output_size / 4), int(self.output_size / 2) # 7, 14
-
- # 10 stand for the num of labels
- d1 = tf.nn.relu(batch_normal(fully_connect(z, output_size=1024,
- scope='gen_fully'), scope='gen_bn1'))
- d1 = tf.concat([d1, y], 1) # [1024, 10]
-
- d2 = tf.nn.relu(batch_normal(fully_connect(d1, output_size=7*7*2*100, scope='gen_fully2'),
- scope='gen_bn2'))
- d2 = tf.reshape(d2, [self.batch_size, c1, c1, 100 * 2]) # [100, 7, 7, 200]
- d2 = conv_cond_concat(d2, yb)
-
- d3 = tf.nn.relu(batch_normal(de_conv(d2, output_shape=[self.batch_size, c2, c2, 200],
- name='gen_deconv1'), scope='gen_bn3'))
- d3 = conv_cond_concat(d3, yb)
- d4 = de_conv(d3, output_shape=[self.batch_size, self.output_size, self.output_size, self.channel],
- name='gen_deconv2', initializer=xavier_initializer())
-
- return tf.nn.sigmoid(d4)
-
- def dis_net(self, images, y, reuse=False):
- with tf.variable_scope("discriminator") as scope:
- if reuse:
- scope.reuse_variables()
-
- # mnist data's shape is (28 , 28 , 1)
- yb = tf.reshape(y, shape=[self.batch_size, 1, 1, self.y_dim])
- # concat
- concat_data = conv_cond_concat(images, yb)
-
- conv1, w1 = conv2d(concat_data, output_dim=10, name='dis_conv1')
- tf.add_to_collection('weight_1', w1)
-
- conv1 = lrelu(conv1)
- conv1 = conv_cond_concat(conv1, yb)
- tf.add_to_collection('ac_1', conv1)
-
- conv2, w2 = conv2d(conv1, output_dim=64, name='dis_conv2')
- tf.add_to_collection('weight_2', w2)
-
- conv2 = lrelu(batch_normal(conv2, scope='dis_bn1'))
- tf.add_to_collection('ac_2', conv2)
-
- conv2 = tf.reshape(conv2, [self.batch_size, -1])
- conv2 = tf.concat([conv2, y], 1)
-
- f1 = lrelu(batch_normal(fully_connect(conv2, output_size=1024, scope='dis_fully1'), scope='dis_bn2', reuse=reuse))
- f1 = tf.concat([f1, y], 1)
-
- out = fully_connect(f1, output_size=1, scope='dis_fully2', initializer = xavier_initializer())
-
- return tf.nn.sigmoid(out), out
最后主文件就是用来控制训练或者测试或者可视化过程的,先给出代码:
- from model import CGAN
- import tensorflow as tf
- from utils import Mnist
- import os
-
- flags = tf.app.flags
-
- flags.DEFINE_string("sample_dir", "samples_for_test", "the dir of sample images")
- flags.DEFINE_integer("output_size", 28, "the size of generate image")
- flags.DEFINE_float("learn_rate", 0.0002, "the learning rate for gan")
- flags.DEFINE_integer("batch_size", 100, "the batch number")
- flags.DEFINE_integer("z_dim", 100, "the dimension of noise z")
- flags.DEFINE_integer("y_dim", 10, "the dimension of condition y")
- flags.DEFINE_string("log_dir", "/tmp/tensorflow_mnist", "the path of tensorflow's log")
- flags.DEFINE_string("model_path", "model/model.ckpt", "the path of model")
- flags.DEFINE_string("visua_path", "visualization", "the path of visuzation images")
- flags.DEFINE_integer("op", 0, "0: train ; 1:test ; 2:visualize")
-
- FLAGS = flags.FLAGS
- #
- if not os.path.exists(FLAGS.sample_dir):
- os.makedirs(FLAGS.sample_dir)
- if not os.path.exists(FLAGS.log_dir):
- os.makedirs(FLAGS.log_dir)
- if not os.path.exists(FLAGS.model_path):
- os.makedirs(FLAGS.model_path)
- if not os.path.exists(FLAGS.visua_path):
- os.makedirs(FLAGS.visua_path)
-
-
- def main(_):
-
- mn_object = Mnist()
- cg = CGAN(data_ob=mn_object, sample_dir=FLAGS.sample_dir, output_size=FLAGS.output_size,
- learn_rate=FLAGS.learn_rate, batch_size=FLAGS.batch_size, z_dim=FLAGS.z_dim,
- y_dim=FLAGS.y_dim, log_dir=FLAGS.log_dir, model_path=FLAGS.model_path,
- visua_path=FLAGS.visua_path)
- cg.build_model()
-
- if FLAGS.op == 0:
- cg.train()
- elif FLAGS.op == 1:
- cg.test()
- else:
- cg.visual()
-
- if __name__ == '__main__':
- tf.app.run()
做好上述工作之后直接运行main文件就可以了,下面来简单看一下实验效果:
当epoch=0时的生成结果:
当epoch=50时的生成结果:
当epoch=100时的生成结果:
当epoch=400时的生成结果:
当epoch=1000时的生成结果:
当epoch=3000时的生成结果:
当epoch=5000时的生成结果:
可以看到,当epoch=3000时的生成结果就已经很不错了。不过相对于GAN来说训练过程还是慢了一点。
1. conditionalGAN是GAN的更进一步,仍然采用的是神经网络而不是卷积,该模型对GAN的输入进行了标签约束。
2. 如果你需要用conditionalGAN的话,不如考虑下infoGAN。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。