当前位置:   article > 正文

对抗生成网络学习(十一)——SAGAN生成更为精细的人脸图像(tensorflow实现)_sagan模型

sagan模型

一、背景

SAGAN全称为Self-Attention Generative Adversarial Networks,是由Han Zhang等人[1]于18年5月提出的一种模型。文章中作者解释到,传统的GAN模型都是在低分辨率特征图的空间局部点上来生成高分辨率的细节,而SAGAN是可以从所有的特征处生成细节,并且SAGAN的判别器可以判别两幅具有明显差异的图像是否具有一致的高度精细特征。SAGAN目前是取得了非常好的效果。

本文以CelebA数据集为例,用SAGAN生成更为精细的人脸图像,主要参考代码为[2]。

[1]文章链接:https://arxiv.org/pdf/1805.08318.pdf

[2]参考代码:https://github.com/taki0112/Self-Attention-GAN-Tensorflow

二、SAGAN原理

感觉入门系列的GAN文章网上的介绍还挺多,越新的文章解读越少,这里简单推荐一篇:

[3]SA-GAN - Self-Attention Generative Adversarial Networks 论文解读(附代码)

下面是自己对于文献的一些理解和介绍。

在我的上一篇文章GAN系列文章中:对抗神经网络学习(十)——attentiveGAN实现影像去雨滴的过程(tensorflow实现),初次了解到了attentiveNet引入GAN中的优势,引入attentiveNet来生成attentive map,能够让网络快速准确的定位到图像中的重点关注区域,当时就隐约觉得可以用这个思路来进一步优化GAN的模型结构,后来就看到了SAGAN采用了这个方法。

首先,作者关注GAN目前存在的问题:当我们训练多类别数据集时,GAN在某些图像类别上很难建模。通俗来说,GAN容易捕捉纹理特征但很难捕捉几何结构特征。

However, by carefully examining the generated samples from these models, we can observe that convolutional GANs have much more difficulty modeling some image classes than others when trained on multi-class datasets. For example, while the state-of-the-art ImageNet GAN model excels at synthesizing image classes with few structural constraints (e.g. ocean, sky and landscape classes, which are distinguished more by texture than by geometry), it fails to capture geometric or structural patterns that occur consistently in some classes (for example, dogs are often drawn with realistic fur texture but without clearly defined separate feet).

原因就在于这类模型依靠卷积来建立不同图像区域之间的依赖关系,而依赖关系的传递只能通过大范围的多个卷积层来实现。随着卷积大小的增加,网络的真实容量也在增加,但却损失了计算效率。而self-attentive,却能够做到依赖性和计算效率的平衡,因此文章引入self-attention机制。

作者主要的贡献在于:

In this work, we propose Self-Attention Generative Adversarial Networks (SAGANs), which introduce a self -attention mechanism into convolutional GANs. The self-attention module is complementary to convolutions and helps with modeling long range, multi-level dependencies across image regions. Armed with self-attention, the generator can draw images in which fine details at every location are carefully coordinated with fine details in distant portions of the image. Moreover, the discriminator can also more accurately enforce complicated geometric constraints on the global image structure. (作者引入self-attention机制,提出SAGAN。引入该机制后,生成器能够精细的细节,判别器能够实行几何限制。)

作者以一幅图来简单介绍self-attention:

简单来理解,最左边的图中的5个点是作者放置的5个点。而右侧的每张图对应一个点,内容则是这个点具有类似特征的区域,也就是上面说的most-attened regions。

之后,作者介绍了self-attention机制:

将self-attention机制引入到GAN中,可以得到SAGAN的loss函数:

为了使得SAGAN的训练稳定,作者还使用一点trick:①在生成器和判别器中都使用了spectral normalization
。②生成器和判别器的非平衡学习率(Imbalanced learning rate)。

最终SAGAN的表现也很好:

SAGAN significantly outperforms the state of the art in image synthesis by boosting the best reported Inception score from 36.8 to 52.52 and reducing Fréchet Inception distance from 27.62 to 18.65.

关于文章中的公式推导和参数的设置,这里不再多说,最后只展示一下作者的效果图,看起来还是很惊艳的:

三、SAGAN实现过程

1. 所有文件结构

SAGAN的文件比较少,所有文件的结构为:

  1. -- dataset # 数据集文件,需要自己下载
  2. |------ CelebA
  3. |------ image1.jpg
  4. |------ image2.jpg
  5. |------ ...
  6. -- ops.py # 图层文件
  7. -- utils.py # 操作文件
  8. -- SAGAN.py # 模型文件
  9. -- main.py # 主函数文件

2. 数据准备

这次用的数据集依旧是CelebA人脸数据集,之前的文章对抗神经网络学习(六)——BEGAN实现不同人脸的生成(tensorflow实现)中有介绍这个数据集,因此这里不再多介绍,只给出这个数据集的官方下载地址:https://pan.baidu.com/s/1eSNpdRG#list/path=%2FCelebA

下载好数据集之后,将其解压放到'dataset/CelebA/'路径下即可。

3. 操作文件utils.py

