当前位置:   article > 正文

使用pytorch中的resnet预训练模型进行特征提取,以及查找相似图像_resnet提取图像特征

resnet提取图像特征

想通过缩略图找原图?之前p过的图像想找原图?汇报时使用压缩过的图像现在想找原图?如何从大量图像文件中快速找到与目标图像相似的那个?pytorch 只需要几行代码就可以搞定。

模型的选取

一般进行特征提取使用图像分类网络即可。参看上一篇 使用pytorch中的resnet预训练模型进行快速图像分类

代码如下

提取查询图像和候选图像的特征,计算二者的余弦相似度,相似度越大则图像越相似。输出图像的路径,将相似的图像保存到指定目录下。

# load model
import torch
import torchvision
model = torchvision.models.resnet101(pretrained=True)
# or any of these variants
# resnet18, resnet34, resnet50, resnet101, resnet152
model.eval()


from PIL import Image
from torchvision import transforms
from tqdm import tqdm

# args
path_to_query = 'target.jpg' # given one query image
path_to_data = '/path/to/data/' # gallery images
BATCH_SIZE = 256
target_dir = 'targetdir' # we can save similar images in target directory
threshold = 0.8 # pick out if the cosine similarity > threshold.

# build dataloader
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
test_data = torchvision.datasets.ImageFolder(path_to_data, preprocess)
image_names = test_data.samples
data_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)

# load model to GPU
model.to('cuda')
count = 0
result = []

# test and query
with torch.no_grad():

    # load query image
    input_image = Image.open(path_to_query)
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
    input_batch = input_batch.to('cuda')
    
    # build feature extractor
    resnet50_feature_extractor = model
    resnet50_feature_extractor.fc = torch.nn.Linear(2048,2048)  # 512,512 2048,1024 ... 
    # the size varies for different models. Refer to official implementations for the size of feature maps
    
    # ---以下几行必须要有:---
    # torch.nn.init.eye_(resnet50_feature_extractor.module.fc.weight) # for parallel distributed training
    # torch.nn.init.eye_(resnet50_feature_extractor.module.fc.weight)
    torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
    torch.nn.init.zeros_(resnet50_feature_extractor.fc.bias)
    for param in resnet50_feature_extractor.parameters():
        param.requires_grad = False
    # ---------------------
    
    # extract feature
    resnet50_feature_extractor = resnet50_feature_extractor.cuda()
    q_feature = resnet50_feature_extractor(input_batch)
    
    # load gallery images
    for (x, y) in tqdm(data_loader, desc="Evaluating", leave=False):
        x = x.to('cuda')
        y = y.to('cuda')
        
        # extract fature
        output = resnet50_feature_extractor(x)
        
        # calculate cosine similarity to query
        similarity = torch.cosine_similarity(q_feature, output, dim=1)
        for index in range(output.shape[0]):
            if similarity[index] > threshold:
                result.append(image_names[count*BATCH_SIZE+index][0])
        count += 1
        
# from shutil import copyfile
# import os
# os.makedirs(target_dir, exist_ok=True)
# for r in result:
#     copyfile(r, target_dir+'/'+r.split('/')[-1])

# print the results
for r in result:
    print(r)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87

参考

pytorch-resnet 提取特征

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/289229
推荐阅读
相关标签
  

闽ICP备14008679号