当前位置:   article > 正文

使用pytorch进行迁移学习模型下载失败解决办法_torchvision.models模型下载失败

torchvision.models模型下载失败

使用pytorch进行迁移学习的时候,我们需要下载预训练的模型,但是这个模型通常很大,如果在代码中在线下载的话,很可能会中断,并且一中断之前也就白下载了,这篇文章里我介绍一种离线使用预训练模型的方法。

所谓离线使用预训练模型的方法,实际上就是使用浏览器将模型下载下来(通常浏览器下载会比较稳定,并且如果下载中断还能恢复),下面给出各种模型的下载地址,只需要将对应的url键入到浏览器中就可以建立下载

  1. 1. Resnet:
  2. model_urls = {
  3. 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
  4. 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
  5. 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
  6. 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
  7. 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
  8. }
  9. 2. inception:
  10. model_urls = {
  11. Inception v3 ported from TensorFlow
  12. 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
  13. }
  14. 3. Densenet:
  15. model_urls = {
  16. 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
  17. 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
  18. 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
  19. 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
  20. }
  21. 4. Alexnet:
  22. model_urls = {
  23. 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
  24. }
  25. 5. vggnet:
  26. model_urls = {
  27. 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
  28. 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
  29. 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
  30. 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
  31. 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
  32. 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
  33. 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
  34. 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
  35. }

下载完后我们就得到了我们想要的pth文件,使用方法如下,假设我们想要加载一个VGG16网络

  1. model_ft = models.vgg16(pretrained=False)
  2. pre = torch.load("vgg16-397923af.pth")
  3. model_ft.load_state_dict(pre)

model.vgg16()里面的参数pretrained要写成False,否则还是会在线下载模型,按照网上的说法,模型是加载在C:\Users\lvnianzu/.cache\torch\checkpoints 这样子的目录里面的,但是我的电脑上没有找到这个目录。

torch.load()里面的参数就是你要加载的模型的地址

下面用一个例子来演示加载和使用VGG16模型的整个过程

  1. import torch
  2. import torchvision.models as models
  3. # 加载模型
  4. model_ft = models.vgg16(pretrained=False)
  5. pre = torch.load("vgg16-397923af.pth")
  6. model_ft.load_state_dict(pre)
  7. from PIL import Image
  8. im = Image.open("dog.jpg")
  9. im = im.resize((224, 224))
  10. from torch.autograd import Variable
  11. import torchvision.transforms as tfs
  12. im = tfs.ToTensor()(im)
  13. # 将图片变成网络输出的维度
  14. tensor = Variable(torch.unsqueeze(im, dim=0).float(), requires_grad=False)
  15. print(tensor.shape)
  16. output = model_ft(tensor)
  17. print(output.shape)
  18. _,pred = torch.max(output, 1) # dim为1表示索引每行最大值
  19. print(pred)

最后得到的结果如下

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/351085
推荐阅读
相关标签
  

闽ICP备14008679号