当前位置:   article > 正文

GAN生成对抗网络合集(六):GAN-cls –具有匹配感知的判别器(附代码)_gan匹配

gan匹配

1 GAN-cls原理

       这是一种GAN网络增强技术----具有匹配感知的判别器。前面讲过,在InfoGAN中,使用了ACGAN的方式进行指导模拟数据与生成数据的对应关系(分类)。在GAN-cls中该效果会以更简单的方式来实现,即增强判别器的功能,令其不仅能判断图片真伪,还能判断匹配真伪

(个人理解)没啥实质性改变,时间并未缩短,技术也没有怎么简化甚至变得复杂了。就是思想上的一个转变,原本ACGan是模拟样本+正确分类信息输入进去/真实样本+正确分类信息输入进D去。现在的GAN-cls变为输入真实样本和真实标签、虚拟样本和真实标签、虚拟标签和真实样本的三种组合形式(无对应图片的随机标签

       GAN-cls的具体做法是,在原有的GAN网络上,将判别器的输入变为图片与对应标签的连接数据。这样判别器的输入特征中就会有生成图像的特征与对应标签的特征。然后用这样的判别器分别对真实标签与真实图片、假标签与真实图片、真实标签与假图片进行判断,预期的结果依次为真、假、假,在训练的过程中沿着这个方向收敛即可。而对于生成器,则不需要做任何改动。这样简单的一步就完成了生成根据标签匹配的模拟数据功能。

在这里插入图片描述

2 代码

直接修改上一篇 GAN生成对抗网络合集(五):LSGan-最小二乘GAN(附代码) 代码,将其改成GAN-cls。

  1. 修改判别器D
    将判别器的输入改成x与y,新增加的y代表输入的样本标签(真、假);在内部处理中,先通过全连接网络将y变为与图片一样维度的映射,并调整为图片相同的形状,使用concat将二者连接到一起统一处理。后续的处理过程是一样的,两个卷积后再接两个全连接,最后一层输出disc。该部分代码如下:
# def discriminator(x, num_classes=10, num_cont=2):
def discriminator(x, y):  # 判别器函数 : x两次卷积,再接两次全连接; y代表输入的样本标签
    reuse = len([t for t in tf.global_variables() if t.name.startswith('discriminator')]) > 0
    # print (reuse)
    # print (x.get_shape())
    with tf.variable_scope('discriminator', reuse=reuse):

        y = slim.fully_connected(y, num_outputs=n_input, activation_fn=leaky_relu)  # 将y变为与图片一样维度的映射
        y = tf.reshape(y, shape=[-1, 28, 28, 1])    # 将y统一成图片格式

        x = tf.reshape(x, shape=[-1, 28, 28, 1])

        # 将二者连接到一起,统一处理
        x = tf.concat(axis=3, values=[x, y])  # x.shape = [-1, 28, 28, 2]

        x = slim.conv2d(x, num_outputs=64, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4, 4], stride=2, activation_fn=leaky_relu)
        # print ("conv2d",x.get_shape())
        x = slim.flatten(x)  # 输入扁平化
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn=leaky_relu)
        # recog_shared = slim.fully_connected(shared_tensor, num_outputs=128, activation_fn=leaky_relu)

        # 生成的数据可以分别连接不同的输出层产生不同的结果
        # 1维的输出层产生判别结果1或是0
        disc = slim.fully_connected(shared_tensor, num_outputs=1, activation_fn=tf.nn.sigmoid)
        disc = tf.squeeze(disc, -1)
        # print ("disc",disc.get_shape()) # 0 or 1

        # 10维的输出层产生分类结果 (样本标签)
        # recog_cat = slim.fully_connected(recog_shared, num_outputs=num_classes, activation_fn=None)

        # 2维输出层产生重构造的隐含维度信息
        # recog_cont = slim.fully_connected(recog_shared, num_outputs=num_cont, activation_fn=tf.nn.sigmoid)
    return disc  # recog_cat, recog_cont
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  1. 添加错误标签输入符,构建网络结构
    添加错误标签misy,同时在判别器中分别将真实样本与真实标签、生成的图像gen与真实标签、真实样本与错误标签组成的输入传入判别器中。去掉隐含信息z_con部分。

注:这里是将3种输入的x与y分别按照batch_size维度连接变为判别器的一个输入的。生成结果后再使用split函数将其裁成3个结果disc_real、disc_fake和disc_mis,分别代表真实样本与真实标签、生成的图像gen与真实标签、真实样本与错误标签所对应的判别值。这么写会使代码看上去简洁一些,当然也可以一个一个地输入x、y,然后调用三次判别器,效果是一样的。

##################################################################
#  3.定义网络模型 : 定义 参数/输入/输出/中间过程(经过G/D)的输入输出
##################################################################
batch_size = 10  # 获取样本的批次大小32
classes_dim = 10  # 10 classes
con_dim = 2  # 隐含信息变量的维度, 应节点为z_con
rand_dim = 38  # 一般噪声的维度, 应节点为z_rand, 二者都是符合标准高斯分布的随机数。
n_input = 784  # 28 * 28

x = tf.placeholder(tf.float32, [None, n_input])  # x为输入真实图片images
y = tf.placeholder(tf.int32, [None])  # y为真实标签labels
misy = tf.placeholder(tf.int32, [None])  # 错误标签

# z_con = tf.random_normal((batch_size, con_dim))  # 2列
z_rand = tf.random_normal((batch_size, rand_dim))  # 38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth=classes_dim), z_rand])  # 50列 shape = (10, 50)
gen = generator(z)  # shape = (10, 28, 28, 1)
genout = tf.squeeze(gen, -1)  # shape = (10, 28, 28)

# labels for discriminator
# y_real = tf.ones(batch_size)  # 真
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/441200
推荐阅读
相关标签
  

闽ICP备14008679号