赞
踩
原图:
风格图:
数据集链接:训练数据,8W多 12G蛮大的
http://msvocds.blob.core.windows.net/coco2014/train2014.zip
训练代码:
- from __future__ import print_function
- import sys, os, pdb
- import numpy as np
- import scipy.misc
- from src.optimize import optimize
- from argparse import ArgumentParser
- from src.utils import save_img, get_img, exists, list_files
- import evaluate # 迭代优化
-
- CONTENT_WEIGHT = 7.5e0
- STYLE_WEIGHT = 1e2
- TV_WEIGHT = 2e2
-
- LEARNING_RATE = 1e-3
- NUM_EPOCHS = 2
- CHECKPOINT_DIR = 'checkpoints'
- CHECKPOINT_ITERATIONS = 2000
- VGG_PATH = 'data/imagenet-vgg-verydeep-19.mat'
- TRAIN_PATH = 'data/' # 图片数据路径
- BATCH_SIZE = 4
- DEVICE = '/gpu:0' # gpu 计算
- FRAC_GPU = 1
-
- # 检测模型中的各个 参数是否已设置好
- def check_opts(opts):
- exists(opts.checkpoint_dir, "checkpoint dir not found!")
- exists(opts.style, "style path not found!")
- exists(opts.train_path, "train path not found!")
- if opts.test or opts.test_dir:
- exists(opts.test, "test img not found!")
- exists(opts.test_dir, "test directory not found!")
- exists(opts.vgg_path, "vgg network data not found!")
- assert opts.epochs > 0
- assert opts.batch_size > 0
- assert opts.checkpoint_iterations > 0
- assert os.path.exists(opts.vgg_path)
- assert opts.content_weight >= 0
- assert opts.style_weight >= 0
- assert opts.tv_weight >= 0
- assert opts.learning_rate >= 0
-
- def _get_files(img_dir):
- files = list_files(img_dir)
- return [os.path.join(img_dir,x) for x in files]
-
-
- def main():
- parser = build_parser()
- options = parser.parse_args()
- check_opts(options)
-
- style_target = get_img(options.style)
- if not options.slow:
- content_targets = _get_files(options.train_path)
- elif options.test:
- content_targets = [options.test]
-
- kwargs = {
- "slow":options.slow,
- "epochs":options.epochs,
- "print_iterations":options.checkpoint_iterations,
- "batch_size":options.batch_size,
- "save_path":os.path.join(options.checkpoint_dir,'fns.ckpt'),
- "learning_rate":options.learning_rate
- }
-
- if options.slow:
- if options.epochs < 10:
- kwargs['epochs'] = 1000
- if options.learning_rate < 1:
- kwargs['learning_rate'] = 1e1
-
- args = [
- content_targets,
- style_target,
- options.content_weight,
- options.style_weight,
- options.tv_weight,
- options.vgg_path
- ]
-
- for preds, losses, i, epoch in optimize(*args, **kwargs):
- style_loss, content_loss, tv_loss, loss = losses
-
- print('Epoch %d, Iteration: %d, Loss: %s' % (epoch, i, loss))
- to_print = (style_loss, content_loss, tv_loss)
- print('style: %s, content:%s, tv: %s' % to_print)
- if options.test:
- assert options.test_dir != False
- preds_path = '%s/%s_%s.png' % (options.test_dir,epoch,i)
- if not options.slow:
- ckpt_dir = os.path.dirname(options.checkpoint_dir)
- evaluate.ffwd_to_img(options.test,preds_path,
- options.checkpoint_dir)
- else:
- save_img(preds_path, img)
- ckpt_dir = options.checkpoint_dir
- cmd_text = 'python evaluate.py --checkpoint %s ...' % ckpt_dir
- print("Training complete. For evaluation:\n `%s`" % cmd_text)
-
- if __name__ == '__main__':
- main()
VGG训练好的模型:
http://www.vlfeat.org/matconvnet/models/beta16/imagenet-vgg-verydeep-19.mat
- from __future__ import print_function
- import sys
- sys.path.insert(0, 'src')
- import numpy as np, src.vgg, pdb, os
- from src import transform
- import scipy.misc
- import tensorflow as tf
- from src.utils import save_img, get_img, exists, list_files
- from argparse import ArgumentParser
- from collections import defaultdict
- import time
- import json
- import subprocess
- import numpy
-
- BATCH_SIZE = 4
- DEVICE = '/gpu:0'
-
- def from_pipe(opts):
- command = ["ffprobe",
- '-v', "quiet",
- '-print_format', 'json',
- '-show_streams', opts.in_path]
-
- info = json.loads(str(subprocess.check_output(command), encoding="utf8"))
- width = int(info["streams"][0]["width"])
- height = int(info["streams"][0]["height"])
- fps = round(eval(info["streams"][0]["r_frame_rate"]))
-
- command = ["ffmpeg",
- '-loglevel', "quiet",
- '-i', opts.in_path,
- '-f', 'image2pipe',
- '-pix_fmt', 'rgb24',
- '-vcodec', 'rawvideo', '-']
-
- pipe_in = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=10 ** 9, stdin=None, stderr=None)
-
- command = ["ffmpeg",
- '-loglevel', "info",
- '-y', # (optional) overwrite output file if it exists
- '-f', 'rawvideo',
- '-vcodec', 'rawvideo',
- '-s', str(width) + 'x' + str(height), # size of one frame
- '-pix_fmt', 'rgb24',
- '-r', str(fps), # frames per second
- '-i', '-', # The imput comes from a pipe
- '-an', # Tells FFMPEG not to expect any audio
- '-c:v', 'libx264',
- '-preset', 'slow',
- '-crf', '18',
- opts.out]
-
- pipe_out = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=None, stderr=None)
- g = tf.Graph()
- soft_config = tf.ConfigProto(allow_soft_placement=True)
- soft_config.gpu_options.allow_growth = True
-
- with g.as_default(), g.device(opts.device), \
- tf.Session(config=soft_config) as sess:
- batch_shape = (opts.batch_size, height, width, 3)
- img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,
- name='img_placeholder')
- preds = transform.net(img_placeholder)
- saver = tf.train.Saver()
- if os.path.isdir(opts.checkpoint):
- ckpt = tf.train.get_checkpoint_state(opts.checkpoint)
- if ckpt and ckpt.model_checkpoint_path:
- saver.restore(sess, ckpt.model_checkpoint_path)
- else:
- raise Exception("No checkpoint found...")
- else:
- saver.restore(sess, opts.checkpoint)
-
- X = np.zeros(batch_shape, dtype=np.float32)
- nbytes = 3 * width * height
- read_input = True
- last = False
-
- while read_input:
- count = 0
- while count < opts.batch_size:
- raw_image = pipe_in.stdout.read(width * height * 3)
-
- if len(raw_image) != nbytes:
- if count == 0:
- read_input = False
- else:
- last = True
- X = X[:count]
- batch_shape = (count, height, width, 3)
- img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,
- name='img_placeholder')
- preds = transform.net(img_placeholder)
- break
-
- image = numpy.fromstring(raw_image, dtype='uint8')
- image = image.reshape((height, width, 3))
- X[count] = image
- count += 1
-
- if read_input:
- if last:
- read_input = False
- _preds = sess.run(preds, feed_dict={img_placeholder: X})
-
- for i in range(0, batch_shape[0]):
- img = np.clip(_preds[i], 0, 255).astype(np.uint8)
- try:
- pipe_out.stdin.write(img)
- except IOError as err:
- ffmpeg_error = pipe_out.stderr.read()
- error = (str(err) + ("\n\nFFMPEG encountered"
- "the following error while writing file:"
- "\n\n %s" % ffmpeg_error))
- read_input = False
- print(error)
- pipe_out.terminate()
- pipe_in.terminate()
- pipe_out.stdin.close()
- pipe_in.stdout.close()
- del pipe_in
- del pipe_out
-
- # get img_shape
- def ffwd(data_in, paths_out, checkpoint_dir, device_t='/gpu:0', batch_size=4):
- assert len(paths_out) > 0
- is_paths = type(data_in[0]) == str
- if is_paths:
- assert len(data_in) == len(paths_out)
- img_shape = get_img(data_in[0]).shape
- else:
- assert data_in.size[0] == len(paths_out)
- # img_shape = X[0].shape
-
- g = tf.Graph()
- batch_size = min(len(paths_out), batch_size)
- curr_num = 0
- soft_config = tf.ConfigProto(allow_soft_placement=True)
- soft_config.gpu_options.allow_growth = True
- with g.as_default(), g.device(device_t), tf.Session(config=soft_config) as sess:
- batch_shape = (batch_size,) + img_shape
- img_placeholder = tf.placeholder(tf.float32, shape=batch_shape,
- name='img_placeholder')
-
- preds = transform.net(img_placeholder)
- saver = tf.train.Saver()
- if os.path.isdir(checkpoint_dir):
- ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
- if ckpt and ckpt.model_checkpoint_path:
- saver.restore(sess, ckpt.model_checkpoint_path)
- else:
- raise Exception("No checkpoint found...")
- else:
- saver.restore(sess, checkpoint_dir)
-
- num_iters = int(len(paths_out)/batch_size)
- for i in range(num_iters):
- pos = i * batch_size
- curr_batch_out = paths_out[pos:pos+batch_size]
- if is_paths:
- curr_batch_in = data_in[pos:pos+batch_size]
- X = np.zeros(batch_shape, dtype=np.float32)
- for j, path_in in enumerate(curr_batch_in):
- img = get_img(path_in)
- assert img.shape == img_shape, \
- 'Images have different dimensions. ' + \
- 'Resize images or use --allow-different-dimensions.'
- X[j] = img
- else:
- X = data_in[pos:pos+batch_size]
-
- _preds = sess.run(preds, feed_dict={img_placeholder:X})
- for j, path_out in enumerate(curr_batch_out):
- save_img(path_out, _preds[j])
-
- remaining_in = data_in[num_iters*batch_size:]
- remaining_out = paths_out[num_iters*batch_size:]
- if len(remaining_in) > 0:
- ffwd(remaining_in, remaining_out, checkpoint_dir,
- device_t=device_t, batch_size=1)
-
- def ffwd_to_img(in_path, out_path, checkpoint_dir, device='/cpu:0'):
- paths_in, paths_out = [in_path], [out_path]
- ffwd(paths_in, paths_out, checkpoint_dir, batch_size=1, device_t=device)
-
- def ffwd_different_dimensions(in_path, out_path, checkpoint_dir,
- device_t=DEVICE, batch_size=4):
- in_path_of_shape = defaultdict(list)
- out_path_of_shape = defaultdict(list)
- for i in range(len(in_path)):
- in_image = in_path[i]
- out_image = out_path[i]
- shape = "%dx%dx%d" % get_img(in_image).shape
- in_path_of_shape[shape].append(in_image)
- out_path_of_shape[shape].append(out_image)
- for shape in in_path_of_shape:
- print('Processing images of shape %s' % shape)
- ffwd(in_path_of_shape[shape], out_path_of_shape[shape],
- checkpoint_dir, device_t, batch_size)
-
- def check_opts(opts):
- exists(opts.checkpoint_dir, 'Checkpoint not found!')
- exists(opts.in_path, 'In path not found!')
- if os.path.isdir(opts.out_path):
- exists(opts.out_path, 'out dir not found!')
- assert opts.batch_size > 0
-
- def build_parser():
- parser = ArgumentParser()
- parser.add_argument('--checkpoint', type=str,
- dest='checkpoint_dir',
- help='dir or .ckpt file to load checkpoint from',
- metavar='CHECKPOINT', required=True,default='./model/la_muse.ckpt')
-
- parser.add_argument('--in-path', type=str,
- dest='in_path',help='dir or file to transform',
- metavar='IN_PATH', required=True,default='./examples/content/stata.jpg')
-
- help_out = 'destination (dir or file) of transformed file or files'
- parser.add_argument('--out-path', type=str,
- dest='out_path', help=help_out, metavar='OUT_PATH',
- required=True,default='./')
-
- parser.add_argument('--device', type=str,
- dest='device',help='device to perform compute on',
- metavar='DEVICE', default=DEVICE)
-
- parser.add_argument('--batch-size', type=int,
- dest='batch_size',help='batch size for feedforwarding',
- metavar='BATCH_SIZE', default=BATCH_SIZE)
-
- parser.add_argument('--allow-different-dimensions', action='store_true',
- dest='allow_different_dimensions',
- help='allow different image dimensions')
- return parser
-
- def main():
- parser = build_parser()
- opts = parser.parse_args()
- # 确认输入参数是否已存在,若不存在,重新创建
- check_opts(opts)
- if not os.path.isdir(opts.in_path):
- if os.path.exists(opts.out_path) and os.path.isdir(opts.out_path):
- # 获取图片的名称,作为输出图片名
- out_path = os.path.join(opts.out_path,os.path.basename(opts.in_path))
- else:
- out_path = opts.out_path
- ffwd_to_img(opts.in_path, out_path, opts.checkpoint_dir,
- device=opts.device)
- else:
- files = list_files(opts.in_path)
- full_in = [os.path.join(opts.in_path,x) for x in files]
- full_out = [os.path.join(opts.out_path,x) for x in files]
- if opts.allow_different_dimensions:
- ffwd_different_dimensions(full_in, full_out, opts.checkpoint_dir,
- device_t=opts.device, batch_size=opts.batch_size)
- else :
- ffwd(full_in, full_out, opts.checkpoint_dir, device_t=opts.device,
- batch_size=opts.batch_size)
-
-
- if __name__ == '__main__':
- main()
- import tensorflow as tf, pdb
-
- WEIGHTS_INIT_STDEV = .1
- # 网络结构
- def net(image):
- conv1 = _conv_layer(image, 32, 9, 1)
- conv2 = _conv_layer(conv1, 64, 3, 2)
- conv3 = _conv_layer(conv2, 128, 3, 2)
- # 残差网络结构
- resid1 = _residual_block(conv3, 3)
- resid2 = _residual_block(resid1, 3)
- resid3 = _residual_block(resid2, 3)
- resid4 = _residual_block(resid3, 3)
- resid5 = _residual_block(resid4, 3)
-
- conv_t1 = _conv_tranpose_layer(resid5, 64, 3, 2)
- conv_t2 = _conv_tranpose_layer(conv_t1, 32, 3, 2)
- conv_t3 = _conv_layer(conv_t2, 3, 9, 1, relu=False)
- preds = tf.nn.tanh(conv_t3) * 150 + 255./2
- return preds
-
- def _conv_layer(net, num_filters, filter_size, strides, relu=True):
- weights_init = _conv_init_vars(net, num_filters, filter_size)
- strides_shape = [1, strides, strides, 1]
- net = tf.nn.conv2d(net, weights_init, strides_shape, padding='SAME')
- net = _instance_norm(net)
- if relu:
- net = tf.nn.relu(net)
- return net
-
- # 反卷积操作
- def _conv_tranpose_layer(net, num_filters, filter_size, strides):
- weights_init = _conv_init_vars(net, num_filters, filter_size, transpose=True) #True 反卷积
-
- batch_size, rows, cols, in_channels = [i.value for i in net.get_shape()]
- new_rows, new_cols = int(rows * strides), int(cols * strides) # 反卷积变换
- # new_shape = #tf.pack([tf.shape(net)[0], new_rows, new_cols, num_filters])
-
- new_shape = [batch_size, new_rows, new_cols, num_filters] # 新的shape
- tf_shape = tf.stack(new_shape)
- strides_shape = [1,strides,strides,1]
-
- net = tf.nn.conv2d_transpose(net, weights_init, tf_shape, strides_shape, padding='SAME')
- net = _instance_norm(net)
- return tf.nn.relu(net)
-
- # 残差网络的 模块
- def _residual_block(net, filter_size=3):
- tmp = _conv_layer(net, 128, filter_size, 1)
- return net + _conv_layer(tmp, 128, filter_size, 1, relu=False)
-
- # batch_normalization 模块
- def _instance_norm(net, train=True):
- batch, rows, cols, channels = [i.value for i in net.get_shape()] # 特征图
- var_shape = [channels]
- # 当前特征图中的均值,方差
- mu, sigma_sq = tf.nn.moments(net, [1,2], keep_dims=True)
- shift = tf.Variable(tf.zeros(var_shape))
- scale = tf.Variable(tf.ones(var_shape))
- epsilon = 1e-3
- normalized = (net-mu)/(sigma_sq + epsilon)**(.5)
- return scale * normalized + shift
-
- def _conv_init_vars(net, out_channels, filter_size, transpose=False):
- _, rows, cols, in_channels = [i.value for i in net.get_shape()]
- if not transpose:
- weights_shape = [filter_size, filter_size, in_channels, out_channels]
- else:
- weights_shape = [filter_size, filter_size, out_channels, in_channels] # 反卷积
-
- weights_init = tf.Variable(tf.truncated_normal(weights_shape, stddev=WEIGHTS_INIT_STDEV, seed=1), dtype=tf.float32)
- return weights_init
由于代码过多,不易全部展示,完整Demo参加GitHub链接:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。