当前位置:   article > 正文

style-transfer的实现(tensorflow)_style-transform-tensorflow

style-transform-tensorflow

风格转化是一个很流行的app应用,虽然现在过去风头了,但是自己实现一下也是好的。paper需要自己去解读,下面是图解。


中间是一个空白图片或者噪音图片。然后将空白图片和S表示style、C表示content进行最小损失函数,但是这样训练和验证会加大时间,测试太慢。然后使用如下的网络:


   将网络分成左边Image Transform Net和右侧的Loss Network,左面生成图像的转换,右面进行损失函数的计算,每个特征值的对比。其中,左边的先进性下采样,中间是残差网络,最后上采样是反卷积。其中x和yc是一个。

其中style.py文件如下:

  1. from __future__ import print_function
  2. import sys, os, pdb
  3. sys.path.insert(0, 'src')
  4. import numpy as np, scipy.misc
  5. from optimize import optimize
  6. from argparse import ArgumentParser
  7. from utils import save_img, get_img, exists, list_files
  8. import evaluate
  9. CONTENT_WEIGHT = 7.5e0
  10. STYLE_WEIGHT = 1e2
  11. TV_WEIGHT = 2e2
  12. LEARNING_RATE = 1e-3
  13. NUM_EPOCHS = 2
  14. CHECKPOINT_DIR = 'checkpoints'
  15. CHECKPOINT_ITERATIONS = 2000
  16. VGG_PATH = 'data/imagenet-vgg-verydeep-19.mat'
  17. TRAIN_PATH = 'data/train2014'
  18. BATCH_SIZE = 4
  19. DEVICE = '/gpu:0'
  20. FRAC_GPU = 1
  21. def build_parser():
  22. parser = ArgumentParser()
  23. parser.add_argument('--checkpoint-dir', type=str,
  24. dest='checkpoint_dir', help='dir to save checkpoint in',
  25. metavar='CHECKPOINT_DIR', required=True)
  26. parser.add_argument('--style', type=str,
  27. dest='style', help='style image path',
  28. metavar='STYLE', required=True)
  29. parser.add_argument('--train-path', type=str,
  30. dest='train_path', help='path to training images folder',
  31. metavar='TRAIN_PATH', default=TRAIN_PATH)
  32. parser.add_argument('--test', type=str,
  33. dest='test', help='test image path',
  34. metavar='TEST', default=False)
  35. parser.add_argument('--test-dir', type=str,
  36. dest='test_dir', help='test image save dir',
  37. metavar='TEST_DIR', default=False)
  38. parser.add_argument('--slow', dest='slow', action='store_true',
  39. help='gatys\' approach (for debugging, not supported)',
  40. default=False)
  41. parser.add_argument('--epochs', type=int,
  42. dest='epochs', help='num epochs',
  43. metavar='EPOCHS', default=NUM_EPOCHS)
  44. parser.add_argument('--batch-size', type=int,
  45. dest='batch_size', help='batch size',
  46. metavar='BATCH_SIZE', default=BATCH_SIZE)
  47. parser.add_argument('--checkpoint-iterations', type=int,
  48. dest='checkpoint_iterations', help='checkpoint frequency',
  49. metavar='CHECKPOINT_ITERATIONS',
  50. default=CHECKPOINT_ITERATIONS)
  51. parser.add_argument('--vgg-path', type=str,
  52. dest='vgg_path',
  53. help='path to VGG19 network (default %(default)s)',
  54. metavar='VGG_PATH', default=VGG_PATH)
  55. parser.add_argument('--content-weight', type=float,
  56. dest='content_weight',
  57. help='content weight (default %(default)s)',
  58. metavar='CONTENT_WEIGHT', default=CONTENT_WEIGHT)
  59. parser.add_argument('--style-weight', type=float,
  60. dest='style_weight',
  61. help='style weight (default %(default)s)',
  62. metavar='STYLE_WEIGHT', default=STYLE_WEIGHT)
  63. parser.add_argument('--tv-weight', type=float,
  64. dest='tv_weight',
  65. help='total variation regularization weight (default %(default)s)',
  66. metavar='TV_WEIGHT', default=TV_WEIGHT)
  67. parser.add_argument('--learning-rate', type=float,
  68. dest='learning_rate',
  69. help='learning rate (default %(default)s)',
  70. metavar='LEARNING_RATE', default=LEARNING_RATE)
  71. return parser
  72. def check_opts(opts):
  73. exists(opts.checkpoint_dir, "checkpoint dir not found!")
  74. exists(opts.style, "style path not found!")
  75. exists(opts.train_path, "train path not found!")
  76. if opts.test or opts.test_dir:
  77. exists(opts.test, "test img not found!")
  78. exists(opts.test_dir, "test directory not found!")
  79. exists(opts.vgg_path, "vgg network data not found!")
  80. assert opts.epochs > 0
  81. assert opts.batch_size > 0
  82. assert opts.checkpoint_iterations > 0
  83. assert os.path.exists(opts.vgg_path)
  84. assert opts.content_weight >= 0
  85. assert opts.style_weight >= 0
  86. assert opts.tv_weight >= 0
  87. assert opts.learning_rate >= 0
  88. def _get_files(img_dir):
  89. files = list_files(img_dir)
  90. return [os.path.join(img_dir,x) for x in files]
  91. def main():
  92. parser = build_parser()
  93. options = parser.parse_args()
  94. check_opts(options)
  95. style_target = get_img(options.style)
  96. if not options.slow:
  97. content_targets = _get_files(options.train_path)
  98. elif options.test:
  99. content_targets = [options.test]
  100. kwargs = {
  101. "slow":options.slow,
  102. "epochs":options.epochs,
  103. "print_iterations":options.checkpoint_iterations,
  104. "batch_size":options.batch_size,
  105. "save_path":os.path.join(options.checkpoint_dir,'fns.ckpt'),
  106. "learning_rate":options.learning_rate
  107. }
  108. if options.slow:
  109. if options.epochs < 10:
  110. kwargs['epochs'] = 1000
  111. if options.learning_rate < 1:
  112. kwargs['learning_rate'] = 1e1
  113. args = [
  114. content_targets,
  115. style_target,
  116. options.content_weight,
  117. options.style_weight,
  118. options.tv_weight,
  119. options.vgg_path
  120. ]
  121. for preds, losses, i, epoch in optimize(*args, **kwargs):
  122. style_loss, content_loss, tv_loss, loss = losses
  123. print('Epoch %d, Iteration: %d, Loss: %s' % (epoch, i, loss))
  124. to_print = (style_loss, content_loss, tv_loss)
  125. print('style: %s, content:%s, tv: %s' % to_print)
  126. if options.test:
  127. assert options.test_dir != False
  128. preds_path = '%s/%s_%s.png' % (options.test_dir,epoch,i)
  129. if not options.slow:
  130. ckpt_dir = os.path.dirname(options.checkpoint_dir)
  131. evaluate.ffwd_to_img(options.test,preds_path,
  132. options.checkpoint_dir)
  133. else:
  134. save_img(preds_path, img)
  135. ckpt_dir = options.checkpoint_dir
  136. cmd_text = 'python evaluate.py --checkpoint %s ...' % ckpt_dir
  137. print("Training complete. For evaluation:\n `%s`" % cmd_text)
  138. if __name__ == '__main__':
  139. main()