utils.py的所有内容如下:

  1. import scipy.misc
  2. import numpy as np
  3. import os
  4. from glob import glob
  5. import tensorflow as tf
  6. import tensorflow.contrib.slim as slim
  7. class ImageData:
  8. def __init__(self, load_size, channels):
  9. self.load_size = load_size
  10. self.channels = channels
  11. def image_processing(self, filename):
  12. x = tf.read_file(filename)
  13. x_decode = tf.image.decode_jpeg(x, channels=self.channels)
  14. img = tf.image.resize_images(x_decode, [self.load_size, self.load_size])
  15. img = tf.cast(img, tf.float32) / 127.5 - 1
  16. return img
  17. def load_data(dataset_name, size=64) :
  18. x = glob(os.path.join("./dataset", dataset_name, '*.*'))
  19. return x
  20. def preprocessing(x, size):
  21. x = scipy.misc.imread(x, mode='RGB')
  22. x = scipy.misc.imresize(x, [size, size])
  23. x = normalize(x)
  24. return x
  25. def normalize(x):
  26. return x/127.5 - 1
  27. def save_images(images, size, image_path):
  28. return imsave(inverse_transform(images), size, image_path)
  29. def merge(images, size):
  30. h, w = images.shape[1], images.shape[2]
  31. if images.shape[3] in (3, 4):
  32. c = images.shape[3]
  33. img = np.zeros((h * size[0], w * size[1], c))
  34. for idx, image in enumerate(images):
  35. i = idx % size[1]
  36. j = idx // size[1]
  37. img[j * h:j * h + h, i * w:i * w + w, :] = image
  38. return img
  39. elif images.shape[3] == 1:
  40. img = np.zeros((h * size[0], w * size[1]))
  41. for idx, image in enumerate(images):
  42. i = idx % size[1]
  43. j = idx // size[1]
  44. img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
  45. return img
  46. else:
  47. raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
  48. def imsave(images, size, path):
  49. # image = np.squeeze(merge(images, size)) # 채널이 1인거 제거 ?
  50. return scipy.misc.imsave(path, merge(images, size))
  51. def inverse_transform(images):
  52. return (images+1.)/2.
  53. def check_folder(log_dir):
  54. if not os.path.exists(log_dir):
  55. os.makedirs(log_dir)
  56. return log_dir
  57. def show_all_variables():
  58. model_vars = tf.trainable_variables()
  59. slim.model_analyzer.analyze_vars(model_vars, print_info=True)
  60. def str2bool(x):
  61. return x.lower() in ('true')

4. 图层文件ops.py

