赞
踩
报错场景:在使用 densenet121 模型时,事先下载好模型,并存储到自定义目录下,当导入该模型时报错;
原程序片段:
# net = models.densenet121(pretrained=True) # 自动下载模型并存储到主目录的 .cache 目录下
net = models.densenet121(pretrained=False) # 只加载模型结构
pthfile = "/mypath/densenet121-a639ec97.pth"
net.load_state_dict(state_dict) # 通过文件加载参数
报错说明:RuntimeError: Error(s) in loading state_dict for DenseNet:
原因:导入的模型所使用的 pytorch 版本与现在所用的不一致,部分关键字可能随着版本的升级而更改
解决方案:参照 pretrained=True
的官方源码实现,因为它也是下载完了再加载,加载过程是一致的
# 源码加载代码片段 if pretrained: # '.'s are no longer allowed in module names, but pervious _DenseLayer # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # They are also in the checkpoints in model_urls. This pattern is used # to find such keys. pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') state_dict = model_zoo.load_url(model_urls['densenet121']) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) return model # 根据上述源码可知,部分参数的名字发生变化,所以稍加修改即可,修改后的代码为 # net = models.densenet121(pretrained=True) net = models.densenet121(pretrained=False) pthfile = "/mypath/densenet121-a639ec97.pth" pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') state_dict = torch.load(pthfile) for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] net.load_state_dict(state_dict)
程序在本地运行正常,放在服务器上遭遇非法指令,经过调试发现,出现问题的语句为 torch.max(a,1)
,在这里也有人遇到过这种问题,貌似都是出现在 torch 1.6.0 这个版本中,可能是机器的 CPU 太老,一些底层指令不兼容造成的,报错机器 CPU 信息如下:
$ lscpu Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Byte Order: Little Endian CPU(s): 8 On-line CPU(s) list: 0-7 Thread(s) per core: 1 Core(s) per socket: 4 座: 2 NUMA 节点: 1 厂商 ID: GenuineIntel CPU 系列: 15 型号: 6 型号名称: Common KVM processor 步进: 1 CPU MHz: 2659.998 BogoMIPS: 5319.99 超管理器厂商: KVM 虚拟化类型: 完全 L1d 缓存: 32K L1i 缓存: 32K L2 缓存: 4096K L3 缓存: 16384K NUMA 节点0 CPU: 0-7 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx lm constant_tsc nopl xtopology eagerfpu pni cx16 x2apic hypervisor lahf_lm
# 报错内容示例
>>> import torch
>>> a = torch.tensor([[1,5,62,54], [2,6,2,6], [2,65,2,6]])
>>> a
tensor([[ 1, 5, 62, 54],
[ 2, 6, 2, 6],
[ 2, 65, 2, 6]])
>>> torch.max(a, 1)
非法指令
解决办法是降低版本至 1.5.0。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。