当前位置:   article > 正文

Tensorflow实现图片StyleTransfer_style transfer数据集

style transfer数据集

1.效果展示:

原图:

风格图:                                                   

二. 数据集为8000多张图片,训练一个模型,指定一种训练风格的图片

数据集链接:训练数据,8W多 12G蛮大的
http://msvocds.blob.core.windows.net/coco2014/train2014.zip

训练代码:

  1. from __future__ import print_function
  2. import sys, os, pdb
  3. import numpy as np
  4. import scipy.misc
  5. from src.optimize import optimize
  6. from argparse import ArgumentParser
  7. from src.utils import save_img, get_img, exists, list_files
  8. import evaluate # 迭代优化
  9. CONTENT_WEIGHT = 7.5e0
  10. STYLE_WEIGHT = 1e2
  11. TV_WEIGHT = 2e2
  12. LEARNING_RATE = 1e-3
  13. NUM_EPOCHS = 2
  14. CHECKPOINT_DIR = 'checkpoints'
  15. CHECKPOINT_ITERATIONS = 2000
  16. VGG_PATH = 'data/imagenet-vgg-verydeep-19.mat'
  17. TRAIN_PATH = 'data/' # 图片数据路径
  18. BATCH_SIZE = 4
  19. DEVICE = '/gpu:0' # gpu 计算
  20. FRAC_GPU = 1
  21. # 检测模型中的各个 参数是否已设置好
  22. def check_opts(opts):
  23. exists(opts.checkpoint_dir, "checkpoint dir not found!")
  24. exists(opts.style, "style path not found!")
  25. exists(opts.train_path, "train path not found!")
  26. if opts.test or opts.test_dir:
  27. exists(opts.test, "test img not found!")
  28. exists(opts.test_dir, "test directory not found!")
  29. exists(opts.vgg_path, "vgg network data not found!")
  30. assert opts.epochs > 0
  31. assert opts.batch_size > 0
  32. assert opts.checkpoint_iterations > 0
  33. assert os.path.exists(opts.vgg_path)
  34. assert opts.content_weight >= 0
  35. assert opts.style_weight >= 0
  36. assert opts.tv_weight >= 0
  37. assert opts.learning_rate >= 0
  38. def _get_files(img_dir):
  39. files = list_files(img_dir)
  40. return [os.path.join(img_dir,x) for x in files]
  41. def main():
  42. parser = build_parser()
  43. options = parser.parse_args()
  44. check_opts(options)
  45. style_target = get_img(options.style)
  46. if not options.slow:
  47. content_targets = _get_files(options.train_path)
  48. elif options.test:
  49. content_targets = [options.test]
  50. kwargs = {
  51. "slow":options.slow,
  52. "epochs":options.epochs,
  53. "print_iterations":options.checkpoint_iterations,
  54. "batch_size":options.batch_size,
  55. "save_path":os.path.join(options.checkpoint_dir,'fns.ckpt'),
  56. "learning_rate":options.learning_rate
  57. }
  58. if options.slow:
  59. if options.epochs < 10:
  60. kwargs['epochs'] = 1000
  61. if options.learning_rate < 1:
  62. kwargs['learning_rate'] = 1e1
  63. args = [
  64. content_targets,
  65. style_target,
  66. options.content_weight,
  67. options.style_weight,
  68. options.tv_weight,
  69. options.vgg_path
  70. ]
  71. for preds, losses, i, epoch in optimize(*args, **kwargs):
  72. style_loss, content_loss, tv_loss, loss = losses
  73. print('Epoch %d, Iteration: %d, Loss: %s' % (epoch, i, loss))
  74. to_print = (style_loss, content_loss, tv_loss)
  75. print('style: %s, content:%s, tv: %s' % to_print)
  76. if options.test:
  77. assert options.test_dir != False
  78. preds_path = '%s/%s_%s.png' % (options.test_dir,epoch,i)
  79. if not options.slow:
  80. ckpt_dir = os.path.dirname(options.checkpoint_dir)
  81. evaluate.ffwd_to_img(options.test,preds_path,
  82. options.checkpoint_dir)
  83. else:
  84. save_img(preds_path, img)
  85. ckpt_dir = options.checkpoint_dir
  86. cmd_text = 'python evaluate.py --checkpoint %s ...' % ckpt_dir
  87. print("Training complete. For evaluation:\n `%s`" % cmd_text)
  88. if __name__ == '__main__':
  89. main()

  VGG训练好的模型:
http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat

