SAGAN全称为Self-Attention Generative Adversarial Networks,是由Han Zhang等人[1]于18年5月提出的一种模型。文章中作者解释到,传统的GAN模型都是在低分辨率特征图的空间局部点上来生成高分辨率的细节,而SAGAN是可以从所有的特征处生成细节,并且SAGAN的判别器可以判别两幅具有明显差异的图像是否具有一致的高度精细特征。SAGAN目前是取得了非常好的效果。
在我的上一篇文章GAN系列文章中:对抗神经网络学习(十)——attentiveGAN实现影像去雨滴的过程(tensorflow实现),初次了解到了attentiveNet引入GAN中的优势,引入attentiveNet来生成attentive map,能够让网络快速准确的定位到图像中的重点关注区域,当时就隐约觉得可以用这个思路来进一步优化GAN的模型结构,后来就看到了SAGAN采用了这个方法。
However, by carefully examining the generated samples from these models, we can observe that convolutional GANs have much more difficulty modeling some image classes than others when trained on multi-class datasets. For example, while the state-of-the-art ImageNet GAN model excels at synthesizing image classes with few structural constraints (e.g. ocean, sky and landscape classes, which are distinguished more by texture than by geometry), it fails to capture geometric or structural patterns that occur consistently in some classes (for example, dogs are often drawn with realistic fur texture but without clearly defined separate feet).
In this work, we propose Self-Attention Generative Adversarial Networks (SAGANs), which introduce a self -attention mechanism into convolutional GANs. The self-attention module is complementary to convolutions and helps with modeling long range, multi-level dependencies across image regions. Armed with self-attention, the generator can draw images in which fine details at every location are carefully coordinated with fine details in distant portions of the image. Moreover, the discriminator can also more accurately enforce complicated geometric constraints on the global image structure. (作者引入self-attention机制,提出SAGAN。引入该机制后,生成器能够精细的细节,判别器能够实行几何限制。)
简单来理解,最左边的图中的5个点是作者放置的5个点。而右侧的每张图对应一个点,内容则是这个点具有类似特征的区域,也就是上面说的most-attened regions。
为了使得SAGAN的训练稳定,作者还使用一点trick:①在生成器和判别器中都使用了spectral normalization
。②生成器和判别器的非平衡学习率(Imbalanced learning rate)。
SAGAN significantly outperforms the state of the art in image synthesis by boosting the best reported Inception score from 36.8 to 52.52 and reducing Fréchet Inception distance from 27.62 to 18.65.
- -- dataset # 数据集文件,需要自己下载
- |------ CelebA
- |------ image1.jpg
- |------ image2.jpg
- |------ ...
- -- ops.py # 图层文件
- -- utils.py # 操作文件
- -- SAGAN.py # 模型文件
- -- main.py # 主函数文件
- import scipy.misc
- import numpy as np
- import os
- from glob import glob
- import tensorflow as tf
- import tensorflow.contrib.slim as slim
- class ImageData:
- def __init__(self, load_size, channels):
- self.load_size = load_size
- self.channels = channels
- def image_processing(self, filename):
- x = tf.read_file(filename)
- x_decode = tf.image.decode_jpeg(x, channels=self.channels)
- img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
- img = tf.cast(img, tf.float32) / 127.5 - 1
- return img
- def load_data(dataset_name, size=64) :
- x = glob(os.path.join("./dataset", dataset_name, '*.*'))
- return x
- def preprocessing(x, size):
- x = scipy.misc.imread(x, mode='RGB')
- x = scipy.misc.imresize(x, [size, size])
- x = normalize(x)
- return x
- def normalize(x):
- return x/127.5 - 1
- def save_images(images, size, image_path):
- return imsave(inverse_transform(images), size, image_path)
- def merge(images, size):
- h, w = images.shape[1], images.shape[2]
- if images.shape[3] in (3, 4):
- c = images.shape[3]
- img = np.zeros((h * size[0], w * size[1], c))
- for idx, image in enumerate(images):
- i = idx % size[1]
- j = idx // size[1]
- img[j * h:j * h + h, i * w:i * w + w, :] = image
- return img
- elif images.shape[3] == 1:
- img = np.zeros((h * size[0], w * size[1]))
- for idx, image in enumerate(images):
- i = idx % size[1]
- j = idx // size[1]
- img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
- return img
- else:
- raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
- def imsave(images, size, path):
- # image = np.squeeze(merge(images, size)) # 채널이 1인거 제거 ?
- return scipy.misc.imsave(path, merge(images, size))
- def inverse_transform(images):
- return (images+1.)/2.
- def check_folder(log_dir):
- if not os.path.exists(log_dir):
- os.makedirs(log_dir)
- return log_dir
- def show_all_variables():
- model_vars = tf.trainable_variables()
- slim.model_analyzer.analyze_vars(model_vars, print_info=True)
- def str2bool(x):
- return x.lower() in ('true')

