赞
踩
一、修改预训练模型中的全连接层参数:
方式1:修改字典的方式
- import torch
- import torch.nn as nn
- import torchvision
-
- class ResNet(nn.Module):
- def __init__(self):
- super(ResNet, self).__init__()
-
- if pretrained == True:
- # 获取ResNet34的预训练权重
- resnet34 = torchvision.models.resnet34(pretrained=True)
- pretrained_dict = resnet34.state_dict()
- """加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
- 也可以直接从官方下载:
- pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
- # 获取当前模型的参数字典
- model_dict = [自己模型名称].state_dict()
- # 将pretrained_dict里不属于model_dict的键剔除掉
- pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
- # 更新现有的model_dict
- model_dict.update(pretrained_dict)
- # 加载我们真正需要的state_dict
- self.load_state_dict(model_dict)
- print('成功加载预训练权重')
-
- if __name__ == '__main__':
- resnet = ResNet()
- resnet.init_weights(pretrained=True)
- -----------------------------------
方式2:修改网络层,
对于简单的参数修改,这里以resnet预训练模型举例,resnet源代码在Github点击打开链接
resnet网络最后一层分类层fc是对1000种类型进行划分,对于自己的数据集,如果只有9类,修改的代码如下:
- # coding=UTF-8
- import torchvision.models as models
-
- #调用模型
- model = models.resnet50(pretrained=True)
- #提取fc层中固定的参数
- fc_features = model.fc.in_features
- #修改类别为9
- model.fc = nn.Linear(fc_features, 9)
有关网络属性参数的使用
- pretrained_dict = torch.load(model_weight_path, map_location=device)
- model_dict = net.state_dict()
- pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
- model_dict.update(pretrained_dict)
- # 加载我们真正需要的state_dict
- net.load_state_dict(model_dict)
-
- for k, v in pretrained_dict.items():
- print(k)
- for k, v in model_dict.items():
- print(k)
- #上面两种的for的输出将会是一样的,一个来自定义的网络结构,一个来自网络的预训练模型.pth中
二:网络关键字比预训练模型关键字多前缀
- net = parsingNet(pretrained = True, backbone=cfg.backbone,cls_dim = (cfg.griding_num+1,cls_num_per_lane, cfg.num_lanes),use_aux=cfg.use_aux).cuda()
- # -------------------------
- #加载hornet_backbone 预训练权重
- # -------------------------
- pretrained_dict = torch.load('hornet_tiny_7x7.pth',map_location='cpu')
- # state_dict = torch.load('tusimple_18.pth',map_location='cpu')
- model_dict = net.state_dict()
此时网络的关键字前缀比预训练关键字多一个前缀,即大部分关键字都是相同的。
- for k,v in pretrained_dict['model'].items():
- print(k)
-
- downsample_layers.0.0.weight
- downsample_layers.0.0.bias
- downsample_layers.0.1.weight
- downsample_layers.0.1.bias
- ···
- -----------------------------------------------------------
- for k,v in model_dict.items():
- print(k)
-
- model.downsample_layers.0.0.weight
- model.downsample_layers.0.0.bias
- model.downsample_layers.0.1.weight
- model.downsample_layers.0.1.bias
- #此时可以看到前缀多了model,要和使用多GPU训练出现的module区分开
为了可以正确将预训练权重加载进网络中,我们可以将预训练中的权重加载到网络对应的关键字中:
- keys = []
- for k, v in pretrained_dict['model'].items():
- keys.append(k)
- i = 0
- for k, v in model_dict.items():
- if v.size() == pretrained_dict['model'][keys[i]].size():
- model_dict[k] = pretrained_dict['model'][keys[i]] #权重
- #print(model_dict[k])
- i = i + 1
- net.load_state_dict(model_dict)
这种方法只能将预训练中的参数加载到网络中,对于网络的其他部分(自己加的如pool,droop等层的参数)需要自己重新训练或者提前初始化。
验证:
- >>>pretrained_dict['model']['downsample_layers.2.0.weight']
-
- tensor([0.2275, 0.2362, 0.7987, 0.6742, 1.1917, 0.7307, 0.7106, 0.9949, 0.4314,
- 0.4631, 1.0183, 0.2990, 0.2896, 0.2745, 1.0126, 0.7838, 0.5174, 0.3277,
- ···
- >>>model_dict['model.downsample_layers.2.0.weight']
-
- tensor([0.2275, 0.2362, 0.7987, 0.6742, 1.1917, 0.7307, 0.7106, 0.9949, 0.4314,
- 0.4631, 1.0183, 0.2990, 0.2896, 0.2745, 1.0126, 0.7838, 0.5174, 0.3277,
- ···
- #由此我们知道权重正确加载到网络中
修改关键字方法:(不建议)
- pretrained_dict = torch.load('hornet_tiny_7x7.pth',map_location='cpu')
-
- model_dict = net.state_dict()
-
- from collections import OrderedDict
- new_state_dict = OrderedDict()
-
- # for k, v in mgn_state_dict.items():
- # name = k[7:] # remove `module.`
- # new_state_dict[name] = v
- # self.model = self.model.load_state_dict(new_state_dict)
-
- for k, v in pretrained_dict['model'].items():
- name = "model." + k # add `model.`
- print(name)
- new_state_dict[name] = v #这种方式有弊端就是前面的正常加载,对于后面net中新加入的层并不会有对应的权重,最好使用对应赋值的方式
- net.load_state_dict(new_state_dict) # 此处会报错,这种情况需要修改网络内容
其他博主提供的方法:网络关键字比预训练多一个前缀,在我的网络中并未能使用,供参考。
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。