当前位置:   article > 正文

resnet提取图片特征

resnet提取图片特征

下载模型

wget https://digix-algo-challenge.obs.cn-east-2.myhuaweicloud.com/2020/cv/6rKDTsB6sX8A1O2DA2IAq7TgHPdSPxJF/train_data.zip -o train_data.zip
  • 1

ResNet提取图片特征向量



from torch.autograd import Variable
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import pretrainedmodels
from PIL import Image

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

TARGET_IMG_SIZE = 224
img_to_tensor = transforms.ToTensor()
def get_seresnet50():
    encoder = pretrainedmodels.se_resnet50()
    model = nn.Sequential(encoder.layer0,
                          encoder.layer1,
                          encoder.layer2,
                          encoder.layer3,
                          encoder.layer4,
                          encoder.avg_pool   # 平均池化,张成一个[batchSize,2048]的特征向量
                          )
    for param in model.parameters():
        param.requires_grad = False
    # model.cuda()   # 使用GPU,CPU版去掉
    model.eval()
    return model

# 特征提取
def extract_feature(model, imgpath):
    img = Image.open(imgpath)  # 读取图片
    img = img.resize((TARGET_IMG_SIZE, TARGET_IMG_SIZE))

    tensor = img_to_tensor(img)  # 将图片矩阵转化成tensor
    # tensor = tensor.cuda()  # GPU 使用GPU放开此注释
    tensor = torch.unsqueeze(tensor, 0)
    result = model(Variable(tensor))
    result_npy = result.data.cpu().numpy()[0].ravel().tolist()
    return result_npy

if __name__ == '__main__':
    model = get_seresnet50()
    feature = extract_feature(model, "1.jpeg")
    print(len(feature))
    print(feature)
    # 下载模型可能会报错,将上面下载的模型放在下载的目录
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

参考

https://www.codenong.com/cs107121229/

参考图片特征提取

本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号