- import tensorflow as tf
- import tensorflow.contrib as tf_contrib
- # Xavier : tf_contrib.layers.xavier_initializer()
- # He : tf_contrib.layers.variance_scaling_initializer()
- # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
- # l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
- weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
- weight_regularizer = None
- ##################################################################################
- # Layer
- ##################################################################################
- def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
- with tf.variable_scope(scope):
- if pad_type == 'zero' :
- x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
- if pad_type == 'reflect' :
- x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT')
- if sn :
- w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
- regularizer=weight_regularizer)
- x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
- strides=[1, stride, stride, 1], padding='VALID')
- if use_bias :
- bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
- x = tf.nn.bias_add(x, bias)
- else :
- x = tf.layers.conv2d(inputs=x, filters=channels,
- kernel_size=kernel, kernel_initializer=weight_init,
- kernel_regularizer=weight_regularizer,
- strides=stride, use_bias=use_bias)
- return x
- def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'):
- with tf.variable_scope(scope):
- x_shape = x.get_shape().as_list()
- if padding == 'SAME':
- output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels]
- else:
- output_shape =[x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), x_shape[2] * stride + max(kernel - stride, 0), channels]
- if sn :
- w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer)
- x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding=padding)
- if use_bias :
- bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
- x = tf.nn.bias_add(x, bias)
- else :
- x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
- kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer,
- strides=stride, padding=padding, use_bias=use_bias)
- return x
- def fully_conneted(x, units, use_bias=True, sn=False, scope='fully_0'):
- with tf.variable_scope(scope):
- x = flatten(x)
- shape = x.get_shape().as_list()
- channels = shape[-1]
- if sn :
- w = tf.get_variable("kernel", [channels, units], tf.float32,
- initializer=weight_init, regularizer=weight_regularizer)
- if use_bias :
- bias = tf.get_variable("bias", [units],
- initializer=tf.constant_initializer(0.0))
- x = tf.matmul(x, spectral_norm(w)) + bias
- else :
- x = tf.matmul(x, spectral_norm(w))
- else :
- x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)
- return x
- def flatten(x) :
- return tf.layers.flatten(x)
- def hw_flatten(x) :
- return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
- ##################################################################################
- # Residual-block
- ##################################################################################
- def resblock(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock'):
- with tf.variable_scope(scope):
- with tf.variable_scope('res1'):
- x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
- x = batch_norm(x, is_training)
- x = relu(x)
- with tf.variable_scope('res2'):
- x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
- x = batch_norm(x, is_training)
- return x + x_init
- ##################################################################################
- # Sampling
- ##################################################################################
- def global_avg_pooling(x):
- gap = tf.reduce_mean(x, axis=[1, 2])
- return gap
- def up_sample(x, scale_factor=2):
- _, h, w, _ = x.get_shape().as_list()
- new_size = [h * scale_factor, w * scale_factor]
- return tf.image.resize_nearest_neighbor(x, size=new_size)
- ##################################################################################
- # Activation function
- ##################################################################################
- def lrelu(x, alpha=0.2):
- return tf.nn.leaky_relu(x, alpha)
- def relu(x):
- return tf.nn.relu(x)
- def tanh(x):
- return tf.tanh(x)
- ##################################################################################
- # Normalization function
- ##################################################################################
- def batch_norm(x, is_training=True, scope='batch_norm'):
- return tf_contrib.layers.batch_norm(x,
- decay=0.9, epsilon=1e-05,
- center=True, scale=True, updates_collections=None,
- is_training=is_training, scope=scope)
- def spectral_norm(w, iteration=1):
- w_shape = w.shape.as_list()
- w = tf.reshape(w, [-1, w_shape[-1]])
- u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
- u_hat = u
- v_hat = None
- for i in range(iteration):
- """
- power iteration
- Usually iteration = 1 will be enough
- """
- v_ = tf.matmul(u_hat, tf.transpose(w))
- v_hat = l2_norm(v_)
- u_ = tf.matmul(v_hat, w)
- u_hat = l2_norm(u_)
- sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
- w_norm = w / sigma
- with tf.control_dependencies([u.assign(u_hat)]):
- w_norm = tf.reshape(w_norm, w_shape)
- return w_norm
- def l2_norm(v, eps=1e-12):
- return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
- ##################################################################################
- # Loss function
- ##################################################################################
- def discriminator_loss(loss_func, real, fake):
- real_loss = 0
- fake_loss = 0
- if loss_func == 'lsgan' :
- real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0))
- fake_loss = tf.reduce_mean(tf.square(fake))
- if loss_func == 'gan' or loss_func == 'dragan' :
- real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
- fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
- if loss_func == 'hinge' :
- real_loss = tf.reduce_mean(relu(1.0 - real))
- fake_loss = tf.reduce_mean(relu(1.0 + fake))
- loss = real_loss + fake_loss
- return loss
- def generator_loss(loss_func, fake):
- fake_loss = 0
- if loss_func == 'lsgan' :
- fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0))
- if loss_func == 'gan' or loss_func == 'dragan' :
- fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))
- if loss_func == 'hinge' :
- fake_loss = -tf.reduce_mean(fake)
- loss = fake_loss
- return loss

