当前位置:   article > 正文

深度有趣 | 30 快速图像风格迁移

深度有趣快速风格化迁移

简介

使用TensorFlow实现快速图像风格迁移(Fast Neural Style Transfer)

原理

在之前介绍的图像风格迁移中,我们根据内容图片和风格图片优化输入图片,使得内容损失函数和风格损失函数尽可能小

和DeepDream一样,属于网络参数不变,根据损失函数调整输入数据,因此每生成一张图片都相当于训练一个模型,需要很长时间

训练模型需要很长时间,而使用训练好的模型进行推断则很快

使用快速图像风格迁移可大大缩短生成一张迁移图片所需的时间,其模型结构如下,包括转换网络和损失网络

风格图片是固定的,而内容图片是可变的输入,因此以上模型用于将任意图片快速转换为指定风格的图片

  • 转换网络:参数需要训练,将内容图片转换成迁移图片
  • 损失网络:计算迁移图片和风格图片之间的风格损失,以及迁移图片和原始内容图片之间的内容损失

经过训练后,转换网络所生成的迁移图片,在内容上和输入的内容图片相似,在风格上和指定的风格图片相似

进行推断时,仅使用转换网络,输入内容图片,即可得到对应的迁移图片

如果有多个风格图片,对每个风格分别训练一个模型即可

实现

基于以下两个项目进行修改,github.com/lengstrom/f…github.com/hzy46/fast-…

依然通过之前用过的imagenet-vgg-verydeep-19.mat计算内容损失函数和风格损失函数

需要一些图片作为输入的内容图片,对图片具体内容没有任何要求,也不需要任何标注,这里选择使用MSCOCO数据集的train2014部分,cocodataset.org/#download,共82612张图片

加载库

  1. # -*- coding: utf-8 -*-
  2. import tensorflow as tf
  3. import numpy as np
  4. import cv2
  5. from imageio import imread, imsave
  6. import scipy.io
  7. import os
  8. import glob
  9. from tqdm import tqdm
  10. import matplotlib.pyplot as plt
  11. %matplotlib inline
  12. 复制代码

查看风格图片,共10张

  1. style_images = glob.glob('styles/*.jpg')
  2. print(style_images)
  3. 复制代码

