当前位置:   article > 正文

tensorflow实现resnet50(训练+测试+模型转换)_tensorflow resnet50

tensorflow resnet50

本章使用tensorflow训练resnet50,使用手写数字图片作为数据集。

数据集:

代码工程:

1.train.py

  1. import argparse
  2. import cv2
  3. import tensorflow as tf
  4. # from create_model import resnet_v2_50
  5. from create_model import resnet_v2_50
  6. import numpy as np
  7. from data_loader import get_data, get_data_list
  8. from sklearn.metrics import accuracy_score
  9. def txt_save(data, output_file):
  10. file = open(output_file, 'a')
  11. for i in data:
  12. s = str(i) + '\n'
  13. file.write(s)
  14. file.close()
  15. def get_parms():
  16. parser = argparse.ArgumentParser(description='')
  17. parser.add_argument('--train_data', type=str, default="dataset/train_data.txt")
  18. parser.add_argument('--test_data', type=str, default='data/test/')
  19. parser.add_argument('--checkpoint_dir', type=str, default='./model/')
  20. parser.add_argument('--epoch', type=int, default=5)
  21. parser.add_argument('--batch_size', type=int, default=1)
  22. parser.add_argument('--save_epoch', type=int, default=1)
  23. args=parser.parse_args()
  24. return args
  25. args = get_parms()
  26. inputs=tf.placeholder(tf.float32,(None,28,28,1), name='input_images')
  27. labels = tf.placeholder(tf.int64, [None, 10])
  28. net,endpoins=resnet_v2_50(inputs,10) #['predictions']
  29. # with tf.variable_scope('finetune'):
  30. logit = tf.nn.softmax(net)[0][0]
  31. pred=tf.argmax(logit,1)
  32. correct_predicion=tf.equal(tf.argmax(logit,1),tf.argmax(labels,1))
  33. accuracy=tf.reduce_mean(tf.cast(correct_predicion,'float'))
  34. cross_entropy = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logit)
  35. cross_entropy_cost = tf.reduce_mean(cross_entropy)
  36. train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy_cost) #Resnet_v2_50和自己构建的模型层都训练
  37. # 开始训练
  38. saver=tf.train.Saver()
  39. with tf.Session() as sess:
  40. sess.run(tf.global_variables_initializer())
  41. img_list, label_list=get_data_list(args.train_data)
  42. for i in range(args.epoch):
  43. loss_list=[]
  44. acc_list=[]
  45. for j in range(int(len(label_list)/args.batch_size)):
  46. # for j in range(10):
  47. data, true_label=get_data(img_list, label_list, args.batch_size)
  48. _, loss, acc=sess.run([train_step, cross_entropy_cost, accuracy], feed_dict={inputs:data, labels:true_label})
  49. # _=sess.run([train_step], feed_dict={inputs:data, labels:true_label})
  50. # print(loss, acc)
  51. # a,b=sess.run([pred,logit], feed_dict={inputs:data, labels:true_label})
  52. # print(a,b)
  53. print('epoch:',i, 'loss:',np.mean(loss), "acc:", acc)
  54. if i % args.save_epoch==0:
  55. # saver.save(sess,"model/model.ckpt",global_step=i)
  56. saver.save(sess, "model/model")
  57. # tensorboard --logdir=d:/log --host=127.0.0.1
  58. init = tf.initialize_all_variables()
  59. merged = tf.summary.merge_all()
  60. writer = tf.summary.FileWriter("d:/log/",sess.graph) #目录结构尽量简单,复杂了容易出现找不到文件,原因不清楚
  61. sess.run(init)
  62. # variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
  63. # for i in variables:
  64. # print(i) # 打印
  65. # txt_save(variables, "model/log.txt") # 保存txt 二选一
  66. # import tensorflow as tf
  67. # img_list, label_list=get_data_list(args.train_data)
  68. # data, true_label=get_data(img_list, label_list, args.batch_size)
  69. # with tf.Session() as sess:
  70. # saver = tf.train.import_meta_graph('model/model.meta')
  71. # saver.restore(sess,tf.train.latest_checkpoint('model/'))
  72. # pred,logit=sess.run([pred,logit], feed_dict={inputs:data, labels:true_label})
  73. # print(pred)
  74. # ##Model has been restored. Above statement will print the saved value