- import time
- from ops import *
- from utils import *
- from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
- class SAGAN(object):
- def __init__(self, sess, args):
- self.model_name = "SAGAN" # name for checkpoint
- self.sess = sess
- self.dataset_name = args.dataset
- self.checkpoint_dir = args.checkpoint_dir
- self.sample_dir = args.sample_dir
- self.result_dir = args.result_dir
- self.log_dir = args.log_dir
- self.epoch = args.epoch
- self.iteration = args.iteration
- self.batch_size = args.batch_size
- self.print_freq = args.print_freq
- self.save_freq = args.save_freq
- self.img_size = args.img_size
- """ Generator """
- self.layer_num = int(np.log2(self.img_size)) - 3
- self.z_dim = args.z_dim # dimension of noise-vector
- self.up_sample = args.up_sample
- self.gan_type = args.gan_type
- """ Discriminator """
- self.n_critic = args.n_critic
- self.sn = args.sn
- self.ld = args.ld
- self.sample_num = args.sample_num # number of generated images to be saved
- self.test_num = args.test_num
- # train
- self.g_learning_rate = args.g_lr
- self.d_learning_rate = args.d_lr
- self.beta1 = args.beta1
- self.beta2 = args.beta2
- self.custom_dataset = False
- if self.dataset_name == 'mnist' :
- self.c_dim = 1
- self.data = load_mnist(size=self.img_size)
- elif self.dataset_name == 'cifar10' :
- self.c_dim = 3
- self.data = load_cifar10(size=self.img_size)
- else :
- self.c_dim = 3
- self.data = load_data(dataset_name=self.dataset_name, size=self.img_size)
- self.custom_dataset = True
- self.dataset_num = len(self.data)
- self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
- check_folder(self.sample_dir)
- print()
- print("##### Information #####")
- print("# gan type : ", self.gan_type)
- print("# dataset : ", self.dataset_name)
- print("# dataset number : ", self.dataset_num)
- print("# batch_size : ", self.batch_size)
- print("# epoch : ", self.epoch)
- print("# iteration per epoch : ", self.iteration)
- print()
- print("##### Generator #####")
- print("# generator layer : ", self.layer_num)
- print("# upsample conv : ", self.up_sample)
- print()
- print("##### Discriminator #####")
- print("# discriminator layer : ", self.layer_num)
- print("# the number of critic : ", self.n_critic)
- print("# spectral normalization : ", self.sn)
- ##################################################################################
- # Generator
- ##################################################################################
- def generator(self, z, is_training=True, reuse=False):
- with tf.variable_scope("generator", reuse=reuse):
- ch = 1024
- x = deconv(z, channels=ch, kernel=4, stride=1, padding='VALID', use_bias=False, sn=self.sn, scope='deconv')
- x = batch_norm(x, is_training, scope='batch_norm')
- x = relu(x)
- for i in range(self.layer_num // 2):
- if self.up_sample:
- x = up_sample(x, scale_factor=2)
- x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
- x = relu(x)
- else:
- x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
- x = relu(x)
- ch = ch // 2
- # Self Attention
- x = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)
- for i in range(self.layer_num // 2, self.layer_num):
- if self.up_sample:
- x = up_sample(x, scale_factor=2)
- x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
- x = relu(x)
- else:
- x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
- x = relu(x)
- ch = ch // 2
- if self.up_sample:
- x = up_sample(x, scale_factor=2)
- x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, sn=self.sn, scope='G_conv_logit')
- x = tanh(x)
- else:
- x = deconv(x, channels=self.c_dim, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='G_deconv_logit')
- x = tanh(x)
- return x
- ##################################################################################
- # Discriminator
- ##################################################################################
- def discriminator(self, x, is_training=True, reuse=False):
- with tf.variable_scope("discriminator", reuse=reuse):
- ch = 64
- x = conv(x, channels=ch, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv')
- x = lrelu(x, 0.2)
- for i in range(self.layer_num // 2):
- x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm' + str(i))
- x = lrelu(x, 0.2)
- ch = ch * 2
- # Self Attention
- x = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)
- for i in range(self.layer_num // 2, self.layer_num):
- x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))
- x = batch_norm(x, is_training, scope='batch_norm' + str(i))
- x = lrelu(x, 0.2)
- ch = ch * 2
- x = conv(x, channels=4, stride=1, sn=self.sn, use_bias=False, scope='D_logit')
- return x
- def attention(self, x, ch, sn=False, scope='attention', reuse=False):
- with tf.variable_scope(scope, reuse=reuse):
- f = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c']
- g = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c']
- h = conv(x, ch, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c]
- # N = h * w
- s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
- beta = tf.nn.softmax(s) # attention map
- o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
- gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
- o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]
- x = gamma * o + x
- return x
- def gradient_penalty(self, real, fake):
- if self.gan_type == 'dragan' :
- shape = tf.shape(real)
- eps = tf.random_uniform(shape=shape, minval=0., maxval=1.)
- x_mean, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
- x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
- noise = 0.5 * x_std * eps # delta in paper
- # Author suggested U[0,1] in original paper, but he admitted it is bug in github
- # (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided.
- alpha = tf.random_uniform(shape=[shape[0], 1, 1, 1], minval=-1., maxval=1.)
- interpolated = tf.clip_by_value(real + alpha * noise, -1., 1.) # x_hat should be in the space of X
- else :
- alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
- interpolated = alpha*real + (1. - alpha)*fake
- logit = self.discriminator(interpolated, reuse=True)
- grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
- grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
- GP = 0
- # WGAN - LP
- if self.gan_type == 'wgan-lp':
- GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
- elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
- GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
- return GP
- ##################################################################################
- # Model
- ##################################################################################
- def build_model(self):
- """ Graph Input """
- # images
- if self.custom_dataset :
- Image_Data_Class = ImageData(self.img_size, self.c_dim)
- inputs = tf.data.Dataset.from_tensor_slices(self.data)
- gpu_device = '/gpu:0'
- inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
- inputs_iterator = inputs.make_one_shot_iterator()
- self.inputs = inputs_iterator.get_next()
- else :
- self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images')
- # noises
- self.z = tf.placeholder(tf.float32, [self.batch_size, 1, 1, self.z_dim], name='z')
- """ Loss Function """
- # output of D for real images
- real_logits = self.discriminator(self.inputs)
- # output of D for fake images
- fake_images = self.generator(self.z)
- fake_logits = self.discriminator(fake_images, reuse=True)
- if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
- GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
- else :
- GP = 0
- # get loss for discriminator
- self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP
- # get loss for generator
- self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
- """ Training """
- # divide trainable variables into a group for D and a group for G
- t_vars = tf.trainable_variables()
- d_vars = [var for var in t_vars if 'discriminator' in var.name]
- g_vars = [var for var in t_vars if 'generator' in var.name]
- # optimizers
- self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
- self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars)
- """" Testing """
- # for test
- self.fake_images = self.generator(self.z, is_training=False, reuse=True)
- """ Summary """
- self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
- self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
- ##################################################################################
- # Train
- ##################################################################################
- def train(self):
- # initialize all variables
- tf.global_variables_initializer().run()
- # graph inputs for visualize training results
- self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
- # saver to save model
- self.saver = tf.train.Saver()
- # summary writer
- self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
- # restore check-point if it exits
- could_load, checkpoint_counter = self.load(self.checkpoint_dir)
- if could_load:
- start_epoch = (int)(checkpoint_counter / self.iteration)
- start_batch_id = checkpoint_counter - start_epoch * self.iteration
- counter = checkpoint_counter
- print(" [*] Load SUCCESS")
- else:
- start_epoch = 0
- start_batch_id = 0
- counter = 1
- print(" [!] Load failed...")
- # loop for epoch
- start_time = time.time()
- past_g_loss = -1.
- for epoch in range(start_epoch, self.epoch):
- # get batch data
- for idx in range(start_batch_id, self.iteration):
- batch_z = np.random.uniform(-1, 1, [self.batch_size, 1, 1, self.z_dim])
- if self.custom_dataset :
- train_feed_dict = {
- self.z: batch_z
- }
- else :
- random_index = np.random.choice(self.dataset_num, size=self.batch_size, replace=False)
- # batch_images = self.data[idx*self.batch_size : (idx+1)*self.batch_size]
- batch_images = self.data[random_index]
- train_feed_dict = {
- self.inputs : batch_images,
- self.z : batch_z
- }
- # update D network
- _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], feed_dict=train_feed_dict)
- self.writer.add_summary(summary_str, counter)
- # update G network
- g_loss = None
- if (counter - 1) % self.n_critic == 0:
- _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict=train_feed_dict)
- self.writer.add_summary(summary_str, counter)
- past_g_loss = g_loss
- # display training status
- counter += 1
- if g_loss == None :
- g_loss = past_g_loss
- print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
- % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
- # save training results for every 300 steps
- if np.mod(idx+1, self.print_freq) == 0:
- samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z})
- tot_num_samples = min(self.sample_num, self.batch_size)
- manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
- manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
- save_images(samples[:manifold_h * manifold_w, :, :, :],
- [manifold_h, manifold_w],
- './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format(epoch, idx+1))
- if np.mod(idx+1, self.save_freq) == 0:
- self.save(self.checkpoint_dir, counter)
- # After an epoch, start_batch_id is set to zero
- # non-zero value is only for the first epoch after loading pre-trained model
- start_batch_id = 0
- # save model
- self.save(self.checkpoint_dir, counter)
- # show temporal results
- # self.visualize_results(epoch)
- # save model for final step
- self.save(self.checkpoint_dir, counter)
- @property
- def model_dir(self):
- return "{}_{}_{}_{}_{}_{}".format(
- self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, self.sn)
- def save(self, checkpoint_dir, step):
- checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
- if not os.path.exists(checkpoint_dir):
- os.makedirs(checkpoint_dir)
- self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
- def load(self, checkpoint_dir):
- import re
- print(" [*] Reading checkpoints...")
- checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
- ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
- if ckpt and ckpt.model_checkpoint_path:
- ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
- self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
- counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
- print(" [*] Success to read {}".format(ckpt_name))
- return True, counter
- else:
- print(" [*] Failed to find a checkpoint")
- return False, 0
- def visualize_results(self, epoch):
- tot_num_samples = min(self.sample_num, self.batch_size)
- image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
- """ random condition, random noise """
- z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
- samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
- save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
- self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png')
- def test(self):
- tf.global_variables_initializer().run()
- self.saver = tf.train.Saver()
- could_load, checkpoint_counter = self.load(self.checkpoint_dir)
- result_dir = os.path.join(self.result_dir, self.model_dir)
- check_folder(result_dir)
- if could_load:
- print(" [*] Load SUCCESS")
- else:
- print(" [!] Load failed...")
- tot_num_samples = min(self.sample_num, self.batch_size)
- image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
- """ random condition, random noise """
- for i in range(self.test_num) :
- z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
- samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
- save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
- [image_frame_dim, image_frame_dim],
- result_dir + '/' + self.model_name + '_test_{}.png'.format(i))

