当前位置:   article > 正文

AI修图!pix2pix网络介绍与tensorflow实现_pix2pix tensorflow

pix2pix tensorflow

1.引言

    在现实生活当中,除了语言之间的翻译之外,我们也经常会遇到各种图像的“翻译”任务,即给定一张图像,生成目标图像,常见的场景有:图像风格迁移、图像超级分辨率、图像上色、图像去马赛克等。而在现实生活当中,图像翻译任务更常见的场景可能是图像的修图与美化,因此,本文将准备介绍另一个新的图像翻译任务——AI修图,即给定一张图像,让机器自动对该图像进行修图,从而达到一个更加美化的效果。

    本文将利用GAN网络中一个比较经典的模型,即pix2pix模型,该网络采用一种完全监督的方法,即利用完全配对的输入和输出图像训练模型,通过训练好的模型将输入的图像生成指定任务的目标图像。目前该方法是图像翻译任务中完全监督方法里面效果和通用性最好的一个模型,在介绍这个模型的结构之前,可以先来看下作者利用这个网络所做的一些有趣的实验:

  • 图像语义标签——真实图像
  • 白天——夜景 
  • 简笔画上色
  • 黑白图像——彩色图像

具体效果如下图所示 :

2.pix2pix网络介绍

    pix2pix网络是GAN网络中的一种,主要是采用cGAN网络的结构,它依然包括了一个生成器和一个判别器。生成器采用的是一个U-net的结构,其结构有点类似Encoder-decoder,总共包含15层,分别有8层卷积层作为encoder,7层反卷积层(关于反卷积层的概念可以参考这篇博客:反卷积原理不可多得的好文)作为decoder,与传统的encoder-decoder不同的是引入了一个叫做“skip-connect”的技巧,即每一层反卷积层的输入都是:前一层的输出+与该层对称的卷积层的输出,从而保证encoder的信息在decoder时可以不断地被重新记忆,使得生成的图像尽可能保留原图像的一些信息。具体如下图所示:

    对于判别器,pix2pix采用的是一个6层的卷积网络,其思想与传统的判别器类似,只是有以下两点比较特别的地方:

  • 将输入图像与目标图像进行堆叠:pix2pix的判别器的输入不仅仅只是真实图像与生成图像,还将输入图像也一起作为输入的一部分,即将输入图像与真实图像、生成图像分别在第3通道进行拼接,然后一起作为输入传入判别器模型。
  • 引入PatchGAN的思想:传统的判别器是对一张图像输出一个softmax概率值,而pix2pix的判别器则引入了PatchGAN的思想,将一张图像通过多层卷积层后最终输出了一个比较小的矩阵,比如30*30,然后对每个像素点输出一个softmax概率值,这就相当于对一张输入图像切分为很多小块,对每一小块分别计算一个输出。作者表示引入PatchGAN其实可以起到一种类似计算风格或纹理损失的效果。

