赞
踩
SAGAN全称为Self-Attention Generative Adversarial Networks,是由Han Zhang等人[1]于18年5月提出的一种模型。文章中作者解释到,传统的GAN模型都是在低分辨率特征图的空间局部点上来生成高分辨率的细节,而SAGAN是可以从所有的特征处生成细节,并且SAGAN的判别器可以判别两幅具有明显差异的图像是否具有一致的高度精细特征。SAGAN目前是取得了非常好的效果。
本文以CelebA数据集为例,用SAGAN生成更为精细的人脸图像,主要参考代码为[2]。
[1]文章链接:https://arxiv.org/pdf/1805.08318.pdf
[2]参考代码:https://github.com/taki0112/Self-Attention-GAN-Tensorflow
感觉入门系列的GAN文章网上的介绍还挺多,越新的文章解读越少,这里简单推荐一篇:
[3]SA-GAN - Self-Attention Generative Adversarial Networks 论文解读(附代码)
下面是自己对于文献的一些理解和介绍。
在我的上一篇文章GAN系列文章中:对抗神经网络学习(十)——attentiveGAN实现影像去雨滴的过程(tensorflow实现),初次了解到了attentiveNet引入GAN中的优势,引入attentiveNet来生成attentive map,能够让网络快速准确的定位到图像中的重点关注区域,当时就隐约觉得可以用这个思路来进一步优化GAN的模型结构,后来就看到了SAGAN采用了这个方法。
首先,作者关注GAN目前存在的问题:当我们训练多类别数据集时,GAN在某些图像类别上很难建模。通俗来说,GAN容易捕捉纹理特征但很难捕捉几何结构特征。
However, by carefully examining the generated samples from these models, we can observe that convolutional GANs have much more difficulty modeling some image classes than others when trained on multi-class datasets. For example, while the state-of-the-art ImageNet GAN model excels at synthesizing image classes with few structural constraints (e.g. ocean, sky and landscape classes, which are distinguished more by texture than by geometry), it fails to capture geometric or structural patterns that occur consistently in some classes (for example, dogs are often drawn with realistic fur texture but without clearly defined separate feet).
原因就在于这类模型依靠卷积来建立不同图像区域之间的依赖关系,而依赖关系的传递只能通过大范围的多个卷积层来实现。随着卷积大小的增加,网络的真实容量也在增加,但却损失了计算效率。而self-attentive,却能够做到依赖性和计算效率的平衡,因此文章引入self-attention机制。
作者主要的贡献在于:
In this work, we propose Self-Attention Generative Adversarial Networks (SAGANs), which introduce a self -attention mechanism into convolutional GANs. The self-attention module is complementary to convolutions and helps with modeling long range, multi-level dependencies across image regions. Armed with self-attention, the generator can draw images in which fine details at every location are carefully coordinated with fine details in distant portions of the image. Moreover, the discriminator can also more accurately enforce complicated geometric constraints on the global image structure. (作者引入self-attention机制,提出SAGAN。引入该机制后,生成器能够精细的细节,判别器能够实行几何限制。)
作者以一幅图来简单介绍self-attention:
简单来理解,最左边的图中的5个点是作者放置的5个点。而右侧的每张图对应一个点,内容则是这个点具有类似特征的区域,也就是上面说的most-attened regions。
之后,作者介绍了self-attention机制:
将self-attention机制引入到GAN中,可以得到SAGAN的loss函数:
为了使得SAGAN的训练稳定,作者还使用一点trick:①在生成器和判别器中都使用了spectral normalization
。②生成器和判别器的非平衡学习率(Imbalanced learning rate)。
最终SAGAN的表现也很好:
SAGAN significantly outperforms the state of the art in image synthesis by boosting the best reported Inception score from 36.8 to 52.52 and reducing Fréchet Inception distance from 27.62 to 18.65.
关于文章中的公式推导和参数的设置,这里不再多说,最后只展示一下作者的效果图,看起来还是很惊艳的:
SAGAN的文件比较少,所有文件的结构为:
- -- dataset # 数据集文件,需要自己下载
- |------ CelebA
- |------ image1.jpg
- |------ image2.jpg
- |------ ...
- -- ops.py # 图层文件
- -- utils.py # 操作文件
- -- SAGAN.py # 模型文件
- -- main.py # 主函数文件
这次用的数据集依旧是CelebA人脸数据集,之前的文章对抗神经网络学习(六)——BEGAN实现不同人脸的生成(tensorflow实现)中有介绍这个数据集,因此这里不再多介绍,只给出这个数据集的官方下载地址:https://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA。
下载好数据集之后,将其解压放到'dataset/CelebA/'路径下即可。
utils.py的所有内容如下:
- import scipy.misc
- import numpy as np
- import os
- from glob import glob
-
- import tensorflow as tf
- import tensorflow.contrib.slim as slim
-
- class ImageData:
-
- def __init__(self, load_size, channels):
- self.load_size = load_size
- self.channels = channels
-
- def image_processing(self, filename):
- x = tf.read_file(filename)
- x_decode = tf.image.decode_jpeg(x, channels=self.channels)
- img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
- img = tf.cast(img, tf.float32) / 127.5 - 1
- return img
-
-
- def load_data(dataset_name, size=64) :
- x = glob(os.path.join("./dataset", dataset_name, '*.*'))
- return x
-
-
- def preprocessing(x, size):
- x = scipy.misc.imread(x, mode='RGB')
- x = scipy.misc.imresize(x, [size, size])
- x = normalize(x)
- return x
-
-
- def normalize(x):
- return x/127.5 - 1
-
-
- def save_images(images, size, image_path):
- return imsave(inverse_transform(images), size, image_path)
-
-
- def merge(images, size):
- h, w = images.shape[1], images.shape[2]
- if images.shape[3] in (3, 4):
- c = images.shape[3]
- img = np.zeros((h * size[0], w * size[1], c))
- for idx, image in enumerate(images):
- i = idx % size[1]
- j = idx // size[1]
- img[j * h:j * h + h, i * w:i * w + w, :] = image
- return img
- elif images.shape[3] == 1:
- img = np.zeros((h * size[0], w * size[1]))
- for idx, image in enumerate(images):
- i = idx % size[1]
- j = idx // size[1]
- img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
- return img
- else:
- raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
-
-
- def imsave(images, size, path):
- # image = np.squeeze(merge(images, size)) # 채널이 1인거 제거 ?
- return scipy.misc.imsave(path, merge(images, size))
-
-
- def inverse_transform(images):
- return (images+1.)/2.
-
-
- def check_folder(log_dir):
- if not os.path.exists(log_dir):
- os.makedirs(log_dir)
- return log_dir
-
-
- def show_all_variables():
- model_vars = tf.trainable_variables()
- slim.model_analyzer.analyze_vars(model_vars, print_info=True)
-
-
- def str2bool(x):
- return x.lower() in ('true')
图层文件ops.py的主要内容为:
- import tensorflow as tf
- import tensorflow.contrib as tf_contrib
-
- # Xavier : tf_contrib.layers.xavier_initializer()
- # He : tf_contrib.layers.variance_scaling_initializer()
- # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
- # l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
-
- weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
- weight_regularizer = None
-
- ##################################################################################
- # Layer
- ##################################################################################
-
- def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
- with tf.variable_scope(scope):
- if pad_type == 'zero' :
- x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
- if pad_type == 'reflect' :
- x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT')
-
- if sn :
- w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
- regularizer=weight_regularizer)
- x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
- strides=[1, stride, stride, 1], padding='VALID')
- if use_bias :
- bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
- x = tf.nn.bias_add(x, bias)
-
- else :
- x = tf.layers.conv2d(inputs=x, filters=channels,
- kernel_size=kernel, kernel_initializer=weight_init,
- kernel_regularizer=weight_regularizer,
- strides=stride, use_bias=use_bias)
- return x
-
-
- def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'):
- with tf.variable_scope(scope):
- x_shape = x.get_shape().as_list()
-
- if padding == 'SAME':
- output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels]
-
- else:
- output_shape =[x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), x_shape[2] * stride + max(kernel - stride, 0), channels]
-
- if sn :
- w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer)
- x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding=padding)
-
- if use_bias :
- bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
- x = tf.nn.bias_add(x, bias)
-
- else :
- x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
- kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer,
- strides=stride, padding=padding, use_bias=use_bias)
- return x
-
- def fully_conneted(x, units, use_bias=True, sn=False, scope='fully_0'):
- with tf.variable_scope(scope):
- x = flatten(x)
- shape = x.get_shape().as_list()
- channels = shape[-1]
-
- if sn :
- w = tf.get_variable("kernel", [channels, units], tf.float32,
- initializer=weight_init, regularizer=weight_regularizer)
- if use_bias :
- bias = tf.get_variable("bias", [units],
- initializer=tf.constant_initializer(0.0))
-
- x = tf.matmul(x, spectral_norm(w)) + bias
- else :
- x = tf.matmul(x, spectral_norm(w))
-
- else :
- x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)
- return x
-
- def flatten(x) :
- return tf.layers.flatten(x)
-
- def hw_flatten(x) :
- return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
-
- ##################################################################################
- # Residual-block
- ##################################################################################
-
- def resblock(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock'):
- with tf.variable_scope(scope):
- with tf.variable_scope('res1'):
- x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
- x = batch_norm(x, is_training)
- x = relu(x)
-
- with tf.variable_scope('res2'):
- x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
- x = batch_norm(x, is_training)
-
- return x + x_init
-
- ##################################################################################
- # Sampling
- ##################################################################################
-
- def global_avg_pooling(x):
- gap = tf.reduce_mean(x, axis=[1, 2])
-
- return gap
-
- def up_sample(x, scale_factor=2):
- _, h, w, _ = x.get_shape().as_list()
- new_size = [h * scale_factor, w * scale_factor]
- return tf.image.resize_nearest_neighbor(x, size=new_size)
-
- ##################################################################################
- # Activation function
- ##################################################################################
-
- def lrelu(x, alpha=0.2):
- return tf.nn.leaky_relu(x, alpha)
-
-
- def relu(x):
- return tf.nn.relu(x)
-
-
- def tanh(x):
- return tf.tanh(x)
-
- ##################################################################################
- # Normalization function
- ##################################################################################
-
- def batch_norm(x, is_training=True, scope='batch_norm'):
- return tf_contrib.layers.batch_norm(x,
- decay=0.9, epsilon=1e-05,
- center=True, scale=True, updates_collections=None,
- is_training=is_training, scope=scope)
-
- def spectral_norm(w, iteration=1):
- w_shape = w.shape.as_list()
- w = tf.reshape(w, [-1, w_shape[-1]])
-
- u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
-
- u_hat = u
- v_hat = None
- for i in range(iteration):
- """
- power iteration
- Usually iteration = 1 will be enough
- """
- v_ = tf.matmul(u_hat, tf.transpose(w))
- v_hat = l2_norm(v_)
-
- u_ = tf.matmul(v_hat, w)
- u_hat = l2_norm(u_)
-
- sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
- w_norm = w / sigma
-
- with tf.control_dependencies([u.assign(u_hat)]):
- w_norm = tf.reshape(w_norm, w_shape)
-
- return w_norm
-
- def l2_norm(v, eps=1e-12):
- return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
-
- ##################################################################################
- # Loss function
- ##################################################################################
-
- def discriminator_loss(loss_func, real, fake):
- real_loss = 0
- fake_loss = 0
-
- if loss_func == 'lsgan' :
- real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0))
- fake_loss = tf.reduce_mean(tf.square(fake))
-
- if loss_func == 'gan' or loss_func == 'dragan' :
- real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
- fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
-
- if loss_func == 'hinge' :
- real_loss = tf.reduce_mean(relu(1.0 - real))
- fake_loss = tf.reduce_mean(relu(1.0 + fake))
-
- loss = real_loss + fake_loss
-
- return loss
-
- def generator_loss(loss_func, fake):
- fake_loss = 0
-
- if loss_func == 'lsgan' :
- fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0))
-
- if loss_func == 'gan' or loss_func == 'dragan' :
- fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))
-
- if loss_func == 'hinge' :
- fake_loss = -tf.reduce_mean(fake)
-
- loss = fake_loss
-
- return loss
SAGAN.py文件的主要内容为:
- import time
- from ops import *
- from utils import *
- from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
-
- class SAGAN(object):
-
- def __init__(self, sess, args):
- self.model_name = "SAGAN" # name for checkpoint
- self.sess = sess
- self.dataset_name = args.dataset
- self.checkpoint_dir = args.checkpoint_dir
- self.sample_dir = args.sample_dir
- self.result_dir = args.result_dir
- self.log_dir = args.log_dir
-
- self.epoch = args.epoch
- self.iteration = args.iteration
- self.batch_size = args.batch_size
- self.print_freq = args.print_freq
- self.save_freq = args.save_freq
- self.img_size = args.img_size
-
- """ Generator """
- self.layer_num = int(np.log2(self.img_size)) - 3
- self.z_dim = args.z_dim # dimension of noise-vector
- self.up_sample = args.up_sample
- self.gan_type = args.gan_type
-
- """ Discriminator """
- self.n_critic = args.n_critic
- self.sn = args.sn
- self.ld = args.ld
-
- self.sample_num = args.sample_num # number of generated images to be saved
- self.test_num = args.test_num
-
- # train
- self.g_learning_rate = args.g_lr
- self.d_learning_rate = args.d_lr
- self.beta1 = args.beta1
- self.beta2 = args.beta2
-
- self.custom_dataset = False
-
- if self.dataset_name == 'mnist' :
- self.c_dim = 1
- self.data = load_mnist(size=self.img_size)
-
- elif self.dataset_name == 'cifar10' :
- self.c_dim = 3
- self.data = load_cifar10(size=self.img_size)
-
- else :
- self.c_dim = 3
- self.data = load_data(dataset_name=self.dataset_name, size=self.img_size)
- self.custom_dataset = True
-
- self.dataset_num = len(self.data)
-
- self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
- check_folder(self.sample_dir)
-
- print()
-
- print("##### Information #####")
- print("# gan type : ", self.gan_type)
- print("# dataset : ", self.dataset_name)
- print("# dataset number : ", self.dataset_num)
- print("# batch_size : ", self.batch_size)
- print("# epoch : ", self.epoch)
- print("# iteration per epoch : ", self.iteration)
-
- print()
-
- print("##### Generator #####")
- print("# generator layer : ", self.layer_num)
- print("# upsample conv : ", self.up_sample)
-
- print()
-
- print("##### Discriminator #####")
- print("# discriminator layer : ", self.layer_num)
- print("# the number of critic : ", self.n_critic)
- print("# spectral normalization : ", self.sn)
-
- ##################################################################################
- # Generator
- ##################################################################################
-
- def generator(self, z, is_training=True, reuse=False):
- with tf.variable_scope("generator", reuse=reuse):
- ch = 1024
- x = deconv(z, channels=ch, kernel=4, stride=1, padding='VALID', use_bias=False, sn=self.sn, scope='deconv')
- x = batch_norm(x, is_training, scope='batch_norm')
- x = relu(x)
-
- for i in range(self.layer_num // 2):
- if self.up_sample:
- x = up_sample(x, scale_factor=2)
- x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
- x = relu(x)
-
- else:
- x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
- x = relu(x)
-
- ch = ch // 2
-
- # Self Attention
- x = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)
-
- for i in range(self.layer_num // 2, self.layer_num):
- if self.up_sample:
- x = up_sample(x, scale_factor=2)
- x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
- x = relu(x)
-
- else:
- x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
- x = relu(x)
-
- ch = ch // 2
-
-
- if self.up_sample:
- x = up_sample(x, scale_factor=2)
- x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, sn=self.sn, scope='G_conv_logit')
- x = tanh(x)
-
- else:
- x = deconv(x, channels=self.c_dim, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='G_deconv_logit')
- x = tanh(x)
-
- return x
-
- ##################################################################################
- # Discriminator
- ##################################################################################
-
- def discriminator(self, x, is_training=True, reuse=False):
- with tf.variable_scope("discriminator", reuse=reuse):
- ch = 64
- x = conv(x, channels=ch, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv')
- x = lrelu(x, 0.2)
-
- for i in range(self.layer_num // 2):
- x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm' + str(i))
- x = lrelu(x, 0.2)
-
- ch = ch * 2
-
- # Self Attention
- x = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)
-
- for i in range(self.layer_num // 2, self.layer_num):
- x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm' + str(i))
- x = lrelu(x, 0.2)
-
- ch = ch * 2
-
- x = conv(x, channels=4, stride=1, sn=self.sn, use_bias=False, scope='D_logit')
-
- return x
-
- def attention(self, x, ch, sn=False, scope='attention', reuse=False):
- with tf.variable_scope(scope, reuse=reuse):
- f = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c']
- g = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c']
- h = conv(x, ch, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c]
-
- # N = h * w
- s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
-
- beta = tf.nn.softmax(s) # attention map
-
- o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
- gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
-
- o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]
- x = gamma * o + x
-
- return x
-
- def gradient_penalty(self, real, fake):
- if self.gan_type == 'dragan' :
- shape = tf.shape(real)
- eps = tf.random_uniform(shape=shape, minval=0., maxval=1.)
- x_mean, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
- x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
- noise = 0.5 * x_std * eps # delta in paper
-
- # Author suggested U[0,1] in original paper, but he admitted it is bug in github
- # (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided.
-
- alpha = tf.random_uniform(shape=[shape[0], 1, 1, 1], minval=-1., maxval=1.)
- interpolated = tf.clip_by_value(real + alpha * noise, -1., 1.) # x_hat should be in the space of X
-
- else :
- alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
- interpolated = alpha*real + (1. - alpha)*fake
-
- logit = self.discriminator(interpolated, reuse=True)
-
- grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
- grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
-
- GP = 0
-
- # WGAN - LP
- if self.gan_type == 'wgan-lp':
- GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
-
- elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
- GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
-
- return GP
-
- ##################################################################################
- # Model
- ##################################################################################
-
- def build_model(self):
- """ Graph Input """
- # images
- if self.custom_dataset :
- Image_Data_Class = ImageData(self.img_size, self.c_dim)
- inputs = tf.data.Dataset.from_tensor_slices(self.data)
-
- gpu_device = '/gpu:0'
- inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
-
- inputs_iterator = inputs.make_one_shot_iterator()
-
- self.inputs = inputs_iterator.get_next()
-
- else :
- self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images')
-
- # noises
- self.z = tf.placeholder(tf.float32, [self.batch_size, 1, 1, self.z_dim], name='z')
-
- """ Loss Function """
- # output of D for real images
- real_logits = self.discriminator(self.inputs)
-
- # output of D for fake images
- fake_images = self.generator(self.z)
- fake_logits = self.discriminator(fake_images, reuse=True)
-
- if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
- GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
- else :
- GP = 0
-
- # get loss for discriminator
- self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP
-
- # get loss for generator
- self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
-
- """ Training """
- # divide trainable variables into a group for D and a group for G
- t_vars = tf.trainable_variables()
- d_vars = [var for var in t_vars if 'discriminator' in var.name]
- g_vars = [var for var in t_vars if 'generator' in var.name]
-
- # optimizers
- self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
- self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars)
-
- """" Testing """
- # for test
- self.fake_images = self.generator(self.z, is_training=False, reuse=True)
-
- """ Summary """
- self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
- self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
-
- ##################################################################################
- # Train
- ##################################################################################
-
- def train(self):
- # initialize all variables
- tf.global_variables_initializer().run()
-
- # graph inputs for visualize training results
- self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
-
- # saver to save model
- self.saver = tf.train.Saver()
-
- # summary writer
- self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
-
- # restore check-point if it exits
- could_load, checkpoint_counter = self.load(self.checkpoint_dir)
- if could_load:
- start_epoch = (int)(checkpoint_counter / self.iteration)
- start_batch_id = checkpoint_counter - start_epoch * self.iteration
- counter = checkpoint_counter
- print(" [*] Load SUCCESS")
- else:
- start_epoch = 0
- start_batch_id = 0
- counter = 1
- print(" [!] Load failed...")
-
- # loop for epoch
- start_time = time.time()
- past_g_loss = -1.
- for epoch in range(start_epoch, self.epoch):
- # get batch data
- for idx in range(start_batch_id, self.iteration):
- batch_z = np.random.uniform(-1, 1, [self.batch_size, 1, 1, self.z_dim])
-
- if self.custom_dataset :
-
- train_feed_dict = {
- self.z: batch_z
- }
-
- else :
- random_index = np.random.choice(self.dataset_num, size=self.batch_size, replace=False)
- # batch_images = self.data[idx*self.batch_size : (idx+1)*self.batch_size]
- batch_images = self.data[random_index]
-
- train_feed_dict = {
- self.inputs : batch_images,
- self.z : batch_z
- }
-
- # update D network
- _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], feed_dict=train_feed_dict)
- self.writer.add_summary(summary_str, counter)
-
- # update G network
- g_loss = None
- if (counter - 1) % self.n_critic == 0:
- _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict=train_feed_dict)
- self.writer.add_summary(summary_str, counter)
- past_g_loss = g_loss
-
- # display training status
- counter += 1
- if g_loss == None :
- g_loss = past_g_loss
- print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
- % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
-
- # save training results for every 300 steps
- if np.mod(idx+1, self.print_freq) == 0:
- samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z})
- tot_num_samples = min(self.sample_num, self.batch_size)
- manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
- manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
- save_images(samples[:manifold_h * manifold_w, :, :, :],
- [manifold_h, manifold_w],
- './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format(epoch, idx+1))
-
- if np.mod(idx+1, self.save_freq) == 0:
- self.save(self.checkpoint_dir, counter)
-
- # After an epoch, start_batch_id is set to zero
- # non-zero value is only for the first epoch after loading pre-trained model
- start_batch_id = 0
-
- # save model
- self.save(self.checkpoint_dir, counter)
-
- # show temporal results
- # self.visualize_results(epoch)
-
- # save model for final step
- self.save(self.checkpoint_dir, counter)
-
- @property
- def model_dir(self):
- return "{}_{}_{}_{}_{}_{}".format(
- self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, self.sn)
-
- def save(self, checkpoint_dir, step):
- checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
-
- if not os.path.exists(checkpoint_dir):
- os.makedirs(checkpoint_dir)
-
- self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
-
- def load(self, checkpoint_dir):
- import re
- print(" [*] Reading checkpoints...")
- checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
-
- ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
- if ckpt and ckpt.model_checkpoint_path:
- ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
- self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
- counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
- print(" [*] Success to read {}".format(ckpt_name))
- return True, counter
- else:
- print(" [*] Failed to find a checkpoint")
- return False, 0
-
- def visualize_results(self, epoch):
- tot_num_samples = min(self.sample_num, self.batch_size)
- image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
-
- """ random condition, random noise """
-
- z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
-
- samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
-
- save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
- self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png')
-
- def test(self):
- tf.global_variables_initializer().run()
-
- self.saver = tf.train.Saver()
- could_load, checkpoint_counter = self.load(self.checkpoint_dir)
- result_dir = os.path.join(self.result_dir, self.model_dir)
- check_folder(result_dir)
-
- if could_load:
- print(" [*] Load SUCCESS")
- else:
- print(" [!] Load failed...")
-
- tot_num_samples = min(self.sample_num, self.batch_size)
- image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
-
- """ random condition, random noise """
-
- for i in range(self.test_num) :
- z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
-
- samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
-
- save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
- [image_frame_dim, image_frame_dim],
- result_dir + '/' + self.model_name + '_test_{}.png'.format(i))
main.py文件的主要内容为:
- from SAGAN import SAGAN
- import argparse
- from utils import *
-
- """parsing and configuration"""
- def parse_args():
- desc = "Tensorflow implementation of Self-Attention GAN"
- parser = argparse.ArgumentParser(description=desc)
- parser.add_argument('--phase', type=str, default='train', help='train or test ?')
- parser.add_argument('--dataset', type=str, default='celebA', help='dataset name')
-
-
- parser.add_argument('--epoch', type=int, default=10, help='The number of epochs to run')
- parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
- parser.add_argument('--batch_size', type=int, default=32, help='The size of batch per gpu')
- parser.add_argument('--print_freq', type=int, default=500, help='The number of image_print_freqy')
- parser.add_argument('--save_freq', type=int, default=500, help='The number of ckpt_save_freq')
-
-
- parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for generator')
- parser.add_argument('--d_lr', type=float, default=0.0004, help='learning rate for discriminator')
- parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for Adam optimizer')
- parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for Adam optimizer')
-
-
- parser.add_argument('--z_dim', type=int, default=128, help='Dimension of noise vector')
- parser.add_argument('--up_sample', type=str2bool, default=True, help='using upsample-conv')
- parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')
- parser.add_argument('--gan_type', type=str, default='hinge', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]')
- parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')
- parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
-
- parser.add_argument('--img_size', type=int, default=128, help='The size of image')
- parser.add_argument('--sample_num', type=int, default=64, help='The number of sample images')
-
-
- parser.add_argument('--test_num', type=int, default=10, help='The number of images generated by the test')
-
-
- parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
- help='Directory name to save the checkpoints')
- parser.add_argument('--result_dir', type=str, default='results',
- help='Directory name to save the generated images')
- parser.add_argument('--log_dir', type=str, default='logs',
- help='Directory name to save training logs')
- parser.add_argument('--sample_dir', type=str, default='samples',
- help='Directory name to save the samples on training')
-
- return check_args(parser.parse_args())
-
- """checking arguments"""
- def check_args(args):
- # --checkpoint_dir
- check_folder(args.checkpoint_dir)
-
- # --result_dir
- check_folder(args.result_dir)
-
- # --result_dir
- check_folder(args.log_dir)
-
- # --sample_dir
- check_folder(args.sample_dir)
-
- # --epoch
- try:
- assert args.epoch >= 1
- except:
- print('number of epochs must be larger than or equal to one')
-
- # --batch_size
- try:
- assert args.batch_size >= 1
- except:
- print('batch size must be larger than or equal to one')
- return args
-
-
- """main"""
- def main():
- # parse arguments
- args = parse_args()
- if args is None:
- exit()
-
- # open session
- with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
- gan = SAGAN(sess, args)
-
- # build graph
- gan.build_model()
-
- # show network architecture
- show_all_variables()
-
- if args.phase == 'train' :
- # launch the graph in a session
- gan.train()
-
- # visualize learned generator
- gan.visualize_results(args.epoch - 1)
-
- print(" [*] Training finished!")
-
- if args.phase == 'test' :
- gan.test()
- print(" [*] Test finished!")
-
- if __name__ == '__main__':
- main()
写完上述所有文件之后,运行main.py即可。实验过程比较慢,设置epoch为10,每个epoch内迭代10000次,我的GPU是GTX1060 3G,每个epoch大概需要近4000秒,即1个小时5分钟,所以还是非常耗时的。目前只进行了训练过程,所以只有训练过程的生成样本。
当epoch=0, iter=500时,也就是迭代500次,生成样本为:
当epoch=1,iter=0时,也就是运算了10000次,效果为:
当epoch=2,iter=0时,即运算20000次,效果为:
当epoch=4,iter=0,即运算40000次时,效果为:
当epoch=6,iter=0,即运算60000次时,效果为:
由于loss已经很小了,而且后面的训练效果其实差距也不大,所以也就没有再继续训练了。
1. 训练过程中似乎出现了模式崩塌的现象,因为生成的样本都非常类似,这个还需要进一步检查。
2. 之前用DCGAN、WGAN、BEGAN也做过人脸生成,下面来比较一下他们的效果:
DCGAN生成图像的大小为64*64的, 但是可以明显的看到,DCGAN的生成结果中,很多人脸的姿态都非常相似,DCGAN很容易出现模式崩塌现象,而且DCGAN生成的人脸肤色偏黑,且图像中的噪点很多,边缘非常不平滑,生成的效果比较差。
WGAN理论上解决了模式崩塌现象,生成人脸的尺寸为128*128,肤色明显自然了很多,但是生成的效果很差,边界几乎看不清,而且有的图像几乎什么也看不出来,噪点也非常多。
BEGAN生成的人脸尺寸为64*64,它是再DCGAN的基础上扩展的,效果明显好了很多,生成图像偶尔有噪点,人物的五官和轮廓都非常清晰,整体来讲效果不错。
SAGAN生成的人脸尺寸是128*128,生成的结果中,出现了模式崩塌的现象,可能是我对人脸数据集没有做预处理,生成人脸的五官是很清晰,但是五官之外的轮廓就非常糟糕了。
这几种模型,目前来看,我的实验效果是BEGAN最好,可能是因为我对其他数据集没怎么做数据预处理吧。
3. 关于celebA数据集的下载,作者也给了代码,这里也同时给出:
- import os
- import zipfile
- import argparse
- import requests
-
- from tqdm import tqdm
-
- parser = argparse.ArgumentParser(description='Download dataset for SAGAN')
- parser.add_argument('dataset', metavar='N', type=str, nargs='+', choices=['celebA'],
- help='name of dataset to download [celebA]')
-
-
- def download_file_from_google_drive(id, destination):
- URL = "https://docs.google.com/uc?export=download"
- session = requests.Session()
-
- response = session.get(URL, params={'id': id}, stream=True)
- token = get_confirm_token(response)
-
- if token:
- params = {'id': id, 'confirm': token}
- response = session.get(URL, params=params, stream=True)
-
- save_response_content(response, destination)
-
-
- def get_confirm_token(response):
- for key, value in response.cookies.items():
- if key.startswith('download_warning'):
- return value
- return None
-
-
- def save_response_content(response, destination, chunk_size=32 * 1024):
- total_size = int(response.headers.get('content-length', 0))
- with open(destination, "wb") as f:
- for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
- unit='B', unit_scale=True, desc=destination):
- if chunk: # filter out keep-alive new chunks
- f.write(chunk)
-
-
- def download_celeb_a(dirpath):
- data_dir = 'celebA'
- if os.path.exists(os.path.join(dirpath, data_dir)):
- print('Found Celeb-A - skip')
- return
-
- filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
- save_path = os.path.join(dirpath, filename)
-
- if os.path.exists(save_path):
- print('[*] {} already exists'.format(save_path))
- else:
- download_file_from_google_drive(drive_id, save_path)
-
- zip_dir = ''
- with zipfile.ZipFile(save_path) as zf:
- zip_dir = zf.namelist()[0]
- zf.extractall(dirpath)
- os.remove(save_path)
- os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
-
-
- def prepare_data_dir(path='./dataset'):
- if not os.path.exists(path):
- os.mkdir(path)
-
-
- if __name__ == '__main__':
- args = parser.parse_args()
- prepare_data_dir()
-
- if any(name in args.dataset for name in ['CelebA', 'celebA', 'celebA']):
- download_celeb_a('./dataset')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。