- from SAGAN import SAGAN
- import argparse
- from utils import *
- """parsing and configuration"""
- def parse_args():
- desc = "Tensorflow implementation of Self-Attention GAN"
- parser = argparse.ArgumentParser(description=desc)
- parser.add_argument('--phase', type=str, default='train', help='train or test ?')
- parser.add_argument('--dataset', type=str, default='celebA', help='dataset name')
- parser.add_argument('--epoch', type=int, default=10, help='The number of epochs to run')
- parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
- parser.add_argument('--batch_size', type=int, default=32, help='The size of batch per gpu')
- parser.add_argument('--print_freq', type=int, default=500, help='The number of image_print_freqy')
- parser.add_argument('--save_freq', type=int, default=500, help='The number of ckpt_save_freq')
- parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for generator')
- parser.add_argument('--d_lr', type=float, default=0.0004, help='learning rate for discriminator')
- parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for Adam optimizer')
- parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for Adam optimizer')
- parser.add_argument('--z_dim', type=int, default=128, help='Dimension of noise vector')
- parser.add_argument('--up_sample', type=str2bool, default=True, help='using upsample-conv')
- parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')
- parser.add_argument('--gan_type', type=str, default='hinge', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]')
- parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')
- parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
- parser.add_argument('--img_size', type=int, default=128, help='The size of image')
- parser.add_argument('--sample_num', type=int, default=64, help='The number of sample images')
- parser.add_argument('--test_num', type=int, default=10, help='The number of images generated by the test')
- parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
- help='Directory name to save the checkpoints')
- parser.add_argument('--result_dir', type=str, default='results',
- help='Directory name to save the generated images')
- parser.add_argument('--log_dir', type=str, default='logs',
- help='Directory name to save training logs')
- parser.add_argument('--sample_dir', type=str, default='samples',
- help='Directory name to save the samples on training')
- return check_args(parser.parse_args())
- """checking arguments"""
- def check_args(args):
- # --checkpoint_dir
- check_folder(args.checkpoint_dir)
- # --result_dir
- check_folder(args.result_dir)
- # --result_dir
- check_folder(args.log_dir)
- # --sample_dir
- check_folder(args.sample_dir)
- # --epoch
- try:
- assert args.epoch >= 1
- except:
- print('number of epochs must be larger than or equal to one')
- # --batch_size
- try:
- assert args.batch_size >= 1
- except:
- print('batch size must be larger than or equal to one')
- return args
- """main"""
- def main():
- # parse arguments
- args = parse_args()
- if args is None:
- exit()
- # open session
- with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
- gan = SAGAN(sess, args)
- # build graph
- gan.build_model()
- # show network architecture
- show_all_variables()
- if args.phase == 'train' :
- # launch the graph in a session
- gan.train()
- # visualize learned generator
- gan.visualize_results(args.epoch - 1)
- print(" [*] Training finished!")
- if args.phase == 'test' :
- gan.test()
- print(" [*] Test finished!")
- if __name__ == '__main__':
- main()

