当前位置:   article > 正文

深度学习笔记--修改替换Pytorch权重文件的Key值_pytorch yolo训练修改网络后没有权重文件

pytorch yolo训练修改网络后没有权重文件

目录

1--前言

2--问题描述

2--代码

3--测试


1--前言

        最近复现一篇 Paper,需要使用预训练的模型,但预训练模型和自定义模型的 key 值不匹配,导致无法顺利加载预训练权重文件;

2--问题描述

        需要使用的预训练模型如下:

  1. import torch
  2. if __name__ == "__main__":
  3. weights_files = './joint_model_stgcn.pt' # 权重文件路径
  4. weights = torch.load(weights_files) # 加载权重文件
  5. for k, v in weights.items(): # key, value
  6. print(k) # 打印 key(参数名)

        原权重文件的 key 值如下:

A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance.0
edge_importance.1
edge_importance.2
edge_importance.3
edge_importance.4
edge_importance.5
edge_importance.6
edge_importance.7
edge_importance.8
edge_importance.9

fcn.weight
fcn.bias

        需求是修改以下 key 值,以适配自定义模型:

  1. edge_importance.0 -> edge_importance0
  2. edge_importance.1 -> edge_importance1
  3. edge_importance.2 -> edge_importance2
  4. edge_importance.3 -> edge_importance3
  5. edge_importance.4 -> edge_importance4
  6. edge_importance.5 -> edge_importance5
  7. edge_importance.6 -> edge_importance6
  8. edge_importance.7 -> edge_importance7
  9. edge_importance.8 -> edge_importance8
  10. edge_importance.9 -> edge_importance9

2--代码

        基于原权重文件,利用 collections.OrderedDict() 创建新的权重文件:

  1. import torch
  2. import collections
  3. if __name__ == "__main__":
  4. # 加载原权重文件
  5. weights_files = './joint_model_stgcn.pt'
  6. weights = torch.load(weights_files)
  7. # 修改
  8. new_d = weights
  9. for i in range(10):
  10. new_d = collections.OrderedDict([('edge_importance'+str(i), v) if k == 'edge_importance.'+str(i) else (k, v) for k, v in new_d.items()])
  11. # 测试
  12. for k, v in new_d.items(): # key, value
  13. print(k) # 打印参数名
  14. # 保存
  15. torch.save(new_d, 'new_joint_model_stgcn.pt')

        修改后的 key 值:

A
...
st_gcn_networks.9.gcn.conv.weight
st_gcn_networks.9.gcn.conv.bias
st_gcn_networks.9.tcn.0.weight
st_gcn_networks.9.tcn.0.bias
st_gcn_networks.9.tcn.0.running_mean
st_gcn_networks.9.tcn.0.running_var
st_gcn_networks.9.tcn.0.num_batches_tracked
st_gcn_networks.9.tcn.2.weight
st_gcn_networks.9.tcn.2.bias
st_gcn_networks.9.tcn.3.weight
st_gcn_networks.9.tcn.3.bias
st_gcn_networks.9.tcn.3.running_mean
st_gcn_networks.9.tcn.3.running_var
st_gcn_networks.9.tcn.3.num_batches_tracked
edge_importance0
edge_importance1
edge_importance2
edge_importance3
edge_importance4
edge_importance5
edge_importance6
edge_importance7
edge_importance8
edge_importance9

fcn.weight
fcn.bias

3--测试

        测试原权重文件和新权重文件的 value 是否相同:

  1. import torch
  2. if __name__ == "__main__":
  3. origin_weights_files = './joint_model_stgcn.pt'
  4. origin_weights = torch.load(origin_weights_files)
  5. new_weights_files = './new_joint_model_stgcn.pt'
  6. new_weights = torch.load(new_weights_files)
  7. print(origin_weights['A'] == new_weights['A'])
  8. print(origin_weights['edge_importance.0'] == new_weights['edge_importance0'])

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/160109
推荐阅读
相关标签
  

闽ICP备14008679号