三. 测试代码,指定一种风格的model,测试便可生成混合图片

  1. from __future__ import print_function
  2. import sys
  3. sys.path.insert(0, 'src')
  4. import numpy as np, src.vgg, pdb, os
  5. from src import transform
  6. import scipy.misc
  7. import tensorflow as tf
  8. from src.utils import save_img, get_img, exists, list_files
  9. from argparse import ArgumentParser
  10. from collections import defaultdict
  11. import time
  12. import json
  13. import subprocess
  14. import numpy
  15. BATCH_SIZE = 4
  16. DEVICE = '/gpu:0'
  17. def from_pipe(opts):
  18. command = ["ffprobe",
  19. '-v', "quiet",
  20. '-print_format', 'json',
  21. '-show_streams', opts.in_path]
  22. info = json.loads(str(subprocess.check_output(command), encoding="utf8"))
  23. width = int(info["streams"][0]["width"])
  24. height = int(info["streams"][0]["height"])
  25. fps = round(eval(info["streams"][0]["r_frame_rate"]))
  26. command = ["ffmpeg",
  27. '-loglevel', "quiet",
  28. '-i', opts.in_path,
  29. '-f', 'image2pipe',
  30. '-pix_fmt', 'rgb24',
  31. '-vcodec', 'rawvideo', '-']
  32. pipe_in = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=10 ** 9, stdin=None, stderr=None)
  33. command = ["ffmpeg",
  34. '-loglevel', "info",
  35. '-y', # (optional) overwrite output file if it exists
  36. '-f', 'rawvideo',
  37. '-vcodec', 'rawvideo',
  38. '-s', str(width) + 'x' + str(height), # size of one frame
  39. '-pix_fmt', 'rgb24',
  40. '-r', str(fps), # frames per second
  41. '-i', '-', # The imput comes from a pipe
  42. '-an', # Tells FFMPEG not to expect any audio
  43. '-c:v', 'libx264',
  44. '-preset', 'slow',
  45. '-crf', '18',
  46. opts.out]
  47. pipe_out = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=None, stderr=None)
  48. g = tf.Graph()
  49. soft_config = tf.ConfigProto(allow_soft_placement=True)
  50. soft_config.gpu_options.allow_growth = True
  51. with g.as_default(), g.device(opts.device), \
  52. tf.Session(config=soft_config) as sess:
  53. batch_shape = (opts.batch_size, height, width, 3)
  54. img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,
  55. name='img_placeholder')
  56. preds = transform.net(img_placeholder)
  57. saver = tf.train.Saver()
  58. if os.path.isdir(opts.checkpoint):
  59. ckpt = tf.train.get_checkpoint_state(opts.checkpoint)
  60. if ckpt and ckpt.model_checkpoint_path:
  61. saver.restore(sess, ckpt.model_checkpoint_path)
  62. else:
  63. raise Exception("No checkpoint found...")
  64. else:
  65. saver.restore(sess, opts.checkpoint)
  66. X = np.zeros(batch_shape, dtype=np.float32)
  67. nbytes = 3 * width * height
  68. read_input = True
  69. last = False
  70. while read_input:
  71. count = 0
  72. while count < opts.batch_size:
  73. raw_image = pipe_in.stdout.read(width * height * 3)
  74. if len(raw_image) != nbytes:
  75. if count == 0:
  76. read_input = False
  77. else:
  78. last = True
  79. X = X[:count]
  80. batch_shape = (count, height, width, 3)
  81. img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,
  82. name='img_placeholder')
  83. preds = transform.net(img_placeholder)
  84. break
  85. image = numpy.fromstring(raw_image, dtype='uint8')
  86. image = image.reshape((height, width, 3))
  87. X[count] = image
  88. count += 1
  89. if read_input:
  90. if last:
  91. read_input = False
  92. _preds = sess.run(preds, feed_dict={img_placeholder: X})
  93. for i in range(0, batch_shape[0]):
  94. img = np.clip(_preds[i], 0, 255).astype(np.uint8)
  95. try:
  96. pipe_out.stdin.write(img)
  97. except IOError as err:
  98. ffmpeg_error = pipe_out.stderr.read()
  99. error = (str(err) + ("\n\nFFMPEG encountered"
  100. "the following error while writing file:"
  101. "\n\n %s" % ffmpeg_error))
  102. read_input = False
  103. print(error)
  104. pipe_out.terminate()
  105. pipe_in.terminate()
  106. pipe_out.stdin.close()
  107. pipe_in.stdout.close()
  108. del pipe_in
  109. del pipe_out
  110. # get img_shape
  111. def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
  112. assert len(paths_out) > 0
  113. is_paths = type(data_in[0]) == str
  114. if is_paths:
  115. assert len(data_in) == len(paths_out)
  116. img_shape = get_img(data_in[0]).shape
  117. else:
  118. assert data_in.size[0] == len(paths_out)
  119. # img_shape = X[0].shape
  120. g = tf.Graph()
  121. batch_size = min(len(paths_out), batch_size)
  122. curr_num = 0
  123. soft_config = tf.ConfigProto(allow_soft_placement=True)
  124. soft_config.gpu_options.allow_growth = True
  125. with g.as_default(), g.device(device_t), tf.Session(config=soft_config) as sess:
  126. batch_shape = (batch_size,) + img_shape
  127. img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,
  128. name='img_placeholder')
  129. preds = transform.net(img_placeholder)
  130. saver = tf.train.Saver()
  131. if os.path.isdir(checkpoint_dir):
  132. ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
  133. if ckpt and ckpt.model_checkpoint_path:
  134. saver.restore(sess, ckpt.model_checkpoint_path)
  135. else:
  136. raise Exception("No checkpoint found...")
  137. else:
  138. saver.restore(sess, checkpoint_dir)
  139. num_iters = int(len(paths_out)/batch_size)
  140. for i in range(num_iters):
  141. pos = i * batch_size
  142. curr_batch_out = paths_out[pos:pos+batch_size]
  143. if is_paths:
  144. curr_batch_in = data_in[pos:pos+batch_size]
  145. X = np.zeros(batch_shape, dtype=np.float32)
  146. for j, path_in in enumerate(curr_batch_in):
  147. img = get_img(path_in)
  148. assert img.shape == img_shape, \
  149. 'Images have different dimensions. ' + \
  150. 'Resize images or use --allow-different-dimensions.'
  151. X[j] = img
  152. else:
  153. X = data_in[pos:pos+batch_size]
  154. _preds = sess.run(preds, feed_dict={img_placeholder:X})
  155. for j, path_out in enumerate(curr_batch_out):
  156. save_img(path_out, _preds[j])
  157. remaining_in = data_in[num_iters*batch_size:]
  158. remaining_out = paths_out[num_iters*batch_size:]
  159. if len(remaining_in) > 0:
  160. ffwd(remaining_in, remaining_out, checkpoint_dir,
  161. device_t=device_t, batch_size=1)
  162. def ffwd_to_img(in_path, out_path, checkpoint_dir, device='/cpu:0'):
  163. paths_in, paths_out = [in_path], [out_path]
  164. ffwd(paths_in, paths_out, checkpoint_dir, batch_size=1, device_t=device)
  165. def ffwd_different_dimensions(in_path, out_path, checkpoint_dir,
  166. device_t=DEVICE, batch_size=4):
  167. in_path_of_shape = defaultdict(list)
  168. out_path_of_shape = defaultdict(list)
  169. for i in range(len(in_path)):
  170. in_image = in_path[i]
  171. out_image = out_path[i]
  172. shape = "%dx%dx%d" % get_img(in_image).shape
  173. in_path_of_shape[shape].append(in_image)
  174. out_path_of_shape[shape].append(out_image)
  175. for shape in in_path_of_shape:
  176. print('Processing images of shape %s' % shape)
  177. ffwd(in_path_of_shape[shape], out_path_of_shape[shape],
  178. checkpoint_dir, device_t, batch_size)
  179. def check_opts(opts):
  180. exists(opts.checkpoint_dir, 'Checkpoint not found!')
  181. exists(opts.in_path, 'In path not found!')
  182. if os.path.isdir(opts.out_path):
  183. exists(opts.out_path, 'out dir not found!')
  184. assert opts.batch_size > 0
  185. def build_parser():
  186. parser = ArgumentParser()
  187. parser.add_argument('--checkpoint', type=str,
  188. dest='checkpoint_dir',
  189. help='dir or .ckpt file to load checkpoint from',
  190. metavar='CHECKPOINT', required=True,default='./model/la_muse.ckpt')
  191. parser.add_argument('--in-path', type=str,
  192. dest='in_path',help='dir or file to transform',
  193. metavar='IN_PATH', required=True,default='./examples/content/stata.jpg')
  194. help_out = 'destination (dir or file) of transformed file or files'
  195. parser.add_argument('--out-path', type=str,
  196. dest='out_path', help=help_out, metavar='OUT_PATH',
  197. required=True,default='./')
  198. parser.add_argument('--device', type=str,
  199. dest='device',help='device to perform compute on',
  200. metavar='DEVICE', default=DEVICE)
  201. parser.add_argument('--batch-size', type=int,
  202. dest='batch_size',help='batch size for feedforwarding',
  203. metavar='BATCH_SIZE', default=BATCH_SIZE)
  204. parser.add_argument('--allow-different-dimensions', action='store_true',
  205. dest='allow_different_dimensions',
  206. help='allow different image dimensions')
  207. return parser
  208. def main():
  209. parser = build_parser()
  210. opts = parser.parse_args()
  211. # 确认输入参数是否已存在,若不存在,重新创建
  212. check_opts(opts)
  213. if not os.path.isdir(opts.in_path):
  214. if os.path.exists(opts.out_path) and os.path.isdir(opts.out_path):
  215. # 获取图片的名称,作为输出图片名
  216. out_path = os.path.join(opts.out_path,os.path.basename(opts.in_path))
  217. else:
  218. out_path = opts.out_path
  219. ffwd_to_img(opts.in_path, out_path, opts.checkpoint_dir,
  220. device=opts.device)
  221. else:
  222. files = list_files(opts.in_path)
  223. full_in = [os.path.join(opts.in_path,x) for x in files]
  224. full_out = [os.path.join(opts.out_path,x) for x in files]
  225. if opts.allow_different_dimensions:
  226. ffwd_different_dimensions(full_in, full_out, opts.checkpoint_dir,
  227. device_t=opts.device, batch_size=opts.batch_size)
  228. else :
  229. ffwd(full_in, full_out, opts.checkpoint_dir, device_t=opts.device,
  230. batch_size=opts.batch_size)
  231. if __name__ == '__main__':
  232. main()