图层文件ops.py的主要内容为:

  1. import tensorflow as tf
  2. import tensorflow.contrib as tf_contrib
  3. # Xavier : tf_contrib.layers.xavier_initializer()
  4. # He : tf_contrib.layers.variance_scaling_initializer()
  5. # Normal : tf.random_normal_initializer(mean=0.0, stddev=0.02)
  6. # l2_decay : tf_contrib.layers.l2_regularizer(0.0001)
  7. weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
  8. weight_regularizer = None
  9. ##################################################################################
  10. # Layer
  11. ##################################################################################
  12. def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True, sn=False, scope='conv_0'):
  13. with tf.variable_scope(scope):
  14. if pad_type == 'zero' :
  15. x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]])
  16. if pad_type == 'reflect' :
  17. x = tf.pad(x, [[0, 0], [pad, pad], [pad, pad], [0, 0]], mode='REFLECT')
  18. if sn :
  19. w = tf.get_variable("kernel", shape=[kernel, kernel, x.get_shape()[-1], channels], initializer=weight_init,
  20. regularizer=weight_regularizer)
  21. x = tf.nn.conv2d(input=x, filter=spectral_norm(w),
  22. strides=[1, stride, stride, 1], padding='VALID')
  23. if use_bias :
  24. bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
  25. x = tf.nn.bias_add(x, bias)
  26. else :
  27. x = tf.layers.conv2d(inputs=x, filters=channels,
  28. kernel_size=kernel, kernel_initializer=weight_init,
  29. kernel_regularizer=weight_regularizer,
  30. strides=stride, use_bias=use_bias)
  31. return x
  32. def deconv(x, channels, kernel=4, stride=2, padding='SAME', use_bias=True, sn=False, scope='deconv_0'):
  33. with tf.variable_scope(scope):
  34. x_shape = x.get_shape().as_list()
  35. if padding == 'SAME':
  36. output_shape = [x_shape[0], x_shape[1] * stride, x_shape[2] * stride, channels]
  37. else:
  38. output_shape =[x_shape[0], x_shape[1] * stride + max(kernel - stride, 0), x_shape[2] * stride + max(kernel - stride, 0), channels]
  39. if sn :
  40. w = tf.get_variable("kernel", shape=[kernel, kernel, channels, x.get_shape()[-1]], initializer=weight_init, regularizer=weight_regularizer)
  41. x = tf.nn.conv2d_transpose(x, filter=spectral_norm(w), output_shape=output_shape, strides=[1, stride, stride, 1], padding=padding)
  42. if use_bias :
  43. bias = tf.get_variable("bias", [channels], initializer=tf.constant_initializer(0.0))
  44. x = tf.nn.bias_add(x, bias)
  45. else :
  46. x = tf.layers.conv2d_transpose(inputs=x, filters=channels,
  47. kernel_size=kernel, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer,
  48. strides=stride, padding=padding, use_bias=use_bias)
  49. return x
  50. def fully_conneted(x, units, use_bias=True, sn=False, scope='fully_0'):
  51. with tf.variable_scope(scope):
  52. x = flatten(x)
  53. shape = x.get_shape().as_list()
  54. channels = shape[-1]
  55. if sn :
  56. w = tf.get_variable("kernel", [channels, units], tf.float32,
  57. initializer=weight_init, regularizer=weight_regularizer)
  58. if use_bias :
  59. bias = tf.get_variable("bias", [units],
  60. initializer=tf.constant_initializer(0.0))
  61. x = tf.matmul(x, spectral_norm(w)) + bias
  62. else :
  63. x = tf.matmul(x, spectral_norm(w))
  64. else :
  65. x = tf.layers.dense(x, units=units, kernel_initializer=weight_init, kernel_regularizer=weight_regularizer, use_bias=use_bias)
  66. return x
  67. def flatten(x) :
  68. return tf.layers.flatten(x)
  69. def hw_flatten(x) :
  70. return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
  71. ##################################################################################
  72. # Residual-block
  73. ##################################################################################
  74. def resblock(x_init, channels, use_bias=True, is_training=True, sn=False, scope='resblock'):
  75. with tf.variable_scope(scope):
  76. with tf.variable_scope('res1'):
  77. x = conv(x_init, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
  78. x = batch_norm(x, is_training)
  79. x = relu(x)
  80. with tf.variable_scope('res2'):
  81. x = conv(x, channels, kernel=3, stride=1, pad=1, pad_type='reflect', use_bias=use_bias, sn=sn)
  82. x = batch_norm(x, is_training)
  83. return x + x_init
  84. ##################################################################################
  85. # Sampling
  86. ##################################################################################
  87. def global_avg_pooling(x):
  88. gap = tf.reduce_mean(x, axis=[1, 2])
  89. return gap
  90. def up_sample(x, scale_factor=2):
  91. _, h, w, _ = x.get_shape().as_list()
  92. new_size = [h * scale_factor, w * scale_factor]
  93. return tf.image.resize_nearest_neighbor(x, size=new_size)
  94. ##################################################################################
  95. # Activation function
  96. ##################################################################################
  97. def lrelu(x, alpha=0.2):
  98. return tf.nn.leaky_relu(x, alpha)
  99. def relu(x):
  100. return tf.nn.relu(x)
  101. def tanh(x):
  102. return tf.tanh(x)
  103. ##################################################################################
  104. # Normalization function
  105. ##################################################################################
  106. def batch_norm(x, is_training=True, scope='batch_norm'):
  107. return tf_contrib.layers.batch_norm(x,
  108. decay=0.9, epsilon=1e-05,
  109. center=True, scale=True, updates_collections=None,
  110. is_training=is_training, scope=scope)
  111. def spectral_norm(w, iteration=1):
  112. w_shape = w.shape.as_list()
  113. w = tf.reshape(w, [-1, w_shape[-1]])
  114. u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
  115. u_hat = u
  116. v_hat = None
  117. for i in range(iteration):
  118. """
  119. power iteration
  120. Usually iteration = 1 will be enough
  121. """
  122. v_ = tf.matmul(u_hat, tf.transpose(w))
  123. v_hat = l2_norm(v_)
  124. u_ = tf.matmul(v_hat, w)
  125. u_hat = l2_norm(u_)
  126. sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
  127. w_norm = w / sigma
  128. with tf.control_dependencies([u.assign(u_hat)]):
  129. w_norm = tf.reshape(w_norm, w_shape)
  130. return w_norm
  131. def l2_norm(v, eps=1e-12):
  132. return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
  133. ##################################################################################
  134. # Loss function
  135. ##################################################################################
  136. def discriminator_loss(loss_func, real, fake):
  137. real_loss = 0
  138. fake_loss = 0
  139. if loss_func == 'lsgan' :
  140. real_loss = tf.reduce_mean(tf.squared_difference(real, 1.0))
  141. fake_loss = tf.reduce_mean(tf.square(fake))
  142. if loss_func == 'gan' or loss_func == 'dragan' :
  143. real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(real), logits=real))
  144. fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.zeros_like(fake), logits=fake))
  145. if loss_func == 'hinge' :
  146. real_loss = tf.reduce_mean(relu(1.0 - real))
  147. fake_loss = tf.reduce_mean(relu(1.0 + fake))
  148. loss = real_loss + fake_loss
  149. return loss
  150. def generator_loss(loss_func, fake):
  151. fake_loss = 0
  152. if loss_func == 'lsgan' :
  153. fake_loss = tf.reduce_mean(tf.squared_difference(fake, 1.0))
  154. if loss_func == 'gan' or loss_func == 'dragan' :
  155. fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(fake), logits=fake))
  156. if loss_func == 'hinge' :
  157. fake_loss = -tf.reduce_mean(fake)
  158. loss = fake_loss
  159. return loss

5. 模型文件SAGAN.py

