当前位置:   article > 正文

(四)快速图像风格迁移训练模型载入及处理图像_快速风格迁移模型训练

快速风格迁移模型训练

系列文章
(一)图像风格迁移
(二)快速图像风格转换
(三)快速图像风格转换代码解析

1 神经网络模型

def net(image, training):
    '''图像填充'''
    image = tf.pad(image, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='REFLECT')
    '''
    (4, 276, 276, 3)
    :params 4: 每组图像数量
    :params [276, 276, 3] : 图像尺寸.
    '''
    print("image shape after padding: {}".format(image.shape))

    with tf.variable_scope('conv1'):
        '''
        :params 3: 当前图像深度
        :params 32: 下一网络层图像深度
        :params 9:填充和滑动窗口内核.
        :parmas 1: 滑动窗口平移步长
        '''
        '''[276, 276, 32]'''
        conv1 = relu(instance_norm(conv2d(image, 3, 32, 9, 1)))
        print("conv1 shape: {}".format(conv1.shape))
    with tf.variable_scope('conv2'):
        '''[]'''
        conv2 = relu(instance_norm(conv2d(conv1, 32, 64, 3, 2)))
        print("conv2 shape: {}".format(conv2.shape))
    with tf.variable_scope('conv3'):
        conv3 = relu(instance_norm(conv2d(conv2, 64, 128, 3, 2)))
    with tf.variable_scope('res1'):
        res1 = residual(conv3, 128, 3, 1)
    with tf.variable_scope('res2'):
        res2 = residual(res1, 128, 3, 1)
    with tf.variable_scope('res3'):
        res3 = residual(res2, 128, 3, 1)
    with tf.variable_scope('res4'):
        res4 = residual(res3, 128, 3, 1)
    with tf.variable_scope('res5'):
        res5 = residual(res4, 128, 3, 1)
        print("NN processed shape: {}".format(res5.get_shape()))
    with tf.variable_scope('deconv1'):
        # deconv1 = relu(instance_norm(conv2d_transpose(res5, 128, 64, 3, 2)))
        deconv1 = relu(instance_norm(resize_conv2d(res5, 128, 64, 3, 2, training)))
        print("deconv1 shape: {}".format(deconv1.shape))
    with tf.variable_scope('deconv2'):
        # deconv2 = relu(instance_norm(conv2d_transpose(deconv1, 64, 32, 3, 2)))
        deconv2 = relu(instance_norm(resize_conv2d(deconv1, 64, 32, 3, 2, training)))
        print("deconv2 shape: {}".format(deconv2.shape))
    with tf.variable_scope('deconv3'):
        # deconv_test = relu(instance_norm(conv2d(deconv2, 32, 32, 2, 1)))
        deconv3 = tf.nn.tanh(instance_norm(conv2d(deconv2, 32, 3, 9, 1)))
        print("deconv3 shape: {}".format(deconv3.shape))
        print("deconv3 value: {}".format(deconv3))

    y = (deconv3 + 1) * 127.5
    print("processed value: {}".format(y))

    # Remove border effect reducing padding.
    height = tf.shape(y)[1]
    width = tf.shape(y)[2]
    y = tf.slice(y, [0, 10, 10, 0], tf.stack([-1, height - 20, width - 20, -1]))
    '''final y: Tensor("Slice_1:0", shape=(4, 256, 256, 3), dtype=float32)'''
    print("final y: {}".format(y))
    return y
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
'
运行

2 模型载入及读取

import tensorflow as tf
import os
from preprocessing import preprocessing_factory
import reader
import model
import time
import base64
'''基本路径'''
basedir = os.path.abspath(os.path.dirname(__name__))
'''图像路径'''
image_path = "./process/xdqtest_resize.png"
height = 0
width = 0
'''读取图像,获取图像尺寸width和height'''
with open(image_path, 'rb') as img:
    with tf.Session().as_default() as sess:
        if image_path.lower().endswith('png'):
            image = sess.run(tf.image.decode_png(img.read()))
        else:
            image = sess.run(tf.image.decode_jpeg(img.read()))
        height = image.shape[0]
        width = image.shape[1]