加载内容图片,去掉黑白图片,处理成指定大小,暂时不进行归一化,像素值范围为0至255之间

  1. def resize_and_crop(image, image_size):
  2. h = image.shape[0]
  3. w = image.shape[1]
  4. if h > w:
  5. image = image[h // 2 - w // 2: h // 2 + w // 2, :, :]
  6. else:
  7. image = image[:, w // 2 - h // 2: w // 2 + h // 2, :]
  8. image = cv2.resize(image, (image_size, image_size))
  9. return image
  10. X_data = []
  11. image_size = 256
  12. paths = glob.glob('train2014/*.jpg')
  13. for i in tqdm(range(len(paths))):
  14. path = paths[i]
  15. image = imread(path)
  16. if len(image.shape) < 3:
  17. continue
  18. X_data.append(resize_and_crop(image, image_size))
  19. X_data = np.array(X_data)
  20. print(X_data.shape)
  21. 复制代码

加载vgg19模型,并定义一个函数,对于给定的输入,返回vgg19各个层的输出值,就像在GAN中那样,通过variable_scope重用实现网络的重用

  1. vgg = scipy.io.loadmat('imagenet-vgg-verydeep-19.mat')
  2. vgg_layers = vgg['layers']
  3. def vgg_endpoints(inputs, reuse=None):
  4. with tf.variable_scope('endpoints', reuse=reuse):
  5. def _weights(layer, expected_layer_name):
  6. W = vgg_layers[0][layer][0][0][2][0][0]
  7. b = vgg_layers[0][layer][0][0][2][0][1]
  8. layer_name = vgg_layers[0][layer][0][0][0][0]
  9. assert layer_name == expected_layer_name
  10. return W, b
  11. def _conv2d_relu(prev_layer, layer, layer_name):
  12. W, b = _weights(layer, layer_name)
  13. W = tf.constant(W)
  14. b = tf.constant(np.reshape(b, (b.size)))
  15. return tf.nn.relu(tf.nn.conv2d(prev_layer, filter=W, strides=[1, 1, 1, 1], padding='SAME') + b)
  16. def _avgpool(prev_layer):
  17. return tf.nn.avg_pool(prev_layer, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  18. graph = {}
  19. graph['conv1_1'] = _conv2d_relu(inputs, 0, 'conv1_1')
  20. graph['conv1_2'] = _conv2d_relu(graph['conv1_1'], 2, 'conv1_2')
  21. graph['avgpool1'] = _avgpool(graph['conv1_2'])
  22. graph['conv2_1'] = _conv2d_relu(graph['avgpool1'], 5, 'conv2_1')
  23. graph['conv2_2'] = _conv2d_relu(graph['conv2_1'], 7, 'conv2_2')
  24. graph['avgpool2'] = _avgpool(graph['conv2_2'])
  25. graph['conv3_1'] = _conv2d_relu(graph['avgpool2'], 10, 'conv3_1')
  26. graph['conv3_2'] = _conv2d_relu(graph['conv3_1'], 12, 'conv3_2')
  27. graph['conv3_3'] = _conv2d_relu(graph['conv3_2'], 14, 'conv3_3')
  28. graph['conv3_4'] = _conv2d_relu(graph['conv3_3'], 16, 'conv3_4')
  29. graph['avgpool3'] = _avgpool(graph['conv3_4'])
  30. graph['conv4_1'] = _conv2d_relu(graph['avgpool3'], 19, 'conv4_1')
  31. graph['conv4_2'] = _conv2d_relu(graph['conv4_1'], 21, 'conv4_2')
  32. graph['conv4_3'] = _conv2d_relu(graph['conv4_2'], 23, 'conv4_3')
  33. graph['conv4_4'] = _conv2d_relu(graph['conv4_3'], 25, 'conv4_4')
  34. graph['avgpool4'] = _avgpool(graph['conv4_4'])
  35. graph['conv5_1'] = _conv2d_relu(graph['avgpool4'], 28, 'conv5_1')
  36. graph['conv5_2'] = _conv2d_relu(graph['conv5_1'], 30, 'conv5_2')
  37. graph['conv5_3'] = _conv2d_relu(graph['conv5_2'], 32, 'conv5_3')
  38. graph['conv5_4'] = _conv2d_relu(graph['conv5_3'], 34, 'conv5_4')
  39. graph['avgpool5'] = _avgpool(graph['conv5_4'])
  40. return graph
  41. 复制代码

选择一张风格图,减去通道颜色均值后,得到风格图片在vgg19各个层的输出值,计算四个风格层对应的Gram矩阵

  1. style_index = 1
  2. X_style_data = resize_and_crop(imread(style_images[style_index]), image_size)
  3. X_style_data = np.expand_dims(X_style_data, 0)
  4. print(X_style_data.shape)
  5. MEAN_VALUES = np.array([123.68, 116.779, 103.939]).reshape((1, 1, 1, 3))
  6. X_style = tf.placeholder(dtype=tf.float32, shape=X_style_data.shape, name='X_style')
  7. style_endpoints = vgg_endpoints(X_style - MEAN_VALUES)
  8. STYLE_LAYERS = ['conv1_2', 'conv2_2', 'conv3_3', 'conv4_3']
  9. style_features = {}
  10. sess = tf.Session()
  11. for layer_name in STYLE_LAYERS:
  12. features = sess.run(style_endpoints[layer_name], feed_dict={X_style: X_style_data})
  13. features = np.reshape(features, (-1, features.shape[3]))
  14. gram = np.matmul(features.T, features) / features.size
  15. style_features[layer_name] = gram
  16. 复制代码

定义转换网络,典型的卷积、残差、逆卷积结构,内容图片输入之前也需要减去通道颜色均值

  1. batch_size = 4
  2. X = tf.placeholder(dtype=tf.float32, shape=[None, None, None, 3], name='X')
  3. k_initializer = tf.truncated_normal_initializer(0, 0.1)
  4. def relu(x):
  5. return tf.nn.relu(x)
  6. def conv2d(inputs, filters, kernel_size, strides):
  7. p = int(kernel_size / 2)
  8. h0 = tf.pad(inputs, [[0, 0], [p, p], [p, p], [0, 0]], mode='reflect')
  9. return tf.layers.conv2d(inputs=h0, filters=filters, kernel_size=kernel_size, strides=strides, padding='valid', kernel_initializer=k_initializer)
  10. def deconv2d(inputs, filters, kernel_size, strides):
  11. shape = tf.shape(inputs)
  12. height, width = shape[1], shape[2]
  13. h0 = tf.image.resize_images(inputs, [height * strides * 2, width * strides * 2], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  14. return conv2d(h0, filters, kernel_size, strides)
  15. def instance_norm(inputs):
  16. return tf.contrib.layers.instance_norm(inputs)
  17. def residual(inputs, filters, kernel_size):
  18. h0 = relu(conv2d(inputs, filters, kernel_size, 1))
  19. h0 = conv2d(h0, filters, kernel_size, 1)
  20. return tf.add(inputs, h0)
  21. with tf.variable_scope('transformer', reuse=None):
  22. h0 = tf.pad(X - MEAN_VALUES, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='reflect')
  23. h0 = relu(instance_norm(conv2d(h0, 32, 9, 1)))
  24. h0 = relu(instance_norm(conv2d(h0, 64, 3, 2)))
  25. h0 = relu(instance_norm(conv2d(h0, 128, 3, 2)))
  26. for i in range(5):
  27. h0 = residual(h0, 128, 3)
  28. h0 = relu(instance_norm(deconv2d(h0, 64, 3, 2)))
  29. h0 = relu(instance_norm(deconv2d(h0, 32, 3, 2)))
  30. h0 = tf.nn.tanh(instance_norm(conv2d(h0, 3, 9, 1)))
  31. h0 = (h0 + 1) / 2 * 255.
  32. shape = tf.shape(h0)
  33. g = tf.slice(h0, [0, 10, 10, 0], [-1, shape[1] - 20, shape[2] - 20, -1], name='g')
  34. 复制代码

将转换网络的输出即迁移图片,以及原始内容图片都输入到vgg19,得到各自对应层的输出,计算内容损失函数

  1. CONTENT_LAYER = 'conv3_3'
  2. content_endpoints = vgg_endpoints(X - MEAN_VALUES, True)
  3. g_endpoints = vgg_endpoints(g - MEAN_VALUES, True)
  4. def get_content_loss(endpoints_x, endpoints_y, layer_name):
  5. x = endpoints_x[layer_name]
  6. y = endpoints_y[layer_name]
  7. return 2 * tf.nn.l2_loss(x - y) / tf.to_float(tf.size(x))
  8. content_loss = get_content_loss(content_endpoints, g_endpoints, CONTENT_LAYER)
  9. 复制代码

根据迁移图片和风格图片在指定风格层的输出,计算风格损失函数

  1. style_loss = []
  2. for layer_name in STYLE_LAYERS:
  3. layer = g_endpoints[layer_name]
  4. shape = tf.shape(layer)
  5. bs, height, width, channel = shape[0], shape[1], shape[2], shape[3]
  6. features = tf.reshape(layer, (bs, height * width, channel))
  7. gram = tf.matmul(tf.transpose(features, (0, 2, 1)), features) / tf.to_float(height * width * channel)
  8. style_gram = style_features[layer_name]
  9. style_loss.append(2 * tf.nn.l2_loss(gram - style_gram) / tf.to_float(tf.size(layer)))
  10. style_loss = tf.reduce_sum(style_loss)
  11. 复制代码

计算全变差正则,得到总的损失函数

  1. def get_total_variation_loss(inputs):
  2. h = inputs[:, :-1, :, :] - inputs[:, 1:, :, :]
  3. w = inputs[:, :, :-1, :] - inputs[:, :, 1:, :]
  4. return tf.nn.l2_loss(h) / tf.to_float(tf.size(h)) + tf.nn.l2_loss(w) / tf.to_float(tf.size(w))
  5. total_variation_loss = get_total_variation_loss(g)
  6. content_weight = 1
  7. style_weight = 250
  8. total_variation_weight = 0.01
  9. loss = content_weight * content_loss + style_weight * style_loss + total_variation_weight * total_variation_loss
  10. 复制代码

定义优化器,通过调整转换网络中的参数降低总损失

  1. vars_t = [var for var in tf.trainable_variables() if var.name.startswith('transformer')]
  2. optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss, var_list=vars_t)
  3. 复制代码

训练模型,每轮训练结束后,用一张测试图片进行测试,并且将一些tensor的值写入events文件,便于使用tensorboard查看

  1. style_name = style_images[style_index]
  2. style_name = style_name[style_name.find('/') + 1:].rstrip('.jpg')
  3. OUTPUT_DIR = 'samples_%s' % style_name
  4. if not os.path.exists(OUTPUT_DIR):
  5. os.mkdir(OUTPUT_DIR)
  6. tf.summary.scalar('losses/content_loss', content_loss)
  7. tf.summary.scalar('losses/style_loss', style_loss)
  8. tf.summary.scalar('losses/total_variation_loss', total_variation_loss)
  9. tf.summary.scalar('losses/loss', loss)
  10. tf.summary.scalar('weighted_losses/weighted_content_loss', content_weight * content_loss)
  11. tf.summary.scalar('weighted_losses/weighted_style_loss', style_weight * style_loss)
  12. tf.summary.scalar('weighted_losses/weighted_total_variation_loss', total_variation_weight * total_variation_loss)
  13. tf.summary.image('transformed', g)
  14. tf.summary.image('origin', X)
  15. summary = tf.summary.merge_all()
  16. writer = tf.summary.FileWriter(OUTPUT_DIR)
  17. sess.run(tf.global_variables_initializer())
  18. losses = []
  19. epochs = 2
  20. X_sample = imread('sjtu.jpg')
  21. h_sample = X_sample.shape[0]
  22. w_sample = X_sample.shape[1]
  23. for e in range(epochs):
  24. data_index = np.arange(X_data.shape[0])
  25. np.random.shuffle(data_index)
  26. X_data = X_data[data_index]
  27. for i in tqdm(range(X_data.shape[0] // batch_size)):
  28. X_batch = X_data[i * batch_size: i * batch_size + batch_size]
  29. ls_, _ = sess.run([loss, optimizer], feed_dict={X: X_batch})
  30. losses.append(ls_)
  31. if i > 0 and i % 20 == 0:
  32. writer.add_summary(sess.run(summary, feed_dict={X: X_batch}), e * X_data.shape[0] // batch_size + i)
  33. writer.flush()
  34. print('Epoch %d Loss %f' % (e, np.mean(losses)))
  35. losses = []
  36. gen_img = sess.run(g, feed_dict={X: [X_sample]})[0]
  37. gen_img = np.clip(gen_img, 0, 255)
  38. result = np.zeros((h_sample, w_sample * 2, 3))
  39. result[:, :w_sample, :] = X_sample / 255.
  40. result[:, w_sample:, :] = gen_img[:h_sample, :w_sample, :] / 255.
  41. plt.axis('off')
  42. plt.imshow(result)
  43. plt.show()
  44. imsave(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % e), result)
  45. 复制代码

保存模型

  1. saver = tf.train.Saver()
  2. saver.save(sess, os.path.join(OUTPUT_DIR, 'fast_style_transfer'))
  3. 复制代码

测试图片依旧是之前用过的交大庙门

风格迁移结果

训练过程中可以使用tensorboard查看训练过程

  1. tensorboard --logdir=samples_starry
  2. 复制代码

在单机上使用以下代码即可快速完成风格迁移,在CPU上也只需要10秒左右

  1. # -*- coding: utf-8 -*-
  2. import tensorflow as tf
  3. import numpy as np
  4. from imageio import imread, imsave
  5. import os
  6. import time
  7. def the_current_time():
  8. print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))))
  9. style = 'wave'
  10. model = 'samples_%s' % style
  11. content_image = 'sjtu.jpg'
  12. result_image = 'sjtu_%s.jpg' % style
  13. X_image = imread(content_image)
  14. sess = tf.Session()
  15. sess.run(tf.global_variables_initializer())
  16. saver = tf.train.import_meta_graph(os.path.join(model, 'fast_style_transfer.meta'))
  17. saver.restore(sess, tf.train.latest_checkpoint(model))
  18. graph = tf.get_default_graph()
  19. X = graph.get_tensor_by_name('X:0')
  20. g = graph.get_tensor_by_name('transformer/g:0')
  21. the_current_time()
  22. gen_img = sess.run(g, feed_dict={X: [X_image]})[0]
  23. gen_img = np.clip(gen_img, 0, 255) / 255.
  24. imsave(result_image, gen_img)
  25. the_current_time()
  26. 复制代码

对于其他风格图片,用相同方法训练对应模型即可

参考

视频讲解课程

深度有趣(一)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/530337
推荐阅读
相关标签
  

闽ICP备14008679号