当前位置:   article > 正文

PyTorch 使用过程中问题修复(不定期更新)_densenet121 = models.densenet121(pretrained=true)报

densenet121 = models.densenet121(pretrained=true)报错he parameter 'pretraine

1、预先下载了预训练模型,使用过程中报错

报错场景:在使用 densenet121 模型时,事先下载好模型,并存储到自定义目录下,当导入该模型时报错;

原程序片段

# net = models.densenet121(pretrained=True)  # 自动下载模型并存储到主目录的 .cache 目录下
net = models.densenet121(pretrained=False)   # 只加载模型结构
pthfile = "/mypath/densenet121-a639ec97.pth"
net.load_state_dict(state_dict)              # 通过文件加载参数
  • 1
  • 2
  • 3
  • 4

报错说明:RuntimeError: Error(s) in loading state_dict for DenseNet:

原因:导入的模型所使用的 pytorch 版本与现在所用的不一致,部分关键字可能随着版本的升级而更改

解决方案:参照 pretrained=True 的官方源码实现,因为它也是下载完了再加载,加载过程是一致的

# 源码加载代码片段
    if pretrained:
        # '.'s are no longer allowed in module names, but pervious _DenseLayer
        # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
        # They are also in the checkpoints in model_urls. This pattern is used
        # to find such keys.
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
        state_dict = model_zoo.load_url(model_urls['densenet121'])
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        model.load_state_dict(state_dict)
    return model

  
# 根据上述源码可知,部分参数的名字发生变化,所以稍加修改即可,修改后的代码为
# net = models.densenet121(pretrained=True)
net = models.densenet121(pretrained=False)
pthfile = "/mypath/densenet121-a639ec97.pth"
pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = torch.load(pthfile)
for key in list(state_dict.keys()):
    res = pattern.match(key)
    if res:
        new_key = res.group(1) + res.group(2)
        state_dict[new_key] = state_dict[key]
        del state_dict[key]
net.load_state_dict(state_dict)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

2、torch.max() 非法指令

程序在本地运行正常,放在服务器上遭遇非法指令,经过调试发现,出现问题的语句为 torch.max(a,1),在这里也有人遇到过这种问题,貌似都是出现在 torch 1.6.0 这个版本中,可能是机器的 CPU 太老,一些底层指令不兼容造成的,报错机器 CPU 信息如下:

$ lscpu
Architecture:          x86_64
CPU op-mode(s):        32-bit, 64-bit
Byte Order:            Little Endian
CPU(s):                8
On-line CPU(s) list:   0-7
Thread(s) per core:    1
Core(s) per socket:    4
座:                 2
NUMA 节点:         1
厂商 ID:           GenuineIntel
CPU 系列:          15
型号:              6
型号名称:        Common KVM processor
步进:              1
CPU MHz:             2659.998
BogoMIPS:            5319.99
超管理器厂商:  KVM
虚拟化类型:     完全
L1d 缓存:          32K
L1i 缓存:          32K
L2 缓存:           4096K
L3 缓存:           16384K
NUMA 节点0 CPU:    0-7
Flags:                 fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx lm constant_tsc nopl xtopology eagerfpu pni cx16 x2apic hypervisor lahf_lm
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
# 报错内容示例
>>> import torch
>>> a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
>>> a
tensor([[ 1,  5, 62, 54],
        [ 2,  6,  2,  6],
        [ 2, 65,  2,  6]])
>>> torch.max(a, 1)
非法指令
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

解决办法是降低版本至 1.5.0。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/217596
推荐阅读
相关标签
  

闽ICP备14008679号