当前位置:   article > 正文

Pytroch网络模型:修改参数值,修改参数名,添加参数层,删除参数层_修改pytorch的pth模型中的参数名

修改pytorch的pth模型中的参数名

修改参数值

方法1

dict的类型是collecitons.OrderedDict,是一个有序字典,直接将新参数名称和初始值作为键值对插入,然后保存即可。

#修改前
dict = torch.load('./ckpt_dir//model_0.pth')
net.load_state_dict(dict)
for name,param in net.named_parameters():
	print(name,param)
#按参数名修改权重
dict["forward1.0.weight"] = torch.ones((1,1,3,3,3))
dict["forward1.0.bias"] = torch.ones(1)
torch.save(dict, './ckpt_dir//model_0_.pth')
#验证修改是否成功
net.load_state_dict(torch.load('./ckpt_dir//model_0_.pth'))
for param_tensor in net.state_dict():
	print(net.state_dict()[param_tensor])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
方法2(按条件修改)
net.load_state_dict(torch.load('./ckpt_dir//model_0.pth'))
for param_tensor in net.state_dict():
	print(net.state_dict()[param_tensor])
#按条件修改权重
for param in net.parameters():
	new = torch.zeros_like(param.data)
	param.data = torch.where(0, param.data, new)
#验证是否真的修改了权重值。
for param_tensor in net.state_dict():
	print(net.state_dict()[param_tensor])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

修改参数名

dict = torch.load(model_dir)
older_val = dict['旧名']
# 修改参数名
dict['新名'] = dict.pop('旧名')
torch.save(dict, './model_changed.pth')
#验证修改是否成功
changed_dict = torch.load('./model_changed.pth')
print(old_val)
print(changed_dict['新名'])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

添加参数层

dict = torch.load('./ckpt_dir//model_0.pth')
print(dict)
dict['forward1.0.weight1'] = None #把OrderedDict类型的dict当作普通字典使用即可
print(dict)
  • 1
  • 2
  • 3
  • 4

删除参数层

pre_model = "./results/model_2-9.pth"
dict = torch.load(pre_model)
for key in list(dict.keys()):
    if key.startswith('decoder1'):
        del dict[key]
torch.save(dict, './model_deleted.pth')
# # #验证修改是否成功
changed_dict = torch.load('./model_deleted.pth')
for key in dict.keys():
    print(key)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/217558
推荐阅读
相关标签
  

闽ICP备14008679号