赞
踩
给定一张指定图片“22.png”,在指定图片库“database目录”中检索出与其相似度最高的3张图片。
- # -*- coding: UTF-8 -*-
- import numpy as np
- import h5py
- import matplotlib.image as mpimg
- import matplotlib.pyplot as plt
- from keras.applications.vgg16 import VGG16
- from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
- from keras.preprocessing import image
-
- class VGGNet:
- def __init__(self):
- self.input_shape = (224, 224, 3)
- self.weight = 'imagenet' # None代表随机初始化,即不加载预训练权重
- self.pooling = 'max' # avg
- self.model_vgg = VGG16(weights=self.weight,
- input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
- pooling=self.pooling,
- include_top=False)
- # self.model_vgg.predict(np.zeros((1, 224, 224, 3)))
-
- # 提取vgg16最后一层卷积特征( Use vgg16/Resnet model to extract features Output normalized feature vector)
- def vgg_extract_feat(self, img_path):
- img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
- img = image.img_to_array(img)
- img = np.expand_dims(img, axis=0)
- img = preprocess_input_vgg(img)
- feat = self.model_vgg.predict(img)
- # print(feat.shape)
- norm_feat = feat[0] / np.linalg.norm(feat[0])
- return norm_feat
![](https://csdnimg.cn/release/blogv2/dist/pc/img/newCodeMoreWhite.png)
- from keras.applications.resnet50 import ResNet50
- from keras.applications.resnet50 import preprocess_input as preprocess_input_resnet
-
- self.model_resnet = ResNet50(weights = self.weight,
- input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]),
- pooling = self.pooling,
- include_top = False)
-
- # 提取resnet50最后一层卷积特征
- def resnet_extract_feat(self, img_path):
- img = preprocess_input_resnet(img)
- feat = self.model_resnet.predict(img)
- from keras.applications.densenet import DenseNet121
- from keras.applications.densenet import preprocess_input as preprocess_input_densenet
-
- self.model_densenet = DenseNet121(weights = self.weight,
- input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]),
- pooling = self.pooling,
- include_top = False)
-
- # 提取densenet121最后一层卷积特征
- def densenet_extract_feat(self, img_path):
- img = preprocess_input_densenet(img)
- feat = self.model_densenet.predict(img)
使用深度神经网络从database目录提取每张图片的name和feature。
- def save_features():
- database = 'database'
- # directory for storing extracted features
- index = 'models/vgg_featureCNN.h5'
- # Returns a list of filenames for all jpg images in a directory.
- img_list = [os.path.join(database, f) for f in os.listdir(database) if f.endswith('.jpg')]
- feats, names = extract_features_and_images_index(img_list)
-
- # writing feature extraction results
- h5f = h5py.File(index, 'w')
- h5f.create_dataset('dataset_1', data=np.array(feats))
- h5f.create_dataset('dataset_2', data=np.string_(names))
- h5f.close()
'运行
结果写入hdf5
- def extract_features_and_images_index(img_path_list):
- feats = []
- names = []
- model = VGGNet()
- for i, img_path in enumerate(img_path_list):
- norm_feat = model.vgg_extract_feat(img_path) # 修改此处改变提取特征的网络
- img_name = os.path.split(img_path)[1]
- feats.append(norm_feat)
- names.append(img_name)
- print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(img_list)))
- return feats, names
'运行
检索出三张相似度最高的图片
- def get_similarity_top3_picture(image_path='22.png', maxres=3):
- path = 'models/vgg_featureCNN.h5'
- feats, names = get_feature_from_hdf5(path)
- # init VGGNet16 model
- model = VGGNet()
-
- # extract query image's feature, compute simlarity score and sort
- img_feat = model.vgg_extract_feat(image_path) # 修改此处改变提取特征的网络
- scores = np.dot(img_feat, feats.T)
- # scores = np.dot(img_feat, feats.T)/(np.linalg.norm(img_feat)*np.linalg.norm(feats.T))
- rank_ID = np.argsort(scores)[::-1]
- rank_score = scores[rank_ID]
- print (rank_ID)
- # [0 3 1 2]
- print(rank_score)
- # [0.5255763 0.5209291 0.4861027 0.4736392]
-
- # number of top retrieved images to show
- imlist = []
- for i, index in enumerate(rank_ID[0:maxres]):
- imlist.append(names[index])
- print("image names: " + str(names[i]) + " scores: %f" % rank_score[i])
- # show top #maxres retrieved result one by one
- plot_img(i, index)
-
- top_1_score = rank_score[0]
- top_1_md5 = str(imlist[0]).split(".")[0].split("'")[1].strip()
- return [top_1_md5, top_1_score]
- # ['bf43ddd28d6a2544b4ba8f95002674ed', '0.5255763']
'运行
从hdf5读取特征
- def get_feature_from_hdf5(path):
- # read in indexed images' feature vectors and corresponding image names
- path = 'models/vgg_featureCNN.h5'
- h5f = h5py.File(path, 'r')
- feats = h5f['dataset_1'][:]
- names = h5f['dataset_2'][:]
- h5f.close()
- return feats, names
'运行
显示图片
- import matplotlib.image as mpimg
-
- def plot_img(i, index):
- image = mpimg.imread('database/' + str(index, 'utf-8'))
- plt.title("search output %d" % (i + 1))
- plt.imshow(image)
- plt.show()
'运行
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。