赞
踩
InfoGAN是一种把信息论与GAN相融合的神经网络,能够使网络具有信息解读功能。
GAN的生成器在构建样本时使用了任意的噪声向量x’,并从低维的噪声数据x’中还原出来高维的样本数据。这说明数据x’中含有具有与样本相同的特征。
由于随意使用的噪声都能还原出高维样本数据,表明噪声中的特征数据部分是与无用的数据部分高度地纠缠在一起的,即我们能够知道噪声中含有有用特征,但无法知道哪些是有用特征。
InfoGAN是GAN模型的一种改进,是一种能够学习样本中的关键维度信息的GAN,即对生成样本的噪音进行了细化。先来看它的结构,相比对抗自编码,InfoGAN的思路正好相反,InfoGAN是先固定标准高斯分布作为网络输入,再慢慢调整网络输出去匹配复杂样本分布。
图3.1 InfoGAN模型
如图3.1所示,InfoGAN生成器是从标准高斯分布中随机采样来作为输入,生成模拟样本,解码器是将生成器输出的模拟样本还原回生成器输入的随机数中的一部分,判别器是将样本作为输入来区分真假样本。
InfoGAN的理论思想是将输入的随机标准高斯分布当成噪音数据,并将噪音分为两类,第一类是不可压缩的噪音Z,第二类是可解释性的信息C。假设在一个样本中,决定其本身的只有少量重要的维度,那么大多数的维度是可以忽略的。而这里的解码器可以更形象地叫成重构器,即通过重构一部分输入的特征来确定与样本互信息的那些维度。最终被找到的维度可以代替原始样本的特征(类似PCA算法中的主成份),实现降维、解耦的效果。
AC-GAN(Auxiliary Classifier GAN),即在判别器discriminator中再输出相应的分类概率,然后增加输出的分类与真实分类的损失计算,使生成的模拟数据与其所属的class一一对应。一般来讲,AC-GAN可以属于InfoGAN的一部分,class信息可以作为InfoGAN中的潜在信息,只不过这部分信息可以使用半监督方式来学习。
首先明确,GAN的代码没有目标检测的复杂,以一个目标检测程序demo的篇幅就涵盖了GAN的数据输入、训练、定义网络结构和参数、loss函数和优化器以及可视化部分。
还可以学习到的是,GAN基本除开两个大的网络框架G和D以外,就是加各种约束(分类信息、隐含信息等)用以生成想要的数据。
下面是代码实现学习MINST数据特征,生成以假乱真的MNIST模拟样本,并发现内部潜在的特征信息。
代码总纲:
MNIST数据集下载到相应的地址,其加载方式是固定的。
# -*- coding: utf-8 -*- ################################################################## # 1.引入头文件并加载mnist数据 ################################################################## import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from scipy.stats import norm import tensorflow.contrib.slim as slim from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("/media/S318080208/py_pictures/minist/") # ,one_hot=True) tf.reset_default_graph() # 用于清除默认图形堆栈并重置全局默认图形
################################################################## # 2.定义生成器与判别器 ################################################################## def generator(x): # 生成器函数 : 两个全连接+两个反卷积模拟样本的生成,每一层都有BN(批量归一化)处理 reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0 # 确认该变量作用域没有变量 # print (x.get_shape()) with tf.variable_scope('generator', reuse=reuse): x = slim.fully_connected(x, 1024) # print(x) x = slim.batch_norm(x, activation_fn=tf.nn.relu) x = slim.fully_connected(x, 7*7*128) x = slim.batch_norm(x, activation_fn=tf.nn.relu) x = tf.reshape(x, [-1, 7, 7, 128]) # print ('22', tf.tensor.get_shape()) x = slim.conv2d_transpose(x, 64, kernel_size=[4, 4], stride=2, activation_fn = None) # print ('gen',x.get_shape()) x = slim.batch_norm(x, activation_fn=tf.nn.relu) z = slim.conv2d_transpose(x, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid) # print ('genz',z.get_shape()) return z def leaky_relu(x
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。