赞
踩
目录
最近复现一篇 Paper,需要使用预训练的模型,但预训练模型和自定义模型的 key 值不匹配,导致无法顺利加载预训练权重文件;
需要使用的预训练模型如下:
- import torch
-
- if __name__ == "__main__":
-
- weights_files = './joint_model_stgcn.pt' # 权重文件路径
- weights = torch.load(weights_files) # 加载权重文件
-
- for k, v in weights.items(): # key, value
- print(k) # 打印 key(参数名)
原权重文件的 key 值如下:
A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance.0
edge_importance.1
edge_importance.2
edge_importance.3
edge_importance.4
edge_importance.5
edge_importance.6
edge_importance.7
edge_importance.8
edge_importance.9
fcn.weight
fcn.bias
需求是修改以下 key 值,以适配自定义模型:
- edge_importance.0 -> edge_importance0
- edge_importance.1 -> edge_importance1
- edge_importance.2 -> edge_importance2
- edge_importance.3 -> edge_importance3
- edge_importance.4 -> edge_importance4
- edge_importance.5 -> edge_importance5
- edge_importance.6 -> edge_importance6
- edge_importance.7 -> edge_importance7
- edge_importance.8 -> edge_importance8
- edge_importance.9 -> edge_importance9
基于原权重文件,利用 collections.OrderedDict() 创建新的权重文件:
- import torch
- import collections
-
- if __name__ == "__main__":
- # 加载原权重文件
- weights_files = './joint_model_stgcn.pt'
- weights = torch.load(weights_files)
- # 修改
- new_d = weights
- for i in range(10):
- new_d = collections.OrderedDict([('edge_importance'+str(i), v) if k == 'edge_importance.'+str(i) else (k, v) for k, v in new_d.items()])
- # 测试
- for k, v in new_d.items(): # key, value
- print(k) # 打印参数名
- # 保存
- torch.save(new_d, 'new_joint_model_stgcn.pt')
修改后的 key 值:
A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance0
edge_importance1
edge_importance2
edge_importance3
edge_importance4
edge_importance5
edge_importance6
edge_importance7
edge_importance8
edge_importance9
fcn.weight
fcn.bias
测试原权重文件和新权重文件的 value 是否相同:
- import torch
-
- if __name__ == "__main__":
-
- origin_weights_files = './joint_model_stgcn.pt'
- origin_weights = torch.load(origin_weights_files)
- new_weights_files = './new_joint_model_stgcn.pt'
- new_weights = torch.load(new_weights_files)
-
- print(origin_weights['A'] == new_weights['A'])
- print(origin_weights['edge_importance.0'] == new_weights['edge_importance0'])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。