赞
踩
想通过缩略图找原图?之前p过的图像想找原图?汇报时使用压缩过的图像现在想找原图?如何从大量图像文件中快速找到与目标图像相似的那个?pytorch 只需要几行代码就可以搞定。
一般进行特征提取使用图像分类网络即可。参看上一篇 使用pytorch中的resnet预训练模型进行快速图像分类。
提取查询图像和候选图像的特征,计算二者的余弦相似度,相似度越大则图像越相似。输出图像的路径,将相似的图像保存到指定目录下。
# load model import torch import torchvision model = torchvision.models.resnet101(pretrained=True) # or any of these variants # resnet18, resnet34, resnet50, resnet101, resnet152 model.eval() from PIL import Image from torchvision import transforms from tqdm import tqdm # args path_to_query = 'target.jpg' # given one query image path_to_data = '/path/to/data/' # gallery images BATCH_SIZE = 256 target_dir = 'targetdir' # we can save similar images in target directory threshold = 0.8 # pick out if the cosine similarity > threshold. # build dataloader preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) test_data = torchvision.datasets.ImageFolder(path_to_data, preprocess) image_names = test_data.samples data_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE) # load model to GPU model.to('cuda') count = 0 result = [] # test and query with torch.no_grad(): # load query image input_image = Image.open(path_to_query) input_tensor = preprocess(input_image) input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model input_batch = input_batch.to('cuda') # build feature extractor resnet50_feature_extractor = model resnet50_feature_extractor.fc = torch.nn.Linear(2048,2048) # 512,512 2048,1024 ... # the size varies for different models. Refer to official implementations for the size of feature maps # ---以下几行必须要有:--- # torch.nn.init.eye_(resnet50_feature_extractor.module.fc.weight) # for parallel distributed training # torch.nn.init.eye_(resnet50_feature_extractor.module.fc.weight) torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias) torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias) for param in resnet50_feature_extractor.parameters(): param.requires_grad = False # --------------------- # extract feature resnet50_feature_extractor = resnet50_feature_extractor.cuda() q_feature = resnet50_feature_extractor(input_batch) # load gallery images for (x, y) in tqdm(data_loader, desc="Evaluating", leave=False): x = x.to('cuda') y = y.to('cuda') # extract fature output = resnet50_feature_extractor(x) # calculate cosine similarity to query similarity = torch.cosine_similarity(q_feature, output, dim=1) for index in range(output.shape[0]): if similarity[index] > threshold: result.append(image_names[count*BATCH_SIZE+index][0]) count += 1 # from shutil import copyfile # import os # os.makedirs(target_dir, exist_ok=True) # for r in result: # copyfile(r, target_dir+'/'+r.split('/')[-1]) # print the results for r in result: print(r)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。