SAGAN.py文件的主要内容为:

  1. import time
  2. from ops import *
  3. from utils import *
  4. from tensorflow.contrib.data import prefetch_to_device, shuffle_and_repeat, map_and_batch
  5. class SAGAN(object):
  6. def __init__(self, sess, args):
  7. self.model_name = "SAGAN" # name for checkpoint
  8. self.sess = sess
  9. self.dataset_name = args.dataset
  10. self.checkpoint_dir = args.checkpoint_dir
  11. self.sample_dir = args.sample_dir
  12. self.result_dir = args.result_dir
  13. self.log_dir = args.log_dir
  14. self.epoch = args.epoch
  15. self.iteration = args.iteration
  16. self.batch_size = args.batch_size
  17. self.print_freq = args.print_freq
  18. self.save_freq = args.save_freq
  19. self.img_size = args.img_size
  20. """ Generator """
  21. self.layer_num = int(np.log2(self.img_size)) - 3
  22. self.z_dim = args.z_dim # dimension of noise-vector
  23. self.up_sample = args.up_sample
  24. self.gan_type = args.gan_type
  25. """ Discriminator """
  26. self.n_critic = args.n_critic
  27. self.sn = args.sn
  28. self.ld = args.ld
  29. self.sample_num = args.sample_num # number of generated images to be saved
  30. self.test_num = args.test_num
  31. # train
  32. self.g_learning_rate = args.g_lr
  33. self.d_learning_rate = args.d_lr
  34. self.beta1 = args.beta1
  35. self.beta2 = args.beta2
  36. self.custom_dataset = False
  37. if self.dataset_name == 'mnist' :
  38. self.c_dim = 1
  39. self.data = load_mnist(size=self.img_size)
  40. elif self.dataset_name == 'cifar10' :
  41. self.c_dim = 3
  42. self.data = load_cifar10(size=self.img_size)
  43. else :
  44. self.c_dim = 3
  45. self.data = load_data(dataset_name=self.dataset_name, size=self.img_size)
  46. self.custom_dataset = True
  47. self.dataset_num = len(self.data)
  48. self.sample_dir = os.path.join(self.sample_dir, self.model_dir)
  49. check_folder(self.sample_dir)
  50. print()
  51. print("##### Information #####")
  52. print("# gan type : ", self.gan_type)
  53. print("# dataset : ", self.dataset_name)
  54. print("# dataset number : ", self.dataset_num)
  55. print("# batch_size : ", self.batch_size)
  56. print("# epoch : ", self.epoch)
  57. print("# iteration per epoch : ", self.iteration)
  58. print()
  59. print("##### Generator #####")
  60. print("# generator layer : ", self.layer_num)
  61. print("# upsample conv : ", self.up_sample)
  62. print()
  63. print("##### Discriminator #####")
  64. print("# discriminator layer : ", self.layer_num)
  65. print("# the number of critic : ", self.n_critic)
  66. print("# spectral normalization : ", self.sn)
  67. ##################################################################################
  68. # Generator
  69. ##################################################################################
  70. def generator(self, z, is_training=True, reuse=False):
  71. with tf.variable_scope("generator", reuse=reuse):
  72. ch = 1024
  73. x = deconv(z, channels=ch, kernel=4, stride=1, padding='VALID', use_bias=False, sn=self.sn, scope='deconv')
  74. x = batch_norm(x, is_training, scope='batch_norm')
  75. x = relu(x)
  76. for i in range(self.layer_num // 2):
  77. if self.up_sample:
  78. x = up_sample(x, scale_factor=2)
  79. x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))
  80. x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
  81. x = relu(x)
  82. else:
  83. x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))
  84. x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
  85. x = relu(x)
  86. ch = ch // 2
  87. # Self Attention
  88. x = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)
  89. for i in range(self.layer_num // 2, self.layer_num):
  90. if self.up_sample:
  91. x = up_sample(x, scale_factor=2)
  92. x = conv(x, channels=ch // 2, kernel=3, stride=1, pad=1, sn=self.sn, scope='up_conv_' + str(i))
  93. x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
  94. x = relu(x)
  95. else:
  96. x = deconv(x, channels=ch // 2, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='deconv_' + str(i))
  97. x = batch_norm(x, is_training, scope='batch_norm_' + str(i))
  98. x = relu(x)
  99. ch = ch // 2
  100. if self.up_sample:
  101. x = up_sample(x, scale_factor=2)
  102. x = conv(x, channels=self.c_dim, kernel=3, stride=1, pad=1, sn=self.sn, scope='G_conv_logit')
  103. x = tanh(x)
  104. else:
  105. x = deconv(x, channels=self.c_dim, kernel=4, stride=2, use_bias=False, sn=self.sn, scope='G_deconv_logit')
  106. x = tanh(x)
  107. return x
  108. ##################################################################################
  109. # Discriminator
  110. ##################################################################################
  111. def discriminator(self, x, is_training=True, reuse=False):
  112. with tf.variable_scope("discriminator", reuse=reuse):
  113. ch = 64
  114. x = conv(x, channels=ch, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv')
  115. x = lrelu(x, 0.2)
  116. for i in range(self.layer_num // 2):
  117. x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))
  118. x = batch_norm(x, is_training, scope='batch_norm' + str(i))
  119. x = lrelu(x, 0.2)
  120. ch = ch * 2
  121. # Self Attention
  122. x = self.attention(x, ch, sn=self.sn, scope="attention", reuse=reuse)
  123. for i in range(self.layer_num // 2, self.layer_num):
  124. x = conv(x, channels=ch * 2, kernel=4, stride=2, pad=1, sn=self.sn, use_bias=False, scope='conv_' + str(i))
  125. x = batch_norm(x, is_training, scope='batch_norm' + str(i))
  126. x = lrelu(x, 0.2)
  127. ch = ch * 2
  128. x = conv(x, channels=4, stride=1, sn=self.sn, use_bias=False, scope='D_logit')
  129. return x
  130. def attention(self, x, ch, sn=False, scope='attention', reuse=False):
  131. with tf.variable_scope(scope, reuse=reuse):
  132. f = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='f_conv') # [bs, h, w, c']
  133. g = conv(x, ch // 8, kernel=1, stride=1, sn=sn, scope='g_conv') # [bs, h, w, c']
  134. h = conv(x, ch, kernel=1, stride=1, sn=sn, scope='h_conv') # [bs, h, w, c]
  135. # N = h * w
  136. s = tf.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
  137. beta = tf.nn.softmax(s) # attention map
  138. o = tf.matmul(beta, hw_flatten(h)) # [bs, N, C]
  139. gamma = tf.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))
  140. o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]
  141. x = gamma * o + x
  142. return x
  143. def gradient_penalty(self, real, fake):
  144. if self.gan_type == 'dragan' :
  145. shape = tf.shape(real)
  146. eps = tf.random_uniform(shape=shape, minval=0., maxval=1.)
  147. x_mean, x_var = tf.nn.moments(real, axes=[0, 1, 2, 3])
  148. x_std = tf.sqrt(x_var) # magnitude of noise decides the size of local region
  149. noise = 0.5 * x_std * eps # delta in paper
  150. # Author suggested U[0,1] in original paper, but he admitted it is bug in github
  151. # (https://github.com/kodalinaveen3/DRAGAN). It should be two-sided.
  152. alpha = tf.random_uniform(shape=[shape[0], 1, 1, 1], minval=-1., maxval=1.)
  153. interpolated = tf.clip_by_value(real + alpha * noise, -1., 1.) # x_hat should be in the space of X
  154. else :
  155. alpha = tf.random_uniform(shape=[self.batch_size, 1, 1, 1], minval=0., maxval=1.)
  156. interpolated = alpha*real + (1. - alpha)*fake
  157. logit = self.discriminator(interpolated, reuse=True)
  158. grad = tf.gradients(logit, interpolated)[0] # gradient of D(interpolated)
  159. grad_norm = tf.norm(flatten(grad), axis=1) # l2 norm
  160. GP = 0
  161. # WGAN - LP
  162. if self.gan_type == 'wgan-lp':
  163. GP = self.ld * tf.reduce_mean(tf.square(tf.maximum(0.0, grad_norm - 1.)))
  164. elif self.gan_type == 'wgan-gp' or self.gan_type == 'dragan':
  165. GP = self.ld * tf.reduce_mean(tf.square(grad_norm - 1.))
  166. return GP
  167. ##################################################################################
  168. # Model
  169. ##################################################################################
  170. def build_model(self):
  171. """ Graph Input """
  172. # images
  173. if self.custom_dataset :
  174. Image_Data_Class = ImageData(self.img_size, self.c_dim)
  175. inputs = tf.data.Dataset.from_tensor_slices(self.data)
  176. gpu_device = '/gpu:0'
  177. inputs = inputs.apply(shuffle_and_repeat(self.dataset_num)).apply(map_and_batch(Image_Data_Class.image_processing, self.batch_size, num_parallel_batches=16, drop_remainder=True)).apply(prefetch_to_device(gpu_device, self.batch_size))
  178. inputs_iterator = inputs.make_one_shot_iterator()
  179. self.inputs = inputs_iterator.get_next()
  180. else :
  181. self.inputs = tf.placeholder(tf.float32, [self.batch_size, self.img_size, self.img_size, self.c_dim], name='real_images')
  182. # noises
  183. self.z = tf.placeholder(tf.float32, [self.batch_size, 1, 1, self.z_dim], name='z')
  184. """ Loss Function """
  185. # output of D for real images
  186. real_logits = self.discriminator(self.inputs)
  187. # output of D for fake images
  188. fake_images = self.generator(self.z)
  189. fake_logits = self.discriminator(fake_images, reuse=True)
  190. if self.gan_type.__contains__('wgan') or self.gan_type == 'dragan' :
  191. GP = self.gradient_penalty(real=self.inputs, fake=fake_images)
  192. else :
  193. GP = 0
  194. # get loss for discriminator
  195. self.d_loss = discriminator_loss(self.gan_type, real=real_logits, fake=fake_logits) + GP
  196. # get loss for generator
  197. self.g_loss = generator_loss(self.gan_type, fake=fake_logits)
  198. """ Training """
  199. # divide trainable variables into a group for D and a group for G
  200. t_vars = tf.trainable_variables()
  201. d_vars = [var for var in t_vars if 'discriminator' in var.name]
  202. g_vars = [var for var in t_vars if 'generator' in var.name]
  203. # optimizers
  204. self.d_optim = tf.train.AdamOptimizer(self.d_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.d_loss, var_list=d_vars)
  205. self.g_optim = tf.train.AdamOptimizer(self.g_learning_rate, beta1=self.beta1, beta2=self.beta2).minimize(self.g_loss, var_list=g_vars)
  206. """" Testing """
  207. # for test
  208. self.fake_images = self.generator(self.z, is_training=False, reuse=True)
  209. """ Summary """
  210. self.d_sum = tf.summary.scalar("d_loss", self.d_loss)
  211. self.g_sum = tf.summary.scalar("g_loss", self.g_loss)
  212. ##################################################################################
  213. # Train
  214. ##################################################################################
  215. def train(self):
  216. # initialize all variables
  217. tf.global_variables_initializer().run()
  218. # graph inputs for visualize training results
  219. self.sample_z = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
  220. # saver to save model
  221. self.saver = tf.train.Saver()
  222. # summary writer
  223. self.writer = tf.summary.FileWriter(self.log_dir + '/' + self.model_dir, self.sess.graph)
  224. # restore check-point if it exits
  225. could_load, checkpoint_counter = self.load(self.checkpoint_dir)
  226. if could_load:
  227. start_epoch = (int)(checkpoint_counter / self.iteration)
  228. start_batch_id = checkpoint_counter - start_epoch * self.iteration
  229. counter = checkpoint_counter
  230. print(" [*] Load SUCCESS")
  231. else:
  232. start_epoch = 0
  233. start_batch_id = 0
  234. counter = 1
  235. print(" [!] Load failed...")
  236. # loop for epoch
  237. start_time = time.time()
  238. past_g_loss = -1.
  239. for epoch in range(start_epoch, self.epoch):
  240. # get batch data
  241. for idx in range(start_batch_id, self.iteration):
  242. batch_z = np.random.uniform(-1, 1, [self.batch_size, 1, 1, self.z_dim])
  243. if self.custom_dataset :
  244. train_feed_dict = {
  245. self.z: batch_z
  246. }
  247. else :
  248. random_index = np.random.choice(self.dataset_num, size=self.batch_size, replace=False)
  249. # batch_images = self.data[idx*self.batch_size : (idx+1)*self.batch_size]
  250. batch_images = self.data[random_index]
  251. train_feed_dict = {
  252. self.inputs : batch_images,
  253. self.z : batch_z
  254. }
  255. # update D network
  256. _, summary_str, d_loss = self.sess.run([self.d_optim, self.d_sum, self.d_loss], feed_dict=train_feed_dict)
  257. self.writer.add_summary(summary_str, counter)
  258. # update G network
  259. g_loss = None
  260. if (counter - 1) % self.n_critic == 0:
  261. _, summary_str, g_loss = self.sess.run([self.g_optim, self.g_sum, self.g_loss], feed_dict=train_feed_dict)
  262. self.writer.add_summary(summary_str, counter)
  263. past_g_loss = g_loss
  264. # display training status
  265. counter += 1
  266. if g_loss == None :
  267. g_loss = past_g_loss
  268. print("Epoch: [%2d] [%5d/%5d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
  269. % (epoch, idx, self.iteration, time.time() - start_time, d_loss, g_loss))
  270. # save training results for every 300 steps
  271. if np.mod(idx+1, self.print_freq) == 0:
  272. samples = self.sess.run(self.fake_images, feed_dict={self.z: self.sample_z})
  273. tot_num_samples = min(self.sample_num, self.batch_size)
  274. manifold_h = int(np.floor(np.sqrt(tot_num_samples)))
  275. manifold_w = int(np.floor(np.sqrt(tot_num_samples)))
  276. save_images(samples[:manifold_h * manifold_w, :, :, :],
  277. [manifold_h, manifold_w],
  278. './' + self.sample_dir + '/' + self.model_name + '_train_{:02d}_{:05d}.png'.format(epoch, idx+1))
  279. if np.mod(idx+1, self.save_freq) == 0:
  280. self.save(self.checkpoint_dir, counter)
  281. # After an epoch, start_batch_id is set to zero
  282. # non-zero value is only for the first epoch after loading pre-trained model
  283. start_batch_id = 0
  284. # save model
  285. self.save(self.checkpoint_dir, counter)
  286. # show temporal results
  287. # self.visualize_results(epoch)
  288. # save model for final step
  289. self.save(self.checkpoint_dir, counter)
  290. @property
  291. def model_dir(self):
  292. return "{}_{}_{}_{}_{}_{}".format(
  293. self.model_name, self.dataset_name, self.gan_type, self.img_size, self.z_dim, self.sn)
  294. def save(self, checkpoint_dir, step):
  295. checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
  296. if not os.path.exists(checkpoint_dir):
  297. os.makedirs(checkpoint_dir)
  298. self.saver.save(self.sess, os.path.join(checkpoint_dir, self.model_name+'.model'), global_step=step)
  299. def load(self, checkpoint_dir):
  300. import re
  301. print(" [*] Reading checkpoints...")
  302. checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
  303. ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
  304. if ckpt and ckpt.model_checkpoint_path:
  305. ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
  306. self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
  307. counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
  308. print(" [*] Success to read {}".format(ckpt_name))
  309. return True, counter
  310. else:
  311. print(" [*] Failed to find a checkpoint")
  312. return False, 0
  313. def visualize_results(self, epoch):
  314. tot_num_samples = min(self.sample_num, self.batch_size)
  315. image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
  316. """ random condition, random noise """
  317. z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
  318. samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
  319. save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
  320. self.sample_dir + '/' + self.model_name + '_epoch%02d' % epoch + '_visualize.png')
  321. def test(self):
  322. tf.global_variables_initializer().run()
  323. self.saver = tf.train.Saver()
  324. could_load, checkpoint_counter = self.load(self.checkpoint_dir)
  325. result_dir = os.path.join(self.result_dir, self.model_dir)
  326. check_folder(result_dir)
  327. if could_load:
  328. print(" [*] Load SUCCESS")
  329. else:
  330. print(" [!] Load failed...")
  331. tot_num_samples = min(self.sample_num, self.batch_size)
  332. image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))
  333. """ random condition, random noise """
  334. for i in range(self.test_num) :
  335. z_sample = np.random.uniform(-1, 1, size=(self.batch_size, 1, 1, self.z_dim))
  336. samples = self.sess.run(self.fake_images, feed_dict={self.z: z_sample})
  337. save_images(samples[:image_frame_dim * image_frame_dim, :, :, :],
  338. [image_frame_dim, image_frame_dim],
  339. result_dir + '/' + self.model_name + '_test_{}.png'.format(i))