写完上述所有文件之后,运行main.py即可。实验过程比较慢,设置epoch为10,每个epoch内迭代10000次,我的GPU是GTX1060 3G,每个epoch大概需要近4000秒,即1个小时5分钟,所以还是非常耗时的。目前只进行了训练过程,所以只有训练过程的生成样本。
当epoch=0, iter=500时,也就是迭代500次,生成样本为:
1. 训练过程中似乎出现了模式崩塌的现象,因为生成的样本都非常类似,这个还需要进一步检查。
2. 之前用DCGAN、WGAN、BEGAN也做过人脸生成,下面来比较一下他们的效果:
DCGAN生成图像的大小为64*64的, 但是可以明显的看到,DCGAN的生成结果中,很多人脸的姿态都非常相似,DCGAN很容易出现模式崩塌现象,而且DCGAN生成的人脸肤色偏黑,且图像中的噪点很多,边缘非常不平滑,生成的效果比较差。
3. 关于celebA数据集的下载,作者也给了代码,这里也同时给出:
- import os
- import zipfile
- import argparse
- import requests
- from tqdm import tqdm
- parser = argparse.ArgumentParser(description='Download dataset for SAGAN')
- parser.add_argument('dataset', metavar='N', type=str, nargs='+', choices=['celebA'],
- help='name of dataset to download [celebA]')
- def download_file_from_google_drive(id, destination):
- URL = "https://docs.google.com/uc?export=download"
- session = requests.Session()
- response = session.get(URL, params={'id': id}, stream=True)
- token = get_confirm_token(response)
- if token:
- params = {'id': id, 'confirm': token}
- response = session.get(URL, params=params, stream=True)
- save_response_content(response, destination)
- def get_confirm_token(response):
- for key, value in response.cookies.items():
- if key.startswith('download_warning'):
- return value
- return None
- def save_response_content(response, destination, chunk_size=32 * 1024):
- total_size = int(response.headers.get('content-length', 0))
- with open(destination, "wb") as f:
- for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
- unit='B', unit_scale=True, desc=destination):
- if chunk: # filter out keep-alive new chunks
- f.write(chunk)
- def download_celeb_a(dirpath):
- data_dir = 'celebA'
- if os.path.exists(os.path.join(dirpath, data_dir)):
- print('Found Celeb-A - skip')
- return
- filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
- save_path = os.path.join(dirpath, filename)
- if os.path.exists(save_path):
- print('[*] {} already exists'.format(save_path))
- else:
- download_file_from_google_drive(drive_id, save_path)
- zip_dir = ''
- with zipfile.ZipFile(save_path) as zf:
- zip_dir = zf.namelist()[0]
- zf.extractall(dirpath)
- os.remove(save_path)
- os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
- def prepare_data_dir(path='./dataset'):
- if not os.path.exists(path):
- os.mkdir(path)
- if __name__ == '__main__':
- args = parser.parse_args()
- prepare_data_dir()
- if any(name in args.dataset for name in ['CelebA', 'celebA', 'celebA']):
- download_celeb_a('./dataset')

