赞
踩
加载分类模型hub.load_module_spec(“mobilenet_v2_100_224”)
from PIL import Image from matplotlib import pyplot as plt import numpy as np import tensorflow.compat.v1 as tf tf.disable_v2_behavior() import tensorflow_hub as hub with open('中文标签.csv','rb') as f: f1=f.read().decode('gbk') labels=list(map(lambda x:x.replace(',',' '),list(f1))) print(len(labels),type(labels),labels[:5]) sample_images=['hy.jpg','ps.jpg','72.jpg']#定义待测图片路径 #加载分类模型 module_spec=hub.load_module_spec("mobilenet_v2_100_224") #获取模型的输入图片尺寸 height,witdh=hub.get_expected_image_size(module_spec) input_imgs=tf.placeholder(tf.float32,[None,height,witdh,3])#占位符 images=2*(input_imgs/255.0)-1.0#归一化 module=hub.Module(module_spec)#将模型载入张量图 logits=module(images) y=tf.argmax(logits,axis=1) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) sess.run(tf.tables_initializer()) def preimg(img): #图片预处理函数 return np.asarray(img.resize((height,witdh)),dtype=np.float32).reshape(height,witdh,3) #获取原始图片与预处理图片 batchImg=[preimg(Image.open(imgfilename)) for imgfilename in sample_images] orgImg=[Image.open(imgfilename) for imgfilename in sample_images] #样本输入模型 # yv,img_norm=sess.run([y,images],feed_dict={input_imgs:batchImg}) print(input_imgs) yv,img_norm=sess.run([y,images],feed_dict={input_imgs: batchImg}) print(yv,np.shape(yv)) def showresult(yy,img_norm,img_org): t1 = [img_org, (img_norm * 255).astype(np.uint8)] titles=["organizaton image","input image"] plt.figure() for i in range(2): plt.subplot(1,2,i+1) plt.imshow(t1[i]) plt.axis('off') plt.title(titles[i]) plt.show() print(yy,labels[yy]) for yy,img1,img2 in zip(yv,batchImg,orgImg): showresult(yy,img1,img2)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。