6. 主文件main.py

main.py文件的主要内容为:

  1. from SAGAN import SAGAN
  2. import argparse
  3. from utils import *
  4. """parsing and configuration"""
  5. def parse_args():
  6. desc = "Tensorflow implementation of Self-Attention GAN"
  7. parser = argparse.ArgumentParser(description=desc)
  8. parser.add_argument('--phase', type=str, default='train', help='train or test ?')
  9. parser.add_argument('--dataset', type=str, default='celebA', help='dataset name')
  10. parser.add_argument('--epoch', type=int, default=10, help='The number of epochs to run')
  11. parser.add_argument('--iteration', type=int, default=10000, help='The number of training iterations')
  12. parser.add_argument('--batch_size', type=int, default=32, help='The size of batch per gpu')
  13. parser.add_argument('--print_freq', type=int, default=500, help='The number of image_print_freqy')
  14. parser.add_argument('--save_freq', type=int, default=500, help='The number of ckpt_save_freq')
  15. parser.add_argument('--g_lr', type=float, default=0.0001, help='learning rate for generator')
  16. parser.add_argument('--d_lr', type=float, default=0.0004, help='learning rate for discriminator')
  17. parser.add_argument('--beta1', type=float, default=0.0, help='beta1 for Adam optimizer')
  18. parser.add_argument('--beta2', type=float, default=0.9, help='beta2 for Adam optimizer')
  19. parser.add_argument('--z_dim', type=int, default=128, help='Dimension of noise vector')
  20. parser.add_argument('--up_sample', type=str2bool, default=True, help='using upsample-conv')
  21. parser.add_argument('--sn', type=str2bool, default=True, help='using spectral norm')
  22. parser.add_argument('--gan_type', type=str, default='hinge', help='[gan / lsgan / wgan-gp / wgan-lp / dragan / hinge]')
  23. parser.add_argument('--ld', type=float, default=10.0, help='The gradient penalty lambda')
  24. parser.add_argument('--n_critic', type=int, default=1, help='The number of critic')
  25. parser.add_argument('--img_size', type=int, default=128, help='The size of image')
  26. parser.add_argument('--sample_num', type=int, default=64, help='The number of sample images')
  27. parser.add_argument('--test_num', type=int, default=10, help='The number of images generated by the test')
  28. parser.add_argument('--checkpoint_dir', type=str, default='checkpoint',
  29. help='Directory name to save the checkpoints')
  30. parser.add_argument('--result_dir', type=str, default='results',
  31. help='Directory name to save the generated images')
  32. parser.add_argument('--log_dir', type=str, default='logs',
  33. help='Directory name to save training logs')
  34. parser.add_argument('--sample_dir', type=str, default='samples',
  35. help='Directory name to save the samples on training')
  36. return check_args(parser.parse_args())
  37. """checking arguments"""
  38. def check_args(args):
  39. # --checkpoint_dir
  40. check_folder(args.checkpoint_dir)
  41. # --result_dir
  42. check_folder(args.result_dir)
  43. # --result_dir
  44. check_folder(args.log_dir)
  45. # --sample_dir
  46. check_folder(args.sample_dir)
  47. # --epoch
  48. try:
  49. assert args.epoch >= 1
  50. except:
  51. print('number of epochs must be larger than or equal to one')
  52. # --batch_size
  53. try:
  54. assert args.batch_size >= 1
  55. except:
  56. print('batch size must be larger than or equal to one')
  57. return args
  58. """main"""
  59. def main():
  60. # parse arguments
  61. args = parse_args()
  62. if args is None:
  63. exit()
  64. # open session
  65. with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
  66. gan = SAGAN(sess, args)
  67. # build graph
  68. gan.build_model()
  69. # show network architecture
  70. show_all_variables()
  71. if args.phase == 'train' :
  72. # launch the graph in a session
  73. gan.train()
  74. # visualize learned generator
  75. gan.visualize_results(args.epoch - 1)
  76. print(" [*] Training finished!")
  77. if args.phase == 'test' :
  78. gan.test()
  79. print(" [*] Test finished!")
  80. if __name__ == '__main__':
  81. main()