其具体的结构如下图所示:

 3.模型的损失函数

    pix2pix的损失函数除了标准的GAN网络的损失函数之外,还引入了L1的损失函数。记x为输入的图像,y为真实图像(输出图像),G为生成器,D为判别器,则标准的GAN网络的损失函数为:

                      LcGAN(G,D)=Ex,yPdata(x,y)[logD(x,y)]+ExPdata(x)[log(1D(x,G(x))]

对G施加L1惩罚,即:

                                                    LL1(G)=Ex,yPdata(x,y)[yG(x)]

因此,最终GAN网络的损失函数为:

                                             G=argminGmaxDLcGAN(G,D)+λLL1(G)

这样一来,标准的GAN损失负责捕捉图像高频特征,而L1损失则负责捕捉低频特征,使得生成结果既真实且清晰。

4.pix2pix的tensorflow实现

    本文利用pix2pix进行AI修图,采用的框架是tensorflow实现。首先是将输入图像和真实图像(输出图像)分别压缩至256*256的规格,并将两者拼接在一起,形式如下:

其中,左侧为修图前的原图,右侧为人工修图的结果,总共采集了1700对这样的图像作为模型的训练集,模型的主要代码模块如下:

  1. import tensorflow as tf
  2. import numpy as np
  3. from PIL import Image
  4. from data_loader import get_batch_data
  5. import os
  6. import re
  7. class pix2pix(object):
  8. def __init__(self, sess, batch_size, L1_lambda):
  9. """
  10. :param sess: tf.Session
  11. :param batch_size: batch_size. [int]
  12. :param L1_lambda: L1_loss lambda. [int]
  13. """
  14. self.sess = sess
  15. self.k_initializer = tf.random_normal_initializer(0, 0.02)
  16. self.g_initializer = tf.random_normal_initializer(1, 0.02)
  17. self.L1_lambda = L1_lambda
  18. self.bulid_model()
  19. def bulid_model(self):
  20. """
  21. 初始化模型
  22. :return:
  23. """
  24. # init variable
  25. self.x_ = tf.placeholder(dtype=tf.float32, shape=[None, 256, 256, 3], name='x')
  26. self.y_ = tf.placeholder(dtype=tf.float32, shape=[None, 256, 256, 3], name='y')
  27. # generator
  28. self.g = self.generator(self.x_)
  29. # discriminator
  30. self.d_real = self.discriminator(self.x_, self.y_)
  31. self.d_fake = self.discriminator(self.x_, self.g, reuse=True)
  32. # loss
  33. self.loss_g, self.loss_d = self.loss(self.d_real, self.d_fake, self.y_, self.g)
  34. # summary
  35. tf.summary.scalar("loss_g", self.loss_g)
  36. tf.summary.scalar("loss_d", self.loss_d)
  37. self.merged = tf.summary.merge_all()
  38. # vars
  39. self.vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
  40. self.vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
  41. # saver
  42. self.saver = tf.train.Saver()
  43. def discriminator(self, x, y, reuse=None):
  44. """
  45. 判别器
  46. :param x: 输入图像. [tensor]
  47. :param y: 目标图像. [tensor]
  48. :param reuse: reuse or not. [boolean]
  49. :return:
  50. """
  51. with tf.variable_scope('discriminator', reuse=reuse):
  52. x = tf.concat([x, y], axis=3)
  53. h0 = self.lrelu(self.d_conv(x, 64, 2)) # 128 128 64
  54. h0 = self.d_conv(h0, 128, 2)
  55. h0 = self.lrelu(self.batch_norm(h0)) # 64 64 128
  56. h0 = self.d_conv(h0, 256, 2)
  57. h0 = self.lrelu(self.batch_norm(h0)) # 32 32 256
  58. h0 = self.d_conv(h0, 512, 1)
  59. h0 = self.lrelu(self.batch_norm(h0)) # 31 31 512
  60. h0 = self.d_conv(h0, 1, 1) # 30 30 1
  61. h0 = tf.nn.sigmoid(h0)
  62. return h0
  63. def generator(self, x):
  64. """
  65. 生成器
  66. :param x: 输入图像. [tensor]
  67. :return: h0,生成的图像. [tensor]
  68. """
  69. with tf.variable_scope('generator', reuse=None):
  70. layers = []
  71. h0 = self.g_conv(x, 64)
  72. layers.append(h0)
  73. for filters in [128, 256, 512, 512, 512, 512, 512]: # [128, 256, 512, 512, 512, 512, 512]
  74. h0 = self.lrelu(layers[-1])
  75. h0 = self.g_conv(h0, filters)
  76. h0 = self.batch_norm(h0)
  77. layers.append(h0)
  78. encode_layers_num = len(layers) # 8
  79. for i, filters in enumerate([512, 512, 512, 512, 256, 128, 64]): # [512, 512, 512, 512, 256, 128, 64]
  80. skip_layer = encode_layers_num - i - 1
  81. if i == 0:
  82. inputs = layers[-1]
  83. else:
  84. inputs = tf.concat([layers[-1], layers[skip_layer]], axis=3)
  85. h0 = tf.nn.relu(inputs)
  86. h0 = self.g_deconv(h0, filters)
  87. h0 = self.batch_norm(h0)
  88. if i < 3:
  89. h0 = tf.nn.dropout(h0, keep_prob=0.5)
  90. layers.append(h0)
  91. inputs = tf.concat([layers[-1], layers[0]], axis=3)
  92. h0 = tf.nn.relu(inputs)
  93. h0 = self.g_deconv(h0, 3)
  94. h0 = tf.nn.tanh(h0, name='g')
  95. return h0
  96. def loss(self, d_real, d_fake, y, g):
  97. """
  98. 定义损失函数
  99. :param d_real: 真实图像判别器的输出. [tensor]
  100. :param d_fake: 生成图像判别器的输出. [tensor]
  101. :param y: 目标图像. [tensor]
  102. :param g: 生成图像. [tensor]
  103. :return: loss_g, loss_d, 分别对应生成器的损失函数和判别器的损失函数
  104. """
  105. loss_d_real = tf.reduce_mean(self.sigmoid_cross_entropy_with_logits(d_real, tf.ones_like(d_real)))
  106. loss_d_fake = tf.reduce_mean(self.sigmoid_cross_entropy_with_logits(d_fake, tf.zeros_like(d_fake)))
  107. loss_d = loss_d_real + loss_d_fake
  108. loss_g_gan = tf.reduce_mean(self.sigmoid_cross_entropy_with_logits(d_fake, tf.ones_like(d_fake)))
  109. loss_g_l1 = tf.reduce_mean(tf.abs(y - g))
  110. loss_g = loss_g_gan + loss_g_l1 * self.L1_lambda
  111. return loss_g, loss_d
  112. def lrelu(self, x, leak=0.2):
  113. """
  114. lrelu函数
  115. :param x:
  116. :param leak:
  117. :return:
  118. """
  119. return tf.maximum(x, leak * x)
  120. def d_conv(self, inputs, filters, strides):
  121. """
  122. 判别器卷积层
  123. :param inputs: 输入. [tensor]
  124. :param filters: 输出通道数. [int]
  125. :param strides: 卷积核步伐. [int]
  126. :return:
  127. """
  128. padded = tf.pad(inputs, [[0, 0], [1, 1], [1, 1], [0, 0]], mode='CONSTANT')
  129. return tf.layers.conv2d(padded,
  130. kernel_size=4,
  131. filters=filters,
  132. strides=strides,
  133. padding='valid',
  134. kernel_initializer=self.k_initializer)
  135. def g_conv(self, inputs, filters):
  136. """
  137. 生成器卷积层
  138. :param inputs: 输入. [tensor]
  139. :param filters: 输出通道数. [int]
  140. :return:
  141. """
  142. return tf.layers.conv2d(inputs,
  143. kernel_size=4,
  144. filters=filters,
  145. strides=2,
  146. padding='same',
  147. kernel_initializer=self.k_initializer)
  148. def g_deconv(self, inputs, filters):
  149. """
  150. 生成器反卷积层
  151. :param inputs: 输入. [tensor]
  152. :param filters: 输出通道数. [int]
  153. :return:
  154. """
  155. return tf.layers.conv2d_transpose(inputs,
  156. kernel_size=4,
  157. filters=filters,
  158. strides=2,
  159. padding='same',
  160. kernel_initializer=self.k_initializer)
  161. def batch_norm(self, inputs):
  162. """
  163. 批标准化函数
  164. :param inputs: 输入. [tensor]
  165. :return:
  166. """
  167. return tf.layers.batch_normalization(inputs,
  168. axis=3,
  169. epsilon=1e-5,
  170. momentum=0.1,
  171. training=True,
  172. gamma_initializer=self.g_initializer)
  173. def sigmoid_cross_entropy_with_logits(self, x, y):
  174. """
  175. 交叉熵函数
  176. :param x:
  177. :param y:
  178. :return:
  179. """
  180. return tf.nn.sigmoid_cross_entropy_with_logits(logits=x,
  181. labels=y)
  182. def train(self, images, epoch, batch_size):
  183. """
  184. 训练函数
  185. :param images: 图像路径列表. [list]
  186. :param epoch: 迭代次数. [int]
  187. :param batch_size: batch_size. [int]
  188. :return:
  189. """
  190. # optimizer
  191. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  192. with tf.control_dependencies(update_ops):
  193. optim_d = tf.train.AdamOptimizer(learning_rate=0.0002,
  194. beta1=0.5
  195. ).minimize(self.loss_d, var_list=self.vars_d)
  196. optim_g = tf.train.AdamOptimizer(learning_rate=0.0002,
  197. beta1=0.5
  198. ).minimize(self.loss_g, var_list=self.vars_g)
  199. # init variables
  200. init_op = tf.global_variables_initializer()
  201. self.sess.run(init_op)
  202. self.writer = tf.summary.FileWriter("./log", self.sess.graph)
  203. # training
  204. for i in range(epoch):
  205. # 获取图像列表
  206. print("Epoch:%d/%d:" % ((i + 1), epoch))
  207. batch_num = int(np.ceil(len(images) / batch_size))
  208. # batch_list = np.array_split(random.sample(images, len(images)), batch_num)
  209. batch_list = np.array_split(images, batch_num)
  210. # 训练生成器和判别器
  211. for j in range(len(batch_list)):
  212. batch_x, batch_y = get_batch_data(batch_list[j])
  213. _, loss_d = self.sess.run([optim_d, self.loss_d],
  214. feed_dict={self.x_: batch_x, self.y_: batch_y})
  215. _, loss_g = self.sess.run([optim_g, self.loss_g],
  216. feed_dict={self.x_: batch_x, self.y_: batch_y})
  217. print("%d/%d -loss_d:%.4f -loss_g:%.4f" % ((j + 1), len(batch_list), loss_d, loss_g))
  218. # 保存损失值
  219. summary = self.sess.run(self.merged,
  220. feed_dict={self.x_: batch_x, self.y_: batch_y})
  221. self.writer.add_summary(summary, global_step=i)
  222. # 保存模型,每10次保存一次
  223. if (i + 1) % 10 == 0:
  224. self.saver.save(self.sess, './checkpoint/epoch_%d.ckpt' % (i + 1))
  225. # 测试,每循环一次测试一次
  226. if (i + 1) % 1 == 0:
  227. # 对训练集最后一张图像进行测试
  228. train_save_path = os.path.join('./result/train',
  229. re.sub('.jpg',
  230. '',
  231. os.path.basename(images[-1])
  232. ) + '_' + str(i + 1) + '.jpg'
  233. )
  234. train_g = self.sess.run(self.g,
  235. feed_dict={self.x_: batch_x}
  236. )
  237. train_g = 255 * (np.array(train_g[0] + 1) / 2)
  238. im = Image.fromarray(np.uint8(train_g))
  239. im.save(train_save_path)
  240. # 对验证集进行测试
  241. img = np.zeros((256, 256 * 3, 3))
  242. val_img_path = np.array(['./data/val/color/10901.jpg'])
  243. batch_x, batch_y = get_batch_data(val_img_path)
  244. val_g = self.sess.run(self.g, feed_dict={self.x_: batch_x})
  245. img[:, :256, :] = 255 * (np.array(batch_x + 1) / 2)
  246. img[:, 256:256 * 2, :] = 255 * (np.array(batch_y + 1) / 2)
  247. img[:, 256 * 2:, :] = 255 * (np.array(val_g[0] + 1) / 2)
  248. img = Image.fromarray(np.uint8(img))
  249. img.save('./result/val/10901_%d.jpg' % (i + 1))
  250. def save_img(self, g, data, save_path):
  251. """
  252. 保存图像
  253. :param g: 生成的图像. [array]
  254. :param data: 测试数据. [list]
  255. :param save_path: 保存路径. [str]
  256. :return:
  257. """
  258. if len(data) == 1:
  259. img = np.zeros((256, 256 * 2, 3))
  260. img[:, :256, :] = 255* (np.array(data[0] + 1) / 2)
  261. img[:, 256:, :] = 255 * (np.array(g[0] + 1) / 2)
  262. else:
  263. img = np.zeros((256, 256 * 3, 3))
  264. img[:, :256, :] = 255 * (np.array(data[0] + 1) / 2)
  265. img[:, 256:256 * 2, :] = 255 * (np.array(data[1] + 1) / 2)
  266. img[:, 256 * 2:, :] = 255 * (np.array(g[0] + 1) / 2)
  267. im = Image.fromarray(np.uint8(img))
  268. im.save(os.path.join('./result/test', os.path.basename(save_path)))
  269. def test(self, images, batch_size=1, save_path=None, mode=None):
  270. """
  271. 测试函数
  272. :param images: 测试图像列表. [list]
  273. :param batch_size: batch_size. [int]
  274. :param save_path: 保存路径
  275. :return:
  276. """
  277. # init variables
  278. init_op = tf.global_variables_initializer()
  279. self.sess.run(init_op)
  280. # load model
  281. self.saver.restore(self.sess,
  282. tf.train.latest_checkpoint('./checkpoint')
  283. )
  284. # test
  285. if mode != 'orig':
  286. for j in range(len(images)):
  287. batch_x, batch_y = get_batch_data(np.array([images[j]]))
  288. g = self.sess.run(self.g, feed_dict={self.x_: batch_x})
  289. if save_path == None:
  290. self.save_img(g,
  291. data=[batch_x[0], batch_y[0]],
  292. save_path=images[j]
  293. )
  294. else:
  295. self.save_img(g,
  296. data=[batch_x[0], batch_y[0]],
  297. save_path=save_path
  298. )
  299. else:
  300. for j in range(len(images)):
  301. batch_x = get_batch_data(np.array([images[j]]), mode=mode)
  302. g = self.sess.run(self.g, feed_dict={self.x_: batch_x})
  303. batch_x = 255 * (np.array(batch_x[0] + 1) / 2)
  304. g = 255 * (np.array(g[0] + 1) / 2)
  305. img = np.hstack((batch_x, g))
  306. im = Image.fromarray(np.uint8(img))
  307. im.save(os.path.join('./result/test', os.path.basename(images[j])))

    最终经过训练40个epoch后,判别器和生成器的损失函数均达到了平衡状态,因此,对训练过程进行了终止,如下图所示:

5.模型的效果

    利用训练40个epoch后的模型对测试集进行测试,得到模型最终的效果如下:

其中,从左到右分别对应原图、人工修图、AI修图,可以发现,AI修图的结果会使得色彩更加艳丽,并且修图的效果比人工修图更加真实一点,本文也利用训练好的模型对任意规格的高清图像进行了测试,得到效果如下:

左边是从百度上直接下载下来的两张风景图,右边是本文训练出来的模型修图后的结果,可以发现,虽然这两张原图的已经是经过p图之后的结果,但是用AI修图后在亮度、色彩对比度等方面还是有进一步的提升,模型的泛化效果还是蛮不错滴!

    最后,大概讲一下模型的缺点吧,pix2pix虽然通用性很强,但是模型能否收敛对数据的质量要求很高,如果数据质量比较差的话,则训练出来的模型效果就比较差,笔者最开始没有对数据进行清洗,因此训练出来的效果比较模糊,另外,pix2pix要求必须是严格的配对数据,因此,对数据的要求更加苛刻,如果对这方面比较感兴趣的朋友,也可以考虑一下非监督学习方面的模型,比如WESPE模型等。以下是原论文的地址和作者的pytorch实现:


招聘信息:

熊猫书院算法工程师:

https://www.lagou.com/jobs/4842081.html

希望对深度学习算法感兴趣的小伙伴们可以加入我们,一起改变教育! 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/363348
推荐阅读
相关标签
  

闽ICP备14008679号