赞
踩
wget https://digix-algo-challenge.obs.cn-east-2.myhuaweicloud.com/2020/cv/6rKDTsB6sX8A1O2DA2IAq7TgHPdSPxJF/train_data.zip -o train_data.zip
from torch.autograd import Variable import torch import torch.nn as nn import torchvision.transforms as transforms import pretrainedmodels from PIL import Image import ssl ssl._create_default_https_context = ssl._create_unverified_context TARGET_IMG_SIZE = 224 img_to_tensor = transforms.ToTensor() def get_seresnet50(): encoder = pretrainedmodels.se_resnet50() model = nn.Sequential(encoder.layer0, encoder.layer1, encoder.layer2, encoder.layer3, encoder.layer4, encoder.avg_pool # 平均池化,张成一个[batchSize,2048]的特征向量 ) for param in model.parameters(): param.requires_grad = False # model.cuda() # 使用GPU,CPU版去掉 model.eval() return model # 特征提取 def extract_feature(model, imgpath): img = Image.open(imgpath) # 读取图片 img = img.resize((TARGET_IMG_SIZE, TARGET_IMG_SIZE)) tensor = img_to_tensor(img) # 将图片矩阵转化成tensor # tensor = tensor.cuda() # GPU 使用GPU放开此注释 tensor = torch.unsqueeze(tensor, 0) result = model(Variable(tensor)) result_npy = result.data.cpu().numpy()[0].ravel().tolist() return result_npy if __name__ == '__main__': model = get_seresnet50() feature = extract_feature(model, "1.jpeg") print(len(feature)) print(feature) # 下载模型可能会报错,将上面下载的模型放在下载的目录
参考
https://www.codenong.com/cs107121229/
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。