四、实现结果

写完上述所有文件之后,运行main.py即可。实验过程比较慢,设置epoch为10,每个epoch内迭代10000次,我的GPU是GTX1060 3G,每个epoch大概需要近4000秒,即1个小时5分钟,所以还是非常耗时的。目前只进行了训练过程,所以只有训练过程的生成样本。

当epoch=0, iter=500时,也就是迭代500次,生成样本为:

当epoch=1,iter=0时,也就是运算了10000次,效果为:

当epoch=2,iter=0时,即运算20000次,效果为:

当epoch=4,iter=0,即运算40000次时,效果为:

当epoch=6,iter=0,即运算60000次时,效果为:

由于loss已经很小了,而且后面的训练效果其实差距也不大,所以也就没有再继续训练了。

五、分析

1. 训练过程中似乎出现了模式崩塌的现象,因为生成的样本都非常类似,这个还需要进一步检查。

2. 之前用DCGAN、WGAN、BEGAN也做过人脸生成,下面来比较一下他们的效果:

DCGAN生成图像的大小为64*64的, 但是可以明显的看到,DCGAN的生成结果中,很多人脸的姿态都非常相似,DCGAN很容易出现模式崩塌现象,而且DCGAN生成的人脸肤色偏黑,且图像中的噪点很多,边缘非常不平滑,生成的效果比较差。

