当前位置:   article > 正文

Pytorch预训练模型下载并加载(以VGG为例)自定义路径_怎么更改vgg下载路径

怎么更改vgg下载路径

简述

一般来讲,Pytorch用torchvision调用vgg之类的模型话,如果电脑在cache(Pytorch硬编码的一个地址)(如果在环境变量中添加了TORCH_HOMETORCH_MODEL_ZOO的话,就是在这两个位置的联合的路径下,比如TORCH_MODEL_ZOO\model)否则就是在TORCH_HOME\models或者是~/.torch/models

比如,我的就是C:\Users\lijy2/.torch\models\vgg11-bbd30ac9.pth

这很有可能并不是我们想要的下载模型放的地址,或者是这样的下载方式很慢等等。

而且这个地址不可以很容易的直接调用,非常不方便。

这点,在我现在用pytorch版本还是github上的最新版本都是没有做类似的改进的。

但是这种设计(可能对我这种强迫症来说),是有需求的。

解决办法

首先,先处理下载的问题。

读了下源码,是使用import torch.utils.model_zoo as model_zoo里面的函数来加载数据。
整理了下源码中涉及的这一部分

from urllib.parse import urlparse
import torch.utils.model_zoo as model_zoo
import re
import os
def download_model(url, dst_path):
    parts = urlparse(url)
    filename = os.path.basename(parts.path)
    
    HASH_REGEX = re.compile(r'-([a-f0-9]*)\.')
    hash_prefix = HASH_REGEX.search(filename).group(1)
    
    model_zoo._download_url_to_file(url, os.path.join(dst_path, filename), hash_prefix, True)
    return filename
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

调用实例

model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}

import os

path = 'D:/Software/DataSet/models/vgg'
if not (os.path.exists(path)):
    os.makedirs(path)
for url in model_urls.values():
    download_model(url, path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

输出

100%|███████████████████████████████████████████████████████████████| 531456000/531456000 [01:14<00:00, 7114218.15it/s]
100%|███████████████████████████████████████████████████████████████| 532194478/532194478 [02:46<00:00, 3193007.52it/s]
100%|███████████████████████████████████████████████████████████████| 553433881/553433881 [01:13<00:00, 7536750.60it/s]
100%|██████████████████████████████████████████████████████████████| 574673361/574673361 [00:54<00:00, 10587712.79it/s]
100%|███████████████████████████████████████████████████████████████| 531503671/531503671 [01:10<00:00, 7548305.64it/s]
100%|███████████████████████████████████████████████████████████████| 532246301/532246301 [01:35<00:00, 5598996.73it/s]
100%|██████████████████████████████████████████████████████████████| 553507836/553507836 [00:50<00:00, 10900603.60it/s]
100%|███████████████████████████████████████████████████████████████| 574769405/574769405 [01:11<00:00, 8023263.07it/s]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

其他的模型地址,可以打开github里面对应模型的代码,一打开就看到了。

https://github.com/pytorch/vision/tree/master/torchvision/models
  • 1

再看看加载

import glob
import os
def load_model(model_name, model_dir):
    model  = eval('models.%s(init_weights=False)' % model_name)
    path_format = os.path.join(model_dir, '%s-[a-z0-9]*.pth' % model_name)
    
    model_path = glob.glob(path_format)[0]
    
    model.load_state_dict(torch.load(model_path))
    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

使用实例:

model_dir = 'D:/Software/DataSet/models/vgg/'
model = load_model('vgg11', model_dir)
  • 1
  • 2

输出:

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (19): ReLU(inplace)
    (20): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/683280
推荐阅读
相关标签
  

闽ICP备14008679号