下面是utils.py工具类的使用:

在里面实现获取图片,缩放图片,保存图片等操作

  1. import scipy.misc, numpy as np, os, sys
  2. def save_img(out_path, img):
  3. img = np.clip(img, 0, 255).astype(np.uint8)
  4. scipy.misc.imsave(out_path, img)
  5. def scale_img(style_path, style_scale):
  6. scale = float(style_scale)
  7. o0, o1, o2 = scipy.misc.imread(style_path, mode='RGB').shape
  8. scale = float(style_scale)
  9. new_shape = (int(o0 * scale), int(o1 * scale), o2)
  10. style_target = _get_img(style_path, img_size=new_shape)
  11. return style_target
  12. def get_img(src, img_size=False):
  13. img = scipy.misc.imread(src, mode='RGB') # misc.imresize(, (256, 256, 3))
  14. if not (len(img.shape) == 3 and img.shape[2] == 3):
  15. img = np.dstack((img,img,img))
  16. print (img.shape)
  17. if img_size != False:
  18. img = scipy.misc.imresize(img, img_size)
  19. return img
  20. def exists(p, msg):
  21. assert os.path.exists(p), msg
  22. def list_files(in_path):
  23. files = []
  24. for (dirpath, dirnames, filenames) in os.walk(in_path):
  25. files.extend(filenames)
  26. break
  27. return files