WGAN理论上解决了模式崩塌现象,生成人脸的尺寸为128*128,肤色明显自然了很多,但是生成的效果很差,边界几乎看不清,而且有的图像几乎什么也看不出来,噪点也非常多。

BEGAN生成的人脸尺寸为64*64,它是再DCGAN的基础上扩展的,效果明显好了很多,生成图像偶尔有噪点,人物的五官和轮廓都非常清晰,整体来讲效果不错。

SAGAN生成的人脸尺寸是128*128,生成的结果中,出现了模式崩塌的现象,可能是我对人脸数据集没有做预处理,生成人脸的五官是很清晰,但是五官之外的轮廓就非常糟糕了。

这几种模型,目前来看,我的实验效果是BEGAN最好,可能是因为我对其他数据集没怎么做数据预处理吧。

3. 关于celebA数据集的下载,作者也给了代码,这里也同时给出:

  1. import os
  2. import zipfile
  3. import argparse
  4. import requests
  5. from tqdm import tqdm
  6. parser = argparse.ArgumentParser(description='Download dataset for SAGAN')
  7. parser.add_argument('dataset', metavar='N', type=str, nargs='+', choices=['celebA'],
  8. help='name of dataset to download [celebA]')
  9. def download_file_from_google_drive(id, destination):
  10. URL = "https://docs.google.com/uc?export=download"
  11. session = requests.Session()
  12. response = session.get(URL, params={'id': id}, stream=True)
  13. token = get_confirm_token(response)
  14. if token:
  15. params = {'id': id, 'confirm': token}
  16. response = session.get(URL, params=params, stream=True)
  17. save_response_content(response, destination)
  18. def get_confirm_token(response):
  19. for key, value in response.cookies.items():
  20. if key.startswith('download_warning'):
  21. return value
  22. return None
  23. def save_response_content(response, destination, chunk_size=32 * 1024):
  24. total_size = int(response.headers.get('content-length', 0))
  25. with open(destination, "wb") as f:
  26. for chunk in tqdm(response.iter_content(chunk_size), total=total_size,
  27. unit='B', unit_scale=True, desc=destination):
  28. if chunk: # filter out keep-alive new chunks
  29. f.write(chunk)
  30. def download_celeb_a(dirpath):
  31. data_dir = 'celebA'
  32. if os.path.exists(os.path.join(dirpath, data_dir)):
  33. print('Found Celeb-A - skip')
  34. return
  35. filename, drive_id = "img_align_celeba.zip", "0B7EVK8r0v71pZjFTYXZWM3FlRnM"
  36. save_path = os.path.join(dirpath, filename)
  37. if os.path.exists(save_path):
  38. print('[*] {} already exists'.format(save_path))
  39. else:
  40. download_file_from_google_drive(drive_id, save_path)
  41. zip_dir = ''
  42. with zipfile.ZipFile(save_path) as zf:
  43. zip_dir = zf.namelist()[0]
  44. zf.extractall(dirpath)
  45. os.remove(save_path)
  46. os.rename(os.path.join(dirpath, zip_dir), os.path.join(dirpath, data_dir))
  47. def prepare_data_dir(path='./dataset'):
  48. if not os.path.exists(path):
  49. os.mkdir(path)
  50. if __name__ == '__main__':
  51. args = parser.parse_args()
  52. prepare_data_dir()
  53. if any(name in args.dataset for name in ['CelebA', 'celebA', 'celebA']):
  54. download_celeb_a('./dataset')

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/423227
推荐阅读
相关标签
  

闽ICP备14008679号