当前位置:   article > 正文

深度学习模型ckpt转成pb模型(YOLOv3为例)_tensorflow模型文件(ckpt)转pb文件 yolov3

tensorflow模型文件(ckpt)转pb文件 yolov3

为了深度学习模型的移植,首先需要将训练好的模型.ckpt的三个模型保存成.pb模型,在网上找到了很多方法,但是困难重重,中间经历找不到输入输出node,找到之后输出的模型不能进行预测,后来终于找到了方法,这里记录一下。我成功啦~
步骤如下:

1.首先获得输入输出节点的名字

代码如下:

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from core.yolov3 import YOLOV3
input_size = 416
with tf.name_scope('input'):
    img_input = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data')  # 输入节点
model = YOLOV3(img_input, False)  # mYOLOv3模型
print(img_input)#输出输入节点名称
print(model.pred_sbbox, model.pred_mbbox, model.pred_lbbox)#输出输出接电脑名称
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

会输出:

Tensor("input/input_data:0", shape=(?, 416, 416, 3), dtype=float32)
Tensor("pred_sbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32) Tensor("pred_mbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32) Tensor("pred_lbbox/concat_2:0", shape=(?, ?, ?, 3, 10), dtype=float32)
  • 1
  • 2

2.转换成pb文件

#! /usr/bin/env python
# coding=utf-8
import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
from core.yolov3 import YOLOV3
input_size = 416
output_node_names = ["pred_sbbox/concat_2", "pred_mbbox/concat_2", "pred_lbbox/concat_2"]#输出节点
ckpt_filename = "./pbmodel/new/yolov3.ckpt" #设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前
pb_file       = "./pbmodel/yolov3_new.pb"
def ckpt2pb2():
    with tf.Graph().as_default() as graph_old:
        with tf.name_scope('input'):
            img_input = tf.placeholder(dtype=tf.float32, shape=(None, input_size, input_size, 3), name='input_data')#输入节点
        model = YOLOV3(img_input, False)#mYOLOv3模型
        print(img_input)
        print(model.pred_sbbox, model.pred_mbbox, model.pred_lbbox)

        isess = tf.InteractiveSession()

        isess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        saver.restore(isess, ckpt_filename)

        constant_graph = tf.graph_util.convert_variables_to_constants(isess, isess.graph_def,
                                                                      output_node_names)
        constant_graph = tf.graph_util.remove_training_nodes(constant_graph)

        with tf.gfile.GFile(pb_file, mode='wb') as f:
            f.write(constant_graph.SerializeToString())
        print("%d ops in the final graph." % len(constant_graph.node))  # 得到当前图有几个操作节点
if __name__ == '__main__':
    ckpt2pb2()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

3.使用tensorboard查看pb文件的图

1.先使用如下代码读取pb文件,在pblog文件中生成了event,然后再tensorboard上查看图

import tensorflow as tf
with tf.Session() as sess:
    model_filename ='./pbmodel/yolov3_new.pb'#模型路径
    with tf.gfile.GFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        g_in = tf.import_graph_def(graph_def)
        train_writer = tf.summary.FileWriter("./pblog")#保存
        train_writer.add_graph(sess.graph)
        train_writer.flush()
        train_writer.close()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

2.打开prompt,输入tensorboard --logdir=“路径”,然后就可以看啦。

在这里插入图片描述

4.使用pb模型进行预测

我这里使用的是project的demo文件,直接修改了pb的名称
在这里插入图片描述

总结

娃哈哈,四天的时间,第一天不知道熟悉代码,第二天查资料如果打开ckpt并知道怎么得到输入输出的节点名称,第三天终于能够生成了pb文件,但是不能进行预测,第四天,换了新的方法,终于成功啦!!!哈哈哈哈哈

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/297901
推荐阅读
相关标签
  

闽ICP备14008679号