下面是模型优化的函数,最为重要的函数optimize.py

  1. from __future__ import print_function
  2. import functools
  3. import vgg, pdb, time
  4. import tensorflow as tf, numpy as np, os
  5. import transform
  6. from utils import get_img
  7. STYLE_LAYERS = ('relu1_1', 'relu2_1', 'relu3_1', 'relu4_1', 'relu5_1')
  8. CONTENT_LAYER = 'relu4_2'
  9. DEVICES = 'CUDA_VISIBLE_DEVICES'
  10. # np arr, np arr
  11. def optimize(content_targets, style_target, content_weight, style_weight,
  12. tv_weight, vgg_path, epochs=2, print_iterations=1000,
  13. batch_size=4, save_path='saver/fns.ckpt', slow=False,
  14. learning_rate=1e-3, debug=False):
  15. if slow:
  16. batch_size = 1
  17. mod = len(content_targets) % batch_size
  18. if mod > 0:
  19. print("Train set has been trimmed slightly..")
  20. content_targets = content_targets[:-mod]
  21. style_features = {}
  22. batch_shape = (batch_size,256,256,3)
  23. style_shape = (1,) + style_target.shape
  24. #print(style_shape)
  25. # precompute style features
  26. with tf.Graph().as_default(), tf.device('/cpu:0'), tf.Session() as sess:
  27. style_image = tf.placeholder(tf.float32, shape=style_shape, name='style_image')
  28. style_image_pre = vgg.preprocess(style_image)
  29. net = vgg.net(vgg_path, style_image_pre)
  30. style_pre = np.array([style_target])
  31. for layer in STYLE_LAYERS:
  32. features = net[layer].eval(feed_dict={style_image:style_pre})
  33. features = np.reshape(features, (-1, features.shape[3]))
  34. #print (features.shape)
  35. gram = np.matmul(features.T, features) / features.size
  36. style_features[layer] = gram
  37. with tf.Graph().as_default(), tf.Session() as sess:
  38. X_content = tf.placeholder(tf.float32, shape=batch_shape, name="X_content")
  39. X_pre = vgg.preprocess(X_content)
  40. # precompute content features
  41. content_features = {}
  42. content_net = vgg.net(vgg_path, X_pre)
  43. content_features[CONTENT_LAYER] = content_net[CONTENT_LAYER]
  44. if slow:
  45. preds = tf.Variable(
  46. tf.random_normal(X_content.get_shape()) * 0.256
  47. )
  48. preds_pre = preds
  49. else:
  50. preds = transform.net(X_content/255.0)
  51. preds_pre = vgg.preprocess(preds)
  52. net = vgg.net(vgg_path, preds_pre)
  53. content_size = _tensor_size(content_features[CONTENT_LAYER])*batch_size
  54. assert _tensor_size(content_features[CONTENT_LAYER]) == _tensor_size(net[CONTENT_LAYER])
  55. content_loss = content_weight * (2 * tf.nn.l2_loss(
  56. net[CONTENT_LAYER] - content_features[CONTENT_LAYER]) / content_size
  57. )
  58. style_losses = []
  59. for style_layer in STYLE_LAYERS:
  60. layer = net[style_layer]
  61. bs, height, width, filters = map(lambda i:i.value,layer.get_shape())
  62. size = height * width * filters
  63. feats = tf.reshape(layer, (bs, height * width, filters))
  64. feats_T = tf.transpose(feats, perm=[0,2,1])
  65. grams = tf.matmul(feats_T, feats) / size
  66. style_gram = style_features[style_layer]
  67. style_losses.append(2 * tf.nn.l2_loss(grams - style_gram)/style_gram.size)
  68. style_loss = style_weight * functools.reduce(tf.add, style_losses) / batch_size
  69. # total variation denoising
  70. tv_y_size = _tensor_size(preds[:,1:,:,:])
  71. tv_x_size = _tensor_size(preds[:,:,1:,:])
  72. y_tv = tf.nn.l2_loss(preds[:,1:,:,:] - preds[:,:batch_shape[1]-1,:,:])
  73. x_tv = tf.nn.l2_loss(preds[:,:,1:,:] - preds[:,:,:batch_shape[2]-1,:])
  74. tv_loss = tv_weight*2*(x_tv/tv_x_size + y_tv/tv_y_size)/batch_size
  75. loss = content_loss + style_loss + tv_loss
  76. # overall loss
  77. train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss)
  78. sess.run(tf.global_variables_initializer())
  79. import random
  80. uid = random.randint(1, 100)
  81. print("UID: %s" % uid)
  82. for epoch in range(epochs):
  83. num_examples = len(content_targets)
  84. iterations = 0
  85. while iterations * batch_size < num_examples:
  86. start_time = time.time()
  87. curr = iterations * batch_size
  88. step = curr + batch_size
  89. X_batch = np.zeros(batch_shape, dtype=np.float32)
  90. for j, img_p in enumerate(content_targets[curr:step]):
  91. X_batch[j] = get_img(img_p, (256,256,3)).astype(np.float32)
  92. iterations += 1
  93. assert X_batch.shape[0] == batch_size
  94. feed_dict = {
  95. X_content:X_batch
  96. }
  97. train_step.run(feed_dict=feed_dict)
  98. end_time = time.time()
  99. delta_time = end_time - start_time
  100. if debug:
  101. print("UID: %s, batch time: %s" % (uid, delta_time))
  102. is_print_iter = int(iterations) % print_iterations == 0
  103. if slow:
  104. is_print_iter = epoch % print_iterations == 0
  105. is_last = epoch == epochs - 1 and iterations * batch_size >= num_examples
  106. should_print = is_print_iter or is_last
  107. if should_print:
  108. to_get = [style_loss, content_loss, tv_loss, loss, preds]
  109. test_feed_dict = {
  110. X_content:X_batch
  111. }
  112. tup = sess.run(to_get, feed_dict = test_feed_dict)
  113. _style_loss,_content_loss,_tv_loss,_loss,_preds = tup
  114. losses = (_style_loss, _content_loss, _tv_loss, _loss)
  115. if slow:
  116. _preds = vgg.unprocess(_preds)
  117. else:
  118. saver = tf.train.Saver()
  119. res = saver.save(sess, save_path)
  120. yield(_preds, losses, iterations, epoch)
  121. def _tensor_size(tensor):
  122. from operator import mul
  123. return functools.reduce(mul, (d.value for d in tensor.get_shape()[1:]), 1)

