赞
踩
当载入的预训练权重中的全连接层(最后一层)与自己的实例化模型的全连接层节点个数不一样时该如何载入
比如自己实例化的resnet网络,针对花分类数据集,只有5类,所以最后一个全连接层的节点个数等于5。但是载入的预训练权重是基于imagenet 1k的权重,所以它的节点个数是1000,很明显不能直接去载入。
对网络的结构进行了一定的修改,但是还想再次载入这个权重,很明显也无法直接载入。
那么能不能载入部分呢?当然这也要看如何去修改的,比如说在网络的高层上进行结构的修改的话,那么相对底层的一些没有改动过的权重还是可以去载入的,因为底层都是些比较通用的权重。
设置路径
from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth'
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
方法1
# option1
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))# change fc laver structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
首先创建resnet34(),注意这里没有传入num_classes参数,此时默认实例化创建的全连接层的节点数为1000,正好和预训练权重中的全连接层节点个数是一致的。载入之后,由于自己训练的类别个数不等于1000,比如等于5,接下来该怎么办呢?
在搭建 resnet 网络时,对全连接层是通过 nn.Linear 这个类实现的。在 nn.Linear 这个类当中,参数 in_features 对应输入该全连接层的节点个数,out_features 对应输出的节点个数。通过 net.fc.in_features 获得全连接层的输入节点个数,然后利用 nn.Linear 创建一个新的全连接层,它的输入节点个数为 in_channel ,输出节点个数为5。
方法2
# option2
net = resnet34(num_classes=5)
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
if "fc" in key:
del_key.append(key)
for key in del_key:
del pre_weights[key]
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:",*missing_keys, sep="\n")
print("[unexpected_keys]:"*unexpected_keys, sep="\n")
在这个方法中,首先还是实例化 resnet34 ,但这里传入了 num_classes ,也就是说现在全连接层的节点个数一开始就等于5,如果此时强行载入,由于全连接层的节点个数不一样,就会发生报错。
我们可以通过 torch.load 先读取权重,它会以一个有序字典的形式存储。然后在 pre_weights 这个字典中删减全连接层的权重。(PS:这种方法需要实例化模型和载入模型的权重名字是一样的,如果不一样需要修改。可以通过 net_weight = net.static.dict()读取模型当中所以权重的名称和它所对应的权重)
通过遍历,查询 pre_weights 中所有包含"fc"字段的名称,并存入del_key当中,然后一次将这些key删除。最后同样还是通过 net.load_state_dict() 的方式载入权重,但是这里需要注意的是 strict 要设置为 False 。如果不传入这个参数,默认为 True,会严格载入每一个key值。由于删除了全连接层部分的权重,所以不能将 strict 设置为 True。
载入之后,它会给我们返回两个变量,一个叫做 missing_keys,一个叫做 unexpected_keys 。missing_keys的含义就是说在net网络当中,有一部分权重并没有在 pre_weights 这个预训练权重中出现,那么就相当于漏掉了这些权重,所以它存储在这个 missing_keys当中。unexpected_keys 就是说在 pre_weights当中有一部分权重,它不在net当中,那么此时就会存到 unexpected_keys 当中。
载入部分权重的方法有很多种,除了刚刚所说的。在载入的这个字典当中进行删减之外,其实可以自己新创建一个字典,新创建一个字典之后呢,可以通过自己组建key和value,然后同样用这个相同的方法进行载入就可以了,这样的话会更加的灵活一点。
def load(self, model_path=None):
# 加载模型,传入参数 model_path,如果存在就加载模型
if (model_path):
self.logger.info('load_model_path: ' + model_path)
# 加载模型参数,并更新到 self.model 中
# model_state_dict_save = {k.replace('module.',''):v for k,v in torch.load(model_path).items()}
model_state_dict_save = {k:v for k,v in torch.load(model_path, map_location=self.device).items()}
model_state_dict = self.model.state_dict()
model_state_dict.update(model_state_dict_save)
self.model.load_state_dict(model_state_dict)
https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。