当前位置:   article > 正文

加载pytorch已有模型,修改最后分类头_pytorch最后一层分类层

pytorch最后一层分类层

在加载pytorch已有模型的时候,我们必须要明确的事情:

1 如何获取到pytorch所提供的模型,通过什么方式。
2 模型的结构,也就是模型的每个层的名字(key)。
3 我们要把需要加载的模型,尽量封装成一个类。

下面我们针对上面来给出答案。
答1:以 resnet18 举例

# ----------------1 导入库 -----------------
import torchvision.models as model
# -------------2 将resnet18导入到新模型。--------------
base_model = 'resnet18'
if 'resnet' in base_model:
	model = getattr(model,base_model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

答2 :我们在了解模型的时候,经常使用 dict(model.named_parameters()),它会返回一个字典,我们通过 字典.items()来得到字典的key和value值。我们要知道最后一层的分类层名字叫什么。

for (key,value) in dict(model.named_parameters()).items():
    print(key)
  • 1
  • 2

在这里插入图片描述
最后一层的名字叫 fc ,这样我们可以通过最后一层的名字来修改最后一层。

num_class = 51
fc = getattr(model, 'fc')
feature_dim = fc.in_features
setattr(model,'fc',nn.Linear(feature_dim,num_class))
print(model)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

在这里插入图片描述
这样就把最后一层修改完成了。

答3 :最后封装成新的模型类

import torchvision.models as model
import torch.nn as nn


class Model(nn.Module):
    def __init__(self, num_class,base_model= 'resnet18'):
        super().__init__()
        self._prepare_base_model(num_class = num_class,base_model = base_model )
  
    def _prepare_base_model(self, base_model,num_class):  
        if 'resnet' in base_model:
            self.model = getattr(model, base_model)(pretrained=True)
            feature_dim = getattr(self.model, 'fc').in_features
            setattr(self.model,'fc',nn.Linear(feature_dim,num_class))      
        else:
            raise ValueError('Unknown base model: {}'.format(base_model))
            
    def forward(self,x):
        out = self.model(x)
        return out

net = Model(num_class=51,base_model='resnet18')
print(net)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

在这里插入图片描述

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/217575
推荐阅读
相关标签
  

闽ICP备14008679号