if __name__ == "__main__":
    with tf.Session() as sess:
    	'''处理图像的闭包函数'''
        image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing(
            "vgg_16",
            is_training=False)
        '''读取图像'''
        image = reader.get_image(image_path, height, width, image_preprocessing_fn)
        '''增加图像维度'''
        image = tf.expand_dims(image, 0)
		'''建立网络结构'''
        generated = model.net(image, training=False)
        '''载入网络的全局变量'''
        saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
        '''
        初始化全局和本地变量,其中,
        全局变量为网络中的变量,
        本地变量为session中的变量.
        '''
        sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
        '''模型路径:只包含训练的参数'''
        model_path = basedir + "/models/mosaic.ckpt-done"
        print("model path: {}".format(model_path))
        '''载入模型参数'''
        saver.restore(sess, model_path)
        '''模型参数读取:获取模型ckpt的图结构graph_def部分'''
        read_graph = sess.graph.as_graph_def()
        for node in read_graph.node:
            print("node name: {}----->node operation: {}".format(node.name, node.op))
        '''保存图像路径'''
        generated_file = 'generated/processed.jpg' +
        if os.path.exists('generated') is False:
            os.makedir('generated')
        '''保存图像'''
        with open(generated_file, 'wb') as img:
            start_time = time.time()
            img.write(sess.run(tf.image.encode_jpeg(generated)))
            end_time = time.time()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • Result
...
'''卷积层'''
node name: conv1/conv/truncated_normal/shape----->node operation: Const
node name: conv1/conv/truncated_normal/mean----->node operation: Const
node name: conv1/conv/truncated_normal/stddev----->node operation: Const
...
'''参差层'''
node name: res1/residual/conv/MirrorPad----->node operation: MirrorPad
node name: res1/residual/conv/conv----->node operation: Conv2D
node name: res1/residual/Relu----->node operation: Relu
node name: res1/residual/Equal----->node operation: Equal
'''图像恢复'''
node name: deconv1/conv_transpose/Shape----->node operation: Const
...
node name: deconv1/conv_transpose/conv/weight----->node operation: VariableV2
node name: deconv1/conv_transpose/conv/weight/Assign----->node operation: Assign
node name: deconv1/conv_transpose/conv/weight/read----->node operation: Identity
'''保存模型'''
node name: save/Const----->node operation: Const
...
node name: save/Assign_15----->node operation: Assign
node name: save/restore_all----->node operation: NoOp
node name: init----->node operation: NoOp
node name: init_1----->node operation: NoOp
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • Analysis
    (1) 只截取了神经网络部分的结构;
    (2) 训练的神经网络是新建的model.py文件中的网络,slim-vgg网络用于提取图像的内容和风格;
    (3) 载入模型前先建立网络结构,再载入保存的模型参数,即可利用模型计算图像风格.

3 处理结果

在这里插入图片描述

图3.1 图像风格 图3.2 待转换的图像

在这里插入图片描述

图3.3 转换结果

4 总结

(1) 载入模型前先确认模型类型即该模型中包含的是参数还是结构,若模型只含有参数,则载入模型前需要先建立网络结构;
(2) 图像风格迁移训练的网络是新建的神经网络,不是slimvgg网络,风格网络结构中最后的三层是将深度处理的图像进行去深度化,以获取正常三通道(RGB)的图像;
(3) 使用训练模型处理抓换图片,直接将图像数据传入神经网络即可,没有严格使用sess.run(variable, feed_dict={x: x_data}).


[参考文献]
[1]https://blog.csdn.net/Xin_101/article/details/88883977
[2]https://blog.csdn.net/Xin_101/article/details/87854371
[3]https://blog.csdn.net/Xin_101/article/details/88581250
[4]https://blog.csdn.net/Xin_101/article/details/84981890


声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号