赞
踩
想对wiograd后的训练添加预训练权重,因为修改卷积层kernel尺寸后,用了新的key名, 所以修改了一下.
import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import * import os import argparse from model.vggnet_4_bn import VGG parser=argparse.ArgumentParser() # parser.add_argument('--pre_weights', type = str, default = 'rename_winograd/rename_winograd.pth', help = 'pretrained weights') parser.add_argument('--pre_weights', type = str, default = 'ckp_bn_01_vgg4/model_5_0.9785.pth', help = 'pretrained weights') opt=parser.parse_args() print(opt) model = VGG() model.load_state_dict(torch.load(opt.pre_weights)) model.cuda() # 修改model的名字为 features.0.weight ---> features.0.inner_conv2d.weight, 保留 features.0.weight from collections import OrderedDict new_dict = OrderedDict() for key in model.state_dict(): if key == "features.0.weight": new_dict["features.0.inner_conv2d.weight"] = model.state_dict()[key] new_dict["features.0.weight"] = model.state_dict()[key] elif key == "features.0.bias": new_dict["features.0.inner_conv2d.bias"] = model.state_dict()[key] new_dict["features.0.bias"] = model.state_dict()[key] elif key == "features.4.weight": new_dict["features.4.inner_conv2d.weight"] = model.state_dict()[key] new_dict["features.4.weight"] = model.state_dict()[key] elif key == "features.4.bias": new_dict["features.4.inner_conv2d.bias"] = model.state_dict()[key] new_dict["features.4.bias"] = model.state_dict()[key] elif key == "features.8.weight": new_dict["features.8.inner_conv2d.weight"] = model.state_dict()[key] new_dict["features.8.weight"] = model.state_dict()[key] elif key == "features.8.bias": new_dict["features.8.inner_conv2d.bias"] = model.state_dict()[key] new_dict["features.8.bias"] = model.state_dict()[key] elif key == "features.12.weight": new_dict["features.12.inner_conv2d.weight"] = model.state_dict()[key] new_dict["features.12.weight"] = model.state_dict()[key] elif key == "features.12.bias": new_dict["features.12.inner_conv2d.bias"] = model.state_dict()[key] new_dict["features.12.bias"] = model.state_dict()[key] else: new_dict[key] = model.state_dict()[key] print(new_dict.keys()) MODEL_PATH = "/home/aiden00/pytorch_classfication_person/personvscar_pytorch_pq/rename_winograd/" if not os.path.exists(MODEL_PATH): os.makedirs(MODEL_PATH) torch.save(new_dict, MODEL_PATH + 'model_' + 'winograd' + '.pth')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。