赞
踩
使用pytorch进行迁移学习的时候,我们需要下载预训练的模型,但是这个模型通常很大,如果在代码中在线下载的话,很可能会中断,并且一中断之前也就白下载了,这篇文章里我介绍一种离线使用预训练模型的方法。
所谓离线使用预训练模型的方法,实际上就是使用浏览器将模型下载下来(通常浏览器下载会比较稳定,并且如果下载中断还能恢复),下面给出各种模型的下载地址,只需要将对应的url键入到浏览器中就可以建立下载
- 1. Resnet:
- model_urls = {
- 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
- 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
- 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
- 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
- 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
- }
-
- 2. inception:
- model_urls = {
- Inception v3 ported from TensorFlow
- 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
- }
-
- 3. Densenet:
- model_urls = {
- 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
- 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
- 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
- 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
- }
-
- 4. Alexnet:
- model_urls = {
- 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
- }
-
- 5. vggnet:
- model_urls = {
- 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
- 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
- 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
- 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
- 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
- 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
- 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
- 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
- }
下载完后我们就得到了我们想要的pth文件,使用方法如下,假设我们想要加载一个VGG16网络
- model_ft = models.vgg16(pretrained=False)
- pre = torch.load("vgg16-397923af.pth")
- model_ft.load_state_dict(pre)
model.vgg16()里面的参数pretrained要写成False,否则还是会在线下载模型,按照网上的说法,模型是加载在C:\Users\lvnianzu/.cache\torch\checkpoints 这样子的目录里面的,但是我的电脑上没有找到这个目录。
torch.load()里面的参数就是你要加载的模型的地址
下面用一个例子来演示加载和使用VGG16模型的整个过程
- import torch
- import torchvision.models as models
-
-
- # 加载模型
- model_ft = models.vgg16(pretrained=False)
- pre = torch.load("vgg16-397923af.pth")
- model_ft.load_state_dict(pre)
-
-
- from PIL import Image
- im = Image.open("dog.jpg")
- im = im.resize((224, 224))
-
- from torch.autograd import Variable
-
- import torchvision.transforms as tfs
- im = tfs.ToTensor()(im)
- # 将图片变成网络输出的维度
- tensor = Variable(torch.unsqueeze(im, dim=0).float(), requires_grad=False)
- print(tensor.shape)
- output = model_ft(tensor)
- print(output.shape)
- _,pred = torch.max(output, 1) # dim为1表示索引每行最大值
- print(pred)
最后得到的结果如下
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。