赞
踩
为了深度学习模型的移植,首先需要将训练好的模型.ckpt的三个模型保存成.pb模型,在网上找到了很多方法,但是困难重重,中间经历找不到输入输出node,找到之后输出的模型不能进行预测,后来终于找到了方法,这里记录一下。我成功啦~
步骤如下:
代码如下:
import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from core.yolov3 import YOLOV3
input_size = 416
with tf.name_scope('input'):
img_input = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data') # 输入节点
model = YOLOV3(img_input, False) # mYOLOv3模型
print(img_input)#输出输入节点名称
print(model.pred_sbbox, model.pred_mbbox, model.pred_lbbox)#输出输出接电脑名称
会输出:
Tensor("input/input_data:0", shape=(?, 416, 416, 3), dtype=float32)
Tensor("pred_sbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32) Tensor("pred_mbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32) Tensor("pred_lbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32)
#! /usr/bin/env python # coding=utf-8 import tensorflow as tf import os from tensorflow.python.tools import freeze_graph from core.yolov3 import YOLOV3 input_size = 416 output_node_names = ["pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]#输出节点 ckpt_filename = "./pbmodel/new/yolov3.ckpt" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前 pb_file = "./pbmodel/yolov3_new.pb" def ckpt2pb2(): with tf.Graph().as_default() as graph_old: with tf.name_scope('input'): img_input = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data')#输入节点 model = YOLOV3(img_input, False)#mYOLOv3模型 print(img_input) print(model.pred_sbbox, model.pred_mbbox, model.pred_lbbox) isess = tf.InteractiveSession() isess.run(tf.global_variables_initializer()) saver = tf.train.Saver() saver.restore(isess, ckpt_filename) constant_graph = tf.graph_util.convert_variables_to_constants(isess, isess.graph_def, output_node_names) constant_graph = tf.graph_util.remove_training_nodes(constant_graph) with tf.gfile.GFile(pb_file, mode='wb') as f: f.write(constant_graph.SerializeToString()) print("%d ops in the final graph." % len(constant_graph.node)) # 得到当前图有几个操作节点 if __name__ == '__main__': ckpt2pb2()
import tensorflow as tf
with tf.Session() as sess:
model_filename ='./pbmodel/yolov3_new.pb'#模型路径
with tf.gfile.GFile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
g_in = tf.import_graph_def(graph_def)
train_writer = tf.summary.FileWriter("./pblog")#保存
train_writer.add_graph(sess.graph)
train_writer.flush()
train_writer.close()
我这里使用的是project的demo文件,直接修改了pb的名称
娃哈哈,四天的时间,第一天不知道熟悉代码,第二天查资料如果打开ckpt并知道怎么得到输入输出的节点名称,第三天终于能够生成了pb文件,但是不能进行预测,第四天,换了新的方法,终于成功啦!!!哈哈哈哈哈
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。