当前位置:   article > 正文

检索出与指定图片相似度最高的n张图片_resnet50图片特征 不准确

resnet50图片特征 不准确

给定一张指定图片“22.png”,在指定图片库“database目录”中检索出与其相似度最高的3张图片。

1. 使用深度神经网络提取图片特征

1.1 vgg16提取图片特征

  1. # -*- coding: UTF-8 -*-
  2. import numpy as np
  3. import h5py
  4. import matplotlib.image as mpimg
  5. import matplotlib.pyplot as plt
  6. from keras.applications.vgg16 import VGG16
  7. from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
  8. from keras.preprocessing import image
  9. class VGGNet:
  10. def __init__(self):
  11. self.input_shape = (224, 224, 3)
  12. self.weight = 'imagenet' # None代表随机初始化,即不加载预训练权重
  13. self.pooling = 'max' # avg
  14. self.model_vgg = VGG16(weights=self.weight,
  15. input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
  16. pooling=self.pooling,
  17. include_top=False)
  18. # self.model_vgg.predict(np.zeros((1, 224, 224, 3)))
  19. # 提取vgg16最后一层卷积特征( Use vgg16/Resnet model to extract features Output normalized feature vector)
  20. def vgg_extract_feat(self, img_path):
  21. img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
  22. img = image.img_to_array(img)
  23. img = np.expand_dims(img, axis=0)
  24. img = preprocess_input_vgg(img)
  25. feat = self.model_vgg.predict(img)
  26. # print(feat.shape)
  27. norm_feat = feat[0] / np.linalg.norm(feat[0])
  28. return norm_feat

1.2 resnet50提取图片特征

  1. from keras.applications.resnet50 import ResNet50
  2. from keras.applications.resnet50 import preprocess_input as preprocess_input_resnet
  3. self.model_resnet = ResNet50(weights = self.weight,
  4. input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]),
  5. pooling = self.pooling,
  6. include_top = False)
  7. # 提取resnet50最后一层卷积特征
  8. def resnet_extract_feat(self, img_path):
  9. img = preprocess_input_resnet(img)
  10. feat = self.model_resnet.predict(img)

1.3 densenet121提取图片特征

  1. from keras.applications.densenet import DenseNet121
  2. from keras.applications.densenet import preprocess_input as preprocess_input_densenet
  3. self.model_densenet = DenseNet121(weights = self.weight,
  4. input_shape = (self.input_shape[0], self.input_shape[1], self.input_shape[2]),
  5. pooling = self.pooling,
  6. include_top = False)
  7. # 提取densenet121最后一层卷积特征
  8. def densenet_extract_feat(self, img_path):
  9. img = preprocess_input_densenet(img)
  10. feat = self.model_densenet.predict(img)

2. 从图像库抽取特征

使用深度神经网络从database目录提取每张图片的name和feature。

  1. def save_features():
  2. database = 'database'
  3. # directory for storing extracted features
  4. index = 'models/vgg_featureCNN.h5'
  5. # Returns a list of filenames for all jpg images in a directory.
  6. img_list = [os.path.join(database, f) for f in os.listdir(database) if f.endswith('.jpg')]
  7. feats, names = extract_features_and_images_index(img_list)
  8. # writing feature extraction results
  9. h5f = h5py.File(index, 'w')
  10. h5f.create_dataset('dataset_1', data=np.array(feats))
  11. h5f.create_dataset('dataset_2', data=np.string_(names))
  12. h5f.close()
'
运行

结果写入hdf5

  1. def extract_features_and_images_index(img_path_list):
  2. feats = []
  3. names = []
  4. model = VGGNet()
  5. for i, img_path in enumerate(img_path_list):
  6. norm_feat = model.vgg_extract_feat(img_path) # 修改此处改变提取特征的网络
  7. img_name = os.path.split(img_path)[1]
  8. feats.append(norm_feat)
  9. names.append(img_name)
  10. print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(img_list)))
  11. return feats, names
'
运行

3. 加载特征检索相似图片

检索出三张相似度最高的图片

  1. def get_similarity_top3_picture(image_path='22.png', maxres=3):
  2. path = 'models/vgg_featureCNN.h5'
  3. feats, names = get_feature_from_hdf5(path)
  4. # init VGGNet16 model
  5. model = VGGNet()
  6. # extract query image's feature, compute simlarity score and sort
  7. img_feat = model.vgg_extract_feat(image_path) # 修改此处改变提取特征的网络
  8. scores = np.dot(img_feat, feats.T)
  9. # scores = np.dot(img_feat, feats.T)/(np.linalg.norm(img_feat)*np.linalg.norm(feats.T))
  10. rank_ID = np.argsort(scores)[::-1]
  11. rank_score = scores[rank_ID]
  12. print (rank_ID)
  13. # [0 3 1 2]
  14. print(rank_score)
  15. # [0.5255763 0.5209291 0.4861027 0.4736392]
  16. # number of top retrieved images to show
  17. imlist = []
  18. for i, index in enumerate(rank_ID[0:maxres]):
  19. imlist.append(names[index])
  20. print("image names: " + str(names[i]) + " scores: %f" % rank_score[i])
  21. # show top #maxres retrieved result one by one
  22. plot_img(i, index)
  23. top_1_score = rank_score[0]
  24. top_1_md5 = str(imlist[0]).split(".")[0].split("'")[1].strip()
  25. return [top_1_md5, top_1_score]
  26. # ['bf43ddd28d6a2544b4ba8f95002674ed', '0.5255763']
'
运行

从hdf5读取特征

  1. def get_feature_from_hdf5(path):
  2. # read in indexed images' feature vectors and corresponding image names
  3. path = 'models/vgg_featureCNN.h5'
  4. h5f = h5py.File(path, 'r')
  5. feats = h5f['dataset_1'][:]
  6. names = h5f['dataset_2'][:]
  7. h5f.close()
  8. return feats, names
'
运行

显示图片

  1. import matplotlib.image as mpimg
  2. def plot_img(i, index):
  3. image = mpimg.imread('database/' + str(index, 'utf-8'))
  4. plt.title("search output %d" % (i + 1))
  5. plt.imshow(image)
  6. plt.show()
'
运行

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号