当前位置:   article > 正文

PyTorch 修改权重/字典 Key_修改pytorch 模型中state_dict 的key

修改pytorch 模型中state_dict 的key

想对wiograd后的训练添加预训练权重,因为修改卷积层kernel尺寸后,用了新的key名, 所以修改了一下.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import *
import os
import argparse
from model.vggnet_4_bn import VGG

parser=argparse.ArgumentParser()
# parser.add_argument('--pre_weights', type = str, default = 'rename_winograd/rename_winograd.pth', help = 'pretrained weights')
parser.add_argument('--pre_weights', type = str, default = 'ckp_bn_01_vgg4/model_5_0.9785.pth', help = 'pretrained weights')
opt=parser.parse_args()
print(opt)

model = VGG()
model.load_state_dict(torch.load(opt.pre_weights))
model.cuda()
# 修改model的名字为  features.0.weight  ---> features.0.inner_conv2d.weight, 保留 features.0.weight
from collections import OrderedDict
new_dict = OrderedDict()
for key in model.state_dict():
	if key == "features.0.weight":
		new_dict["features.0.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.0.weight"] = model.state_dict()[key]
	elif key == "features.0.bias":
		new_dict["features.0.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.0.bias"] = model.state_dict()[key]
	elif key == "features.4.weight":
		new_dict["features.4.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.4.weight"] = model.state_dict()[key]
	elif key == "features.4.bias":
		new_dict["features.4.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.4.bias"] = model.state_dict()[key]
	elif key == "features.8.weight":
		new_dict["features.8.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.8.weight"] = model.state_dict()[key]
	elif key == "features.8.bias":
		new_dict["features.8.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.8.bias"] = model.state_dict()[key]
	elif key == "features.12.weight":
		new_dict["features.12.inner_conv2d.weight"] = model.state_dict()[key]
		new_dict["features.12.weight"] = model.state_dict()[key]
	elif key == "features.12.bias":
		new_dict["features.12.inner_conv2d.bias"] = model.state_dict()[key]
		new_dict["features.12.bias"] = model.state_dict()[key]
	else:
		new_dict[key] = model.state_dict()[key]
		
print(new_dict.keys())
MODEL_PATH = "/home/aiden00/pytorch_classfication_person/personvscar_pytorch_pq/rename_winograd/" 
if not os.path.exists(MODEL_PATH):
	os.makedirs(MODEL_PATH)	
torch.save(new_dict, MODEL_PATH + 'model_' + 'winograd' + '.pth')  


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/160072
推荐阅读
相关标签
  

闽ICP备14008679号