当前位置:   article > 正文

tensorflow使用python对pb模型做预估_from tensorflow.contrib.layers import flatten

from tensorflow.contrib.layers import flatten

tensorflow中ckpt模型转成pb模型的代码:参考链接https://blog.csdn.net/dulingtingzi/article/details/90790282

但是为了使大家更容易明白,因为有些变量需要统一,这里针对下面的使用pb模型进行预估的代码,粘贴一下ckpt转pb模型:

  1. import tensorflow as tf
  2. import os
  3. from tensorflow.contrib.layers import flatten
  4. from tensorflow.python.framework import graph_util
  5. import tensorflow.contrib.slim as slim
  6. import numpy as np
  7. growth_rate = 6
  8. depth = 50
  9. compression = 0.5
  10. weight_decay = 0.0001
  11. nb_blocks = int((depth - 4) / 6)
  12. def dense_net(img_input, num_classes, nb_blocks, growth_rate, weight_decay, compression, flag):
  13. ##自定义densenet代码
  14. return densenet(growth_rate,img_input,num_classes,weight_decay,nb_blocks,compression,flag)
  15. def set_config():#设置GPU使用率# 控制使用率
  16. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  17. # 假如有16GB的显存并使用其中的8GB:
  18. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.3)
  19. config = tf.ConfigProto(gpu_options=gpu_options)
  20. # session = tf.Session(config=config)
  21. return config
  22. #下面是你自定义的模型
  23. num_classes=2
  24. #is_training = tf.placeholder(tf.bool, name='placeholder_is_training')
  25. is_training = tf.constant(False, dtype=tf.bool)#下面的名字要和你一开始训练模型的时候是一致的
  26. inputs = tf.placeholder(tf.float32, shape=[None,30, 280, 3], name='placeholder_x')
  27. labels = tf.placeholder(tf.float32, shape=[None,num_classes], name='placeholder_y')
  28. pred=dense_net(inputs, num_classes,nb_blocks, growth_rate,weight_decay,compression,is_training)
  29. model_path="./version1/checkpoint/2_class.ckpt-1"#设置model的路径,因新版tensorflow会生成三个文件,只需写到数字前
  30. cfg=set_config()
  31. from tensorflow.python.saved_model import signature_constants, signature_def_utils, tag_constants, utils
  32. save_path = './version1/model_pb/test'
  33. with tf.Session(config=cfg) as sess:
  34. saver = tf.train.Saver()
  35. saver.restore(sess, model_path)
  36. print('ckpt loaded')
  37. #注意下面的inputs和outputs的名字要和后面的pb模型做预估保持一致
  38. model_signature = signature_def_utils.build_signature_def(inputs={"input": utils.build_tensor_info(inputs)},outputs={"pred": utils.build_tensor_info(pred)},method_name=signature_constants.PREDICT_METHOD_NAME)
  39. builder = tf.saved_model.builder.SavedModelBuilder(save_path)
  40. legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
  41. builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING],clear_devices=True,signature_def_map={signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:model_signature},legacy_init_op=legacy_init_op)
  42. builder.save()
  43. print('saved_model saved')

 

tensorflow使用pb模型做预估的代码:

  1. import os
  2. import math
  3. import cv2
  4. import numpy as np
  5. import tensorflow as tf
  6. from tensorflow.python.saved_model import signature_constants, signature_def_utils, tag_constants, utils
  7. import matplotlib.pyplot as plt
  8. from time import time
  9. os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  10. def preprocessing_crop_batch(images, height=30, width=280,depth=3):#按照batch去处理图像
  11. bs = len(images)
  12. GAUGE = height
  13. img_canvas = np.zeros([bs, height, width, depth], dtype=np.float32)
  14. #img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  15. for i,img in enumerate(images):
  16. #####预处理代码块
  17. return img_canvas
  18. sess = tf.Session()
  19. m = tf.saved_model.loader.load(sess, tags=[tag_constants.SERVING], export_dir='./version1/model_pb/test/')
  20. graph = tf.get_default_graph()
  21. signature = m.signature_def
  22. signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
  23. input_tensor_name0 = signature[signature_key].inputs['input'].name#与之前转pb的时候的输入输出名字保持一致
  24. output_tensor_name = signature[signature_key].outputs['pred'].name
  25. x0 = tf.get_default_graph().get_tensor_by_name(input_tensor_name0)
  26. y0 = tf.get_default_graph().get_tensor_by_name(output_tensor_name)
  27. img_path = './images'
  28. imgs = os.listdir(img_path)
  29. imgs = list(map(lambda x : os.path.join(img_path, x), imgs))
  30. bs = 8
  31. img_batch = []
  32. for i in range(bs):
  33. img_batch.append(cv2.imread(imgs[i]))
  34. t0 = time()
  35. img_preprocess= preprocessing_crop_batch(img_batch)
  36. ans = sess.run(y0, {x0 : img_preprocess})
  37. ans = np.argmax(ans, axis=1)
  38. #ans = list(map(lambda x : x.decode(), ans))
  39. t1 = time()
  40. #plt.imshow(img[:,:,::-1])
  41. print(ans)
  42. print('seconds per frame %.2f ms' % ((t1-t0) * 1000 / bs))

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/154335
推荐阅读
相关标签
  

闽ICP备14008679号