四.应用的神经网络模型

  1. import tensorflow as tf, pdb
  2. WEIGHTS_INIT_STDEV = .1
  3. # 网络结构
  4. def net(image):
  5. conv1 = _conv_layer(image, 32, 9, 1)
  6. conv2 = _conv_layer(conv1, 64, 3, 2)
  7. conv3 = _conv_layer(conv2, 128, 3, 2)
  8. # 残差网络结构
  9. resid1 = _residual_block(conv3, 3)
  10. resid2 = _residual_block(resid1, 3)
  11. resid3 = _residual_block(resid2, 3)
  12. resid4 = _residual_block(resid3, 3)
  13. resid5 = _residual_block(resid4, 3)
  14. conv_t1 = _conv_tranpose_layer(resid5, 64, 3, 2)
  15. conv_t2 = _conv_tranpose_layer(conv_t1, 32, 3, 2)
  16. conv_t3 = _conv_layer(conv_t2, 3, 9, 1, relu=False)
  17. preds = tf.nn.tanh(conv_t3) * 150 + 255./2
  18. return preds
  19. def _conv_layer(net, num_filters, filter_size, strides, relu=True):
  20. weights_init = _conv_init_vars(net, num_filters, filter_size)
  21. strides_shape = [1, strides, strides, 1]
  22. net = tf.nn.conv2d(net, weights_init, strides_shape, padding='SAME')
  23. net = _instance_norm(net)
  24. if relu:
  25. net = tf.nn.relu(net)
  26. return net
  27. # 反卷积操作
  28. def _conv_tranpose_layer(net, num_filters, filter_size, strides):
  29. weights_init = _conv_init_vars(net, num_filters, filter_size, transpose=True) #True 反卷积
  30. batch_size, rows, cols, in_channels = [i.value for i in net.get_shape()]
  31. new_rows, new_cols = int(rows * strides), int(cols * strides) # 反卷积变换
  32. # new_shape = #tf.pack([tf.shape(net)[0], new_rows, new_cols, num_filters])
  33. new_shape = [batch_size, new_rows, new_cols, num_filters] # 新的shape
  34. tf_shape = tf.stack(new_shape)
  35. strides_shape = [1,strides,strides,1]
  36. net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding='SAME')
  37. net = _instance_norm(net)
  38. return tf.nn.relu(net)
  39. # 残差网络的 模块
  40. def _residual_block(net, filter_size=3):
  41. tmp = _conv_layer(net, 128, filter_size, 1)
  42. return net + _conv_layer(tmp, 128, filter_size, 1, relu=False)
  43. # batch_normalization 模块
  44. def _instance_norm(net, train=True):
  45. batch, rows, cols, channels = [i.value for i in net.get_shape()] # 特征图
  46. var_shape = [channels]
  47. # 当前特征图中的均值,方差
  48. mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
  49. shift = tf.Variable(tf.zeros(var_shape))
  50. scale = tf.Variable(tf.ones(var_shape))
  51. epsilon = 1e-3
  52. normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
  53. return scale * normalized + shift
  54. def _conv_init_vars(net, out_channels, filter_size, transpose=False):
  55. _, rows, cols, in_channels = [i.value for i in net.get_shape()]
  56. if not transpose:
  57. weights_shape = [filter_size, filter_size, in_channels, out_channels]
  58. else:
  59. weights_shape = [filter_size, filter_size, out_channels, in_channels] # 反卷积
  60. weights_init = tf.Variable(tf.truncated_normal(weights_shape, stddev=WEIGHTS_INIT_STDEV, seed=1), dtype=tf.float32)
  61. return weights_init

由于代码过多,不易全部展示,完整Demo参加GitHub链接:

https://github.com/Whq123/Style-transfer-of-picture

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/364845
推荐阅读
相关标签
  

闽ICP备14008679号