赞
踩
这是一种GAN网络增强技术----具有匹配感知的判别器。前面讲过,在InfoGAN中,使用了ACGAN的方式进行指导模拟数据与生成数据的对应关系(分类)。在GAN-cls中该效果会以更简单的方式来实现,即增强判别器的功能,令其不仅能判断图片真伪,还能判断匹配真伪。
(个人理解)没啥实质性改变,时间并未缩短,技术也没有怎么简化甚至变得复杂了。就是思想上的一个转变,原本ACGan是模拟样本+正确分类信息输入进去/真实样本+正确分类信息输入进D去。现在的GAN-cls变为输入真实样本和真实标签、虚拟样本和真实标签、虚拟标签和真实样本的三种组合形式(无对应图片的随机标签)
GAN-cls的具体做法是,在原有的GAN网络上,将判别器的输入变为图片与对应标签的连接数据。这样判别器的输入特征中就会有生成图像的特征与对应标签的特征。然后用这样的判别器分别对真实标签与真实图片、假标签与真实图片、真实标签与假图片进行判断,预期的结果依次为真、假、假,在训练的过程中沿着这个方向收敛即可。而对于生成器,则不需要做任何改动。这样简单的一步就完成了生成根据标签匹配的模拟数据功能。
直接修改上一篇 GAN生成对抗网络合集(五):LSGan-最小二乘GAN(附代码) 代码,将其改成GAN-cls。
# def discriminator(x, num_classes=10, num_cont=2): def discriminator(x, y): # 判别器函数 : x两次卷积,再接两次全连接; y代表输入的样本标签 reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0 # print (reuse) # print (x.get_shape()) with tf.variable_scope('discriminator', reuse=reuse): y = slim.fully_connected(y, num_outputs=n_input, activation_fn=leaky_relu) # 将y变为与图片一样维度的映射 y = tf.reshape(y, shape=[-1, 28, 28, 1]) # 将y统一成图片格式 x = tf.reshape(x, shape=[-1, 28, 28, 1]) # 将二者连接到一起,统一处理 x = tf.concat(axis=3, values=[x, y]) # x.shape = [-1, 28, 28, 2] x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu) x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu) # print ("conv2d",x.get_shape()) x = slim.flatten(x) # 输入扁平化 shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=leaky_relu) # recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=leaky_relu) # 生成的数据可以分别连接不同的输出层产生不同的结果 # 1维的输出层产生判别结果1或是0 disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid) disc = tf.squeeze(disc, -1) # print ("disc",disc.get_shape()) # 0 or 1 # 10维的输出层产生分类结果 (样本标签) # recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None) # 2维输出层产生重构造的隐含维度信息 # recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid) return disc # recog_cat, recog_cont
注:这里是将3种输入的x与y分别按照batch_size维度连接变为判别器的一个输入的。生成结果后再使用split函数将其裁成3个结果disc_real、disc_fake和disc_mis,分别代表真实样本与真实标签、生成的图像gen与真实标签、真实样本与错误标签所对应的判别值。这么写会使代码看上去简洁一些,当然也可以一个一个地输入x、y,然后调用三次判别器,效果是一样的。
################################################################## # 3.定义网络模型 : 定义 参数/输入/输出/中间过程(经过G/D)的输入输出 ################################################################## batch_size = 10 # 获取样本的批次大小32 classes_dim = 10 # 10 classes con_dim = 2 # 隐含信息变量的维度, 应节点为z_con rand_dim = 38 # 一般噪声的维度, 应节点为z_rand, 二者都是符合标准高斯分布的随机数。 n_input = 784 # 28 * 28 x = tf.placeholder(tf.float32, [None, n_input]) # x为输入真实图片images y = tf.placeholder(tf.int32, [None]) # y为真实标签labels misy = tf.placeholder(tf.int32, [None]) # 错误标签 # z_con = tf.random_normal((batch_size, con_dim)) # 2列 z_rand = tf.random_normal((batch_size, rand_dim)) # 38列 z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_rand]) # 50列 shape = (10, 50) gen = generator(z) # shape = (10, 28, 28, 1) genout = tf.squeeze(gen, -1) # shape = (10, 28, 28) # labels for discriminator # y_real = tf.ones(batch_size) # 真
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。