下面是vgg.py

  1. import tensorflow as tf
  2. import numpy as np
  3. import scipy.io
  4. import pdb
  5. MEAN_PIXEL = np.array([ 123.68 , 116.779, 103.939])
  6. def net(data_path, input_image):
  7. layers = (
  8. 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1',
  9. 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
  10. 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3',
  11. 'relu3_3', 'conv3_4', 'relu3_4', 'pool3',
  12. 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3',
  13. 'relu4_3', 'conv4_4', 'relu4_4', 'pool4',
  14. 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3',
  15. 'relu5_3', 'conv5_4', 'relu5_4'
  16. )
  17. data = scipy.io.loadmat(data_path)
  18. mean = data['normalization'][0][0][0]
  19. mean_pixel = np.mean(mean, axis=(0, 1))
  20. weights = data['layers'][0]
  21. net = {}
  22. current = input_image
  23. for i, name in enumerate(layers):
  24. kind = name[:4]
  25. if kind == 'conv':
  26. kernels, bias = weights[i][0][0][0][0]
  27. # matconvnet: weights are [width, height, in_channels, out_channels]
  28. # tensorflow: weights are [height, width, in_channels, out_channels]
  29. kernels = np.transpose(kernels, (1, 0, 2, 3))
  30. bias = bias.reshape(-1)
  31. current = _conv_layer(current, kernels, bias)
  32. elif kind == 'relu':
  33. current = tf.nn.relu(current)
  34. elif kind == 'pool':
  35. current = _pool_layer(current)
  36. net[name] = current
  37. assert len(net) == len(layers)
  38. return net
  39. def _conv_layer(input, weights, bias):
  40. conv = tf.nn.conv2d(input, tf.constant(weights), strides=(1, 1, 1, 1),
  41. padding='SAME')
  42. return tf.nn.bias_add(conv, bias)
  43. def _pool_layer(input):
  44. return tf.nn.max_pool(input, ksize=(1, 2, 2, 1), strides=(1, 2, 2, 1),
  45. padding='SAME')
  46. def preprocess(image):
  47. return image - MEAN_PIXEL
  48. def unprocess(image):
  49. return image + MEAN_PIXEL


然后就是transform.py的转换网络,即生成网络:

  1. import tensorflow as tf, pdb
  2. WEIGHTS_INIT_STDEV = .1
  3. def net(image):
  4. conv1 = _conv_layer(image, 32, 9, 1)
  5. conv2 = _conv_layer(conv1, 64, 3, 2)
  6. conv3 = _conv_layer(conv2, 128, 3, 2)
  7. resid1 = _residual_block(conv3, 3)
  8. resid2 = _residual_block(resid1, 3)
  9. resid3 = _residual_block(resid2, 3)
  10. resid4 = _residual_block(resid3, 3)
  11. resid5 = _residual_block(resid4, 3)
  12. conv_t1 = _conv_tranpose_layer(resid5, 64, 3, 2)
  13. conv_t2 = _conv_tranpose_layer(conv_t1, 32, 3, 2)
  14. conv_t3 = _conv_layer(conv_t2, 3, 9, 1, relu=False)
  15. preds = tf.nn.tanh(conv_t3) * 150 + 255./2
  16. return preds
  17. def _conv_layer(net, num_filters, filter_size, strides, relu=True):
  18. weights_init = _conv_init_vars(net, num_filters, filter_size)
  19. strides_shape = [1, strides, strides, 1]
  20. net = tf.nn.conv2d(net, weights_init, strides_shape, padding='SAME')
  21. net = _instance_norm(net)
  22. if relu:
  23. net = tf.nn.relu(net)
  24. return net
  25. def _conv_tranpose_layer(net, num_filters, filter_size, strides):
  26. weights_init = _conv_init_vars(net, num_filters, filter_size, transpose=True)
  27. batch_size, rows, cols, in_channels = [i.value for i in net.get_shape()]
  28. new_rows, new_cols = int(rows * strides), int(cols * strides)
  29. # new_shape = #tf.pack([tf.shape(net)[0], new_rows, new_cols, num_filters])
  30. new_shape = [batch_size, new_rows, new_cols, num_filters]
  31. tf_shape = tf.stack(new_shape)
  32. strides_shape = [1,strides,strides,1]
  33. net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding='SAME')
  34. net = _instance_norm(net)
  35. return tf.nn.relu(net)
  36. def _residual_block(net, filter_size=3):
  37. tmp = _conv_layer(net, 128, filter_size, 1)
  38. return net + _conv_layer(tmp, 128, filter_size, 1, relu=False)
  39. def _instance_norm(net, train=True):
  40. batch, rows, cols, channels = [i.value for i in net.get_shape()]
  41. var_shape = [channels]
  42. mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
  43. shift = tf.Variable(tf.zeros(var_shape))
  44. scale = tf.Variable(tf.ones(var_shape))
  45. epsilon = 1e-3
  46. normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
  47. return scale * normalized + shift
  48. def _conv_init_vars(net, out_channels, filter_size, transpose=False):
  49. _, rows, cols, in_channels = [i.value for i in net.get_shape()]
  50. if not transpose:
  51. weights_shape = [filter_size, filter_size, in_channels, out_channels]
  52. else:
  53. weights_shape = [filter_size, filter_size, out_channels, in_channels]
  54. weights_init = tf.Variable(tf.truncated_normal(weights_shape, stddev=WEIGHTS_INIT_STDEV, seed=1), dtype=tf.float32)
  55. return weights_init



声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/123005
推荐阅读