2.create_model.py

  1. import tensorflow as tf
  2. import os
  3. import numpy as np
  4. import cv2
  5. import argparse
  6. def model(inputs):
  7. w1=tf.Variable(tf.random_normal([3, 3, 1, 32], stddev=0.01))
  8. w2=tf.Variable(tf.random_normal([3, 3, 32, 64], stddev=0.01))
  9. w3=tf.Variable(tf.random_normal([3, 3, 64, 128], stddev=0.01))
  10. w4=tf.Variable(tf.random_normal([2048, 625], stddev=0.01))
  11. w5=tf.Variable(tf.random_normal([625, 10], stddev=0.01))
  12. l1_conv=tf.nn.conv2d(inputs, w1, strides=[1, 1, 1, 1], padding='SAME')
  13. l1_relu=tf.nn.relu(l1_conv)
  14. l1_pool=tf.nn.max_pool(l1_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  15. l1_drop = tf.nn.dropout(l1_pool, 0.5)
  16. l2_conv=tf.nn.conv2d(l1_drop, w2, strides=[1, 1, 1, 1], padding='SAME')
  17. l2_relu=tf.nn.relu(l2_conv)
  18. l2_pool=tf.nn.max_pool(l2_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  19. l2_drop = tf.nn.dropout(l2_pool, 0.5)
  20. l3_conv=tf.nn.conv2d(l2_drop, w3, strides=[1, 1, 1, 1], padding='SAME')
  21. l3_relu=tf.nn.relu(l3_conv)
  22. l3_pool=tf.nn.max_pool(l3_relu, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  23. l3_out = tf.reshape(l3_pool, [-1, 2048])
  24. l3_drop = tf.nn.dropout(l3_out, 0.5)
  25. l4 = tf.nn.relu(tf.matmul(l3_drop, w4))
  26. l4 = tf.nn.dropout(l4, 0.5)
  27. out = tf.matmul(l4, w5)
  28. return out
  29. def model2(inputs):
  30. w1=tf.Variable(tf.random_normal([5, 5, 1, 6], stddev=0.01))
  31. b1 = tf.Variable(tf.truncated_normal([6]))
  32. w2=tf.Variable(tf.random_normal([5, 5, 6, 16], stddev=0.01))
  33. b2 = tf.Variable(tf.truncated_normal([16]))
  34. w3=tf.Variable(tf.random_normal([5, 5, 16, 120], stddev=0.01))
  35. b3 = tf.Variable(tf.truncated_normal([120]))
  36. w4 = tf.Variable(tf.truncated_normal([7 * 7 * 120, 80]))
  37. b4 = tf.Variable(tf.truncated_normal([80]))
  38. w5 = tf.Variable(tf.truncated_normal([80, 10]))
  39. b5 = tf.Variable(tf.truncated_normal([10]))
  40. l1_conv=tf.nn.conv2d(inputs, w1, strides=[1, 1, 1, 1], padding='SAME')
  41. l1_sigmoid=tf.nn.sigmoid(l1_conv+b1)
  42. l1_pool=tf.nn.max_pool(l1_sigmoid, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  43. l2_conv=tf.nn.conv2d(l1_pool, w2, strides=[1, 1, 1, 1], padding='SAME')
  44. l2_sigmoid=tf.nn.sigmoid(l2_conv+b2)
  45. l2_pool=tf.nn.max_pool(l2_sigmoid, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
  46. l3_conv=tf.nn.conv2d(l2_pool, w3, strides=[1, 1, 1, 1], padding='SAME')
  47. l3_sigmoid=tf.nn.sigmoid(l3_conv+b3)
  48. l3_out = tf.reshape(l3_sigmoid, [-1, 7*7*120])
  49. l4 = tf.nn.sigmoid(tf.matmul(l3_out, w4)+b4)
  50. out = tf.nn.softmax(tf.matmul(l4, w5) + b5)
  51. return out
  52. from datetime import datetime
  53. import time
  54. import math
  55. import collections
  56. import tensorflow as tf
  57. slim = tf.contrib.slim
  58. # 使用collections.namedtuple设计ResNet的Block模块
  59. # scope参数是block的名称
  60. # unit_fn是功能单元(如残差单元)
  61. # args是一个列表,如([256, 64, 1]) X 2 + [256, 64, 2]),代表两个(256, 64, 1)单元
  62. # 和一个(256, 64, 2)单元
  63. Block = collections.namedtuple("Block", ['scope', 'unit_fn', 'args'])
  64. # 定义下采样的方法,通过max_pool2d实现
  65. def subsample(inputs, factor, scope=None):
  66. if factor == 1:
  67. return inputs
  68. else:
  69. return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
  70. # 定义一个创建卷积层的函数
  71. def conv2d_same(inputs, num_outputs, kernel_size, stride, scope=None):
  72. if stride == 1:
  73. return slim.conv2d(inputs, num_outputs, kernel_size, stride=1,
  74. padding='SAME', scope=scope)
  75. else:
  76. pad_total = kernel_size - 1
  77. pad_beg = pad_total // 2
  78. pad_end = pad_total - pad_beg
  79. inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end],
  80. [pad_beg, pad_end], [0, 0]])
  81. return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
  82. padding='VALID', scope=scope)
  83. # 定义堆叠的block函数
  84. @slim.add_arg_scope
  85. def stack_blocks_dense(net, blocks, outputs_collections=None):
  86. for block in blocks:
  87. with tf.variable_scope(block.scope, 'block', [net]) as sc:
  88. for i, unit in enumerate(block.args):
  89. with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
  90. unit_depth, unit_depth_bottleneck, unit_stride = unit
  91. net = block.unit_fn(net,
  92. depth=unit_depth,
  93. depth_bottleneck=unit_depth_bottleneck,
  94. stride=unit_stride)
  95. net = slim.utils.collect_named_outputs(outputs_collections, sc.name,net)
  96. return net
  97. # 用于设定默认值
  98. def resnet_arg_scope(is_training=True,
  99. weight_decay=0.0001,
  100. batch_norm_decay=0.997,
  101. batch_norm_epsilon=1e-5,
  102. batch_norm_scale=True):
  103. batch_norm_params = {
  104. 'is_training': is_training,
  105. 'decay': batch_norm_decay,
  106. 'epsilon': batch_norm_epsilon,
  107. 'scale': batch_norm_scale,
  108. 'updates_collections': tf.GraphKeys.UPDATE_OPS,
  109. }
  110. with slim.arg_scope(
  111. [slim.conv2d],
  112. weights_regularizer=slim.l2_regularizer(weight_decay),
  113. weights_initializer=slim.variance_scaling_initializer(),
  114. activation_fn=tf.nn.relu,
  115. normalizer_fn=slim.batch_norm,
  116. normalizer_params=batch_norm_params
  117. ):
  118. with slim.arg_scope([slim.batch_norm], **batch_norm_params):
  119. with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
  120. return arg_sc
  121. # 定义残差学习单元
  122. @slim.add_arg_scope
  123. def bottleneck(inputs, depth, depth_bottleneck, stride,
  124. outputs_collections=None, scope=None):
  125. with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc:
  126. depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
  127. preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu,
  128. scope='preact')
  129. # shortcut为直连的X
  130. if depth == depth_in:
  131. shortcut = subsample(inputs, stride, 'shortcut')
  132. else:
  133. shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride,
  134. normalizer_fn=None, activation_fn=None,
  135. scope='shortcut')
  136. residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1,
  137. scope='conv1')
  138. residual = conv2d_same(residual, depth_bottleneck, 3, stride,
  139. scope='conv2')
  140. residual = slim.conv2d(residual, depth, [1, 1], stride=1,
  141. normalizer_fn=None, activation_fn=None,
  142. scope='conv3')
  143. # 将直连的X加到残差上,得到output
  144. output = shortcut + residual
  145. return slim.utils.collect_named_outputs(outputs_collections,
  146. sc.name, output)
  147. # 定义ResNet的主函数
  148. def resnet_v2(inputs,
  149. blocks,
  150. num_classes=None,
  151. global_pool=True,
  152. include_root_block=True,
  153. reuse=None,
  154. scope=None):
  155. with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc:
  156. end_points_collection = sc.original_name_scope + '_end_points'
  157. with slim.arg_scope([slim.conv2d, bottleneck,stack_blocks_dense],outputs_collections=end_points_collection):
  158. net = inputs
  159. if include_root_block:
  160. with slim.arg_scope([slim.conv2d], activation_fn=None,normalizer_fn=None):
  161. net = conv2d_same(net, 64, 7, stride=2, scope='conv1')
  162. net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
  163. net = stack_blocks_dense(net, blocks)
  164. net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm')
  165. if global_pool:
  166. net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
  167. if num_classes is not None:
  168. net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,normalizer_fn=None, scope='logits')
  169. end_points = slim.utils.convert_collection_to_dict(end_points_collection)
  170. if num_classes is not None:
  171. end_points['predictions'] = slim.softmax(net, scope='predictions')
  172. return net, end_points
  173. # 定义50层的ResNet
  174. def resnet_v2_50(inputs,
  175. num_classes=None,
  176. global_pool=True,
  177. reuse=None,
  178. scope='resnet_v2_50'):
  179. blocks = [
  180. Block('block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
  181. Block('block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
  182. Block('block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]),
  183. Block('block4', bottleneck, [(2048, 512, 1)] * 3)]
  184. return resnet_v2(inputs, blocks, num_classes, global_pool,include_root_block=True, reuse=reuse, scope=scope)

3.data_loader.py

  1. import cv2
  2. import os
  3. import numpy as np
  4. import random
  5. # f1=open('dataset/train_data.txt','w+')
  6. # path='dataset/'
  7. # for file in os.listdir(path):
  8. # if file.endswith('png'):
  9. # line=path+file+' 1'+'\n'
  10. # f1.write(line)
  11. # print(line)
  12. def one_hot(data, num_classes):
  13. return np.squeeze(np.eye(num_classes)[data.reshape(-1)])
  14. def get_data_list(path):
  15. f1=open(path,'r')
  16. lines=f1.readlines()
  17. img_list=[]
  18. label_list=[]
  19. for line in lines:
  20. label=int(line.strip().split(" ")[1])
  21. label=one_hot(np.array(label),10)
  22. label_list.append(label)
  23. file_name=line.strip().split(" ")[0]
  24. img=cv2.imread(file_name, 0)
  25. img=np.reshape(img,[28,28,1])
  26. # print(img.shape)
  27. img_list.append(img)
  28. return img_list, label_list
  29. def get_data(img_list, label_list, batch_size):
  30. lens=len(label_list)
  31. random_nums=random.sample(range(lens),lens)
  32. nums=random_nums[0:batch_size]
  33. # print(nums)
  34. data=[]
  35. label=[]
  36. for index in nums:
  37. data.append(img_list[index])
  38. label.append(label_list[index])
  39. return np.array(data), np.array(label)
  40. # batch_size=1
  41. # path="dataset/mnist/train/train_data.txt"
  42. # img_list, label_list=get_data_list(path)
  43. # print(len(img_list), len(label_list))
  44. # data, label=get_data(img_list, label_list, batch_size)
  45. # print(type(data[0]),label)

4.查看ckpt网络.py

  1. from tensorflow.python import pywrap_tensorflow
  2. import os
  3. import tensorflow as tf
  4. from tensorflow.python.platform import gfile
  5. # checkpoint_path = os.path.join('model/model.ckpt')
  6. # reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
  7. # var_to_shape_map = reader.get_variable_to_shape_map()
  8. # for key in var_to_shape_map:
  9. # print('tensor_name: ', key)
  10. ckpt_path = os.path.join('model/model')
  11. saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
  12. graph = tf.get_default_graph()
  13. with tf.Session( graph=graph) as sess:
  14. sess.run(tf.global_variables_initializer())
  15. saver.restore(sess,ckpt_path)
  16. # tensorboard --logdir=d:/log --host=127.0.0.1
  17. init = tf.initialize_all_variables()
  18. merged = tf.summary.merge_all()
  19. writer = tf.summary.FileWriter("d:/log/",sess.graph) #目录结构尽量简单,复杂了容易出现找不到文件,原因不清楚
  20. sess.run(init)

5.ckpt2pb.py

  1. import tensorflow as tf
  2. from tensorflow.python.framework import graph_util
  3. from tensorflow.python import pywrap_tensorflow
  4. def freeze_graph(cpkt_path, pb_path):
  5. checkpoint = tf.train.get_checkpoint_state("model") #检查目录下ckpt文件状态是否可用
  6. cpkt_path2 = checkpoint.model_checkpoint_path #得ckpt文件路径
  7. print("gg:",checkpoint,cpkt_path2)
  8. # 指定输出的节点名称,该节点名称必须是原模型中存在的节点
  9. output_node_names = "resnet_v2_50/logits/BiasAdd"
  10. saver = tf.train.import_meta_graph(cpkt_path + '.meta', clear_devices=True)
  11. graph = tf.get_default_graph()
  12. input_graph_def = graph.as_graph_def()
  13. # feature_data_list = input_graph_def.get_operation_by_name('resnet_v2_50/conv1').outputs[0]
  14. # input_image=tf.placeholder(None,28,28,1)
  15. with tf.Session() as sess:
  16. saver.restore(sess, cpkt_path) # 恢复图并得到数据
  17. pb_path_def = graph_util.convert_variables_to_constants( # 模型持久化,将变量值固定
  18. sess=sess,
  19. input_graph_def=input_graph_def, # 等于:sess.graph_def
  20. output_node_names=output_node_names.split(",")) # 如果有多个输出节点,以逗号隔开
  21. # print(pb_path_def)
  22. with tf.gfile.GFile(pb_path, 'wb') as fgraph:
  23. fgraph.write(pb_path_def.SerializeToString())
  24. # with tf.io.gfile.GFile(pb_path, "wb") as f: # 保存模型
  25. # f.write(pb_path_def.SerializeToString()) # 序列化输出
  26. print("%d ops in the final graph." % len(pb_path_def.node)) # 得到当前图有几个操作节点
  27. if __name__ == '__main__':
  28. # 输入路径(cpkt)
  29. cpkt_path = 'model/model'
  30. # 输出路径(pb模型)pb_path_def
  31. pb_path = "model/test.pb"
  32. # 模型转换
  33. freeze_graph(cpkt_path, pb_path)
  34. # # 查看节点名称:
  35. # reader = pywrap_tensorflow.NewCheckpointReader(cpkt_path)
  36. # var_to_shape_map = reader.get_variable_to_shape_map()
  37. # for key in var_to_shape_map:
  38. # print("tensor_name: ", key)
  39. # # 查看某个指定节点的权重
  40. # reader = pywrap_tensorflow.NewCheckpointReader(cpkt_path)
  41. # var_to_shape_map = reader.get_variable_to_shape_map()
  42. # w0 = reader.get_tensor("finetune/dense_1/bias")
  43. # print(w0.shape, type(w0))
  44. # print(w0[0])
  45. # with tf.Session() as sess:
  46. # # 加载模型定义的graph
  47. # saver = tf.train.import_meta_graph('model/model.meta')
  48. # # 方式一:加载指定文件夹下最近保存的一个模型的数据
  49. # saver.restore(sess, tf.train.latest_checkpoint('model/'))
  50. # # 方式二:指定具体某个数据,需要注意的是,指定的文件不要包含后缀
  51. # # saver.restore(sess, os.path.join(path, 'model.ckpt-1000'))
  52. # # 查看模型中的trainable variables
  53. # tvs = [v for v in tf.trainable_variables()]
  54. # for v in tvs:
  55. # print(v.name)
  56. # # print(sess.run(v))
  57. # # # 查看模型中的所有tensor或者operations
  58. # # gv = [v for v in tf.global_variables()]
  59. # # for v in gv:
  60. # # print(v.name)
  61. # # # 获得几乎所有的operations相关的tensor
  62. # # ops = [o for o in sess.graph.get_operations()]
  63. # # for o in ops:
  64. # # print(o.name)

6.pb2pbtxt.py   (pb和pbtxt互转,修改input_shape)

修改pbtxt文件,把动态shape改成静态的,再转回pb模型

  1. import tensorflow as tf
  2. from tensorflow.python.platform import gfile
  3. from google.protobuf import text_format
  4. def convert_pb_to_pbtxt(root_path, pb_path, pbtxt_path):
  5. with gfile.FastGFile(root_path+pb_path, 'rb') as f:
  6. graph_def = tf.GraphDef()
  7. graph_def.ParseFromString(f.read())
  8. tf.import_graph_def(graph_def, name='')
  9. tf.train.write_graph(graph_def, root_path, pbtxt_path, as_text=True)
  10. return
  11. def convert_pbtxt_to_pb(root_path, pb_path, pbtxt_path):
  12. with tf.gfile.FastGFile(root_path+pbtxt_path, 'r') as f:
  13. graph_def = tf.GraphDef()
  14. file_content = f.read()
  15. # Merges the human-readable string in `file_content` into `graph_def`.
  16. text_format.Merge(file_content, graph_def)
  17. tf.train.write_graph(graph_def, root_path, pb_path, as_text=False)
  18. return
  19. if __name__ == '__main__':
  20. # 模型路径
  21. root_path = "model/"
  22. pb_path = "test.pb"
  23. pbtxt_path = "test.pbtxt"
  24. # 模型转换
  25. convert_pb_to_pbtxt(root_path, pb_path, pbtxt_path)
  26. # convert_pbtxt_to_pb(root_path, pb_path, pbtxt_path)

7.test_pb.py

  1. import tensorflow as tf
  2. from tensorflow.python.framework import graph_util
  3. from tensorflow.python import pywrap_tensorflow
  4. import numpy as np
  5. def recognize(jpg_path, pb_file_path):
  6. with tf.Graph().as_default():
  7. output_graph_def = tf.GraphDef()
  8. with open(pb_file_path, "rb") as f:#主要步骤即为以下标出的几步,1、2步即为读取图
  9. output_graph_def.ParseFromString(f.read())# 1.将模型文件解析为二进制放进graph_def对象
  10. _ = tf.import_graph_def(output_graph_def, name="")# 2.import到当前图
  11. with tf.Session() as sess:
  12. init = tf.global_variables_initializer()
  13. sess.run(init)
  14. graph = tf.get_default_graph()# 3.获得当前图
  15. # # 4.get_tensor_by_name获取需要的节点
  16. # x = graph.get_tensor_by_name("IteratorGetNext_1:0")
  17. # y_out = graph.get_tensor_by_name("resnet_v1_50_1/predictions/Softmax:0")
  18. x = graph.get_tensor_by_name("input_images:0")
  19. y_out = graph.get_tensor_by_name("resnet_v2_50/logits/BiasAdd:0")
  20. img=np.random.normal(size=(1, 28, 28, 1))
  21. # img=cv2.imread(jpg_path)
  22. # img=cv2.resize(img, (224, 224))
  23. # img=np.reshape(img,(1,224,224,3))
  24. # print(img.shape)
  25. #执行
  26. output = sess.run(y_out, feed_dict={x:img})
  27. pred=np.argmax(output[0][0], axis=1)
  28. print("预测结果:", output.shape, output, "预测label:", pred)
  29. # prediction_labels = np.argmax(test_y_out, axis=2)
  30. # print(prediction_labels.shape, prediction_labels)
  31. recognize("dataset/00000.PNG", "model/test.pb")

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号