当前位置:   article > 正文

python加载预训练权重_加载预训练的imagenet权重

加载预训练的imagenet权重

如何在pytorch中载入部分权重

主要有两个常见的点

  1. 当载入的预训练权重中的全连接层(最后一层)与自己的实例化模型的全连接层节点个数不一样时该如何载入

    比如自己实例化的resnet网络,针对花分类数据集,只有5类,所以最后一个全连接层的节点个数等于5。但是载入的预训练权重是基于imagenet 1k的权重,所以它的节点个数是1000,很明显不能直接去载入。

  2. 对网络的结构进行了一定的修改,但是还想再次载入这个权重,很明显也无法直接载入。

    那么能不能载入部分呢?当然这也要看如何去修改的,比如说在网络的高层上进行结构的修改的话,那么相对底层的一些没有改动过的权重还是可以去载入的,因为底层都是些比较通用的权重。

设置路径

from model import resnet34
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# load pretrain weights
# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
model_weight_path = "./resnet34-pre.pth'
assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

方法1

# option1
net = resnet34()
net.load_state_dict(torch.load(model_weight_path, map_location=device))# change fc laver structure
in_channel = net.fc.in_features
net.fc = nn.Linear(in_channel, 5)
  • 1
  • 2
  • 3
  • 4
  • 5

首先创建resnet34(),注意这里没有传入num_classes参数,此时默认实例化创建的全连接层的节点数为1000,正好和预训练权重中的全连接层节点个数是一致的。载入之后,由于自己训练的类别个数不等于1000,比如等于5,接下来该怎么办呢?

在搭建 resnet 网络时,对全连接层是通过 nn.Linear 这个类实现的。在 nn.Linear 这个类当中,参数 in_features 对应输入该全连接层的节点个数,out_features 对应输出的节点个数。通过 net.fc.in_features 获得全连接层的输入节点个数,然后利用 nn.Linear 创建一个新的全连接层,它的输入节点个数为 in_channel ,输出节点个数为5。

方法2

# option2
net = resnet34(num_classes=5)
pre_weights = torch.load(model_weight_path, map_location=device)
del_key = []
for key, _ in pre_weights.items():
	if "fc" in key:
		del_key.append(key)
for key in del_key:
	del pre_weights[key]
missing_keys, unexpected_keys = net.load_state_dict(pre_weights, strict=False)
print("[missing_keys]:",*missing_keys, sep="\n")
print("[unexpected_keys]:"*unexpected_keys, sep="\n")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

在这个方法中,首先还是实例化 resnet34 ,但这里传入了 num_classes ,也就是说现在全连接层的节点个数一开始就等于5,如果此时强行载入,由于全连接层的节点个数不一样,就会发生报错。

我们可以通过 torch.load 先读取权重,它会以一个有序字典的形式存储。然后在 pre_weights 这个字典中删减全连接层的权重。(PS:这种方法需要实例化模型和载入模型的权重名字是一样的,如果不一样需要修改。可以通过 net_weight = net.static.dict()读取模型当中所以权重的名称和它所对应的权重)

通过遍历,查询 pre_weights 中所有包含"fc"字段的名称,并存入del_key当中,然后一次将这些key删除。最后同样还是通过 net.load_state_dict() 的方式载入权重,但是这里需要注意的是 strict 要设置为 False 。如果不传入这个参数,默认为 True,会严格载入每一个key值。由于删除了全连接层部分的权重,所以不能将 strict 设置为 True。

载入之后,它会给我们返回两个变量,一个叫做 missing_keys,一个叫做 unexpected_keys 。missing_keys的含义就是说在net网络当中,有一部分权重并没有在 pre_weights 这个预训练权重中出现,那么就相当于漏掉了这些权重,所以它存储在这个 missing_keys当中。unexpected_keys 就是说在 pre_weights当中有一部分权重,它不在net当中,那么此时就会存到 unexpected_keys 当中。

载入部分权重的方法有很多种,除了刚刚所说的。在载入的这个字典当中进行删减之外,其实可以自己新创建一个字典,新创建一个字典之后呢,可以通过自己组建key和value,然后同样用这个相同的方法进行载入就可以了,这样的话会更加的灵活一点。

毕设(TTSR)中加载预训练模型的方法


    def load(self, model_path=None):
        # 加载模型,传入参数 model_path,如果存在就加载模型
        if (model_path):
            self.logger.info('load_model_path: ' + model_path)
            # 加载模型参数,并更新到 self.model 中
            # model_state_dict_save = {k.replace('module.',''):v for k,v in torch.load(model_path).items()}
            model_state_dict_save = {k:v for k,v in torch.load(model_path, map_location=self.device).items()}
            model_state_dict = self.model.state_dict()
            model_state_dict.update(model_state_dict_save)
            self.model.load_state_dict(model_state_dict)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

官方教程

https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_models_for_inference.html

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/400902
推荐阅读
相关标签
  

闽ICP备14008679号