赞
踩
在加载pytorch已有模型的时候,我们必须要明确的事情:
1 如何获取到pytorch所提供的模型,通过什么方式。
2 模型的结构,也就是模型的每个层的名字(key)。
3 我们要把需要加载的模型,尽量封装成一个类。
下面我们针对上面来给出答案。
答1:以 resnet18 举例
# ----------------1 导入库 -----------------
import torchvision.models as model
# -------------2 将resnet18导入到新模型。--------------
base_model = 'resnet18'
if 'resnet' in base_model:
model = getattr(model,base_model)
答2 :我们在了解模型的时候,经常使用 dict(model.named_parameters()),它会返回一个字典,我们通过 字典.items()来得到字典的key和value值。我们要知道最后一层的分类层名字叫什么。
for (key,value) in dict(model.named_parameters()).items():
print(key)
最后一层的名字叫 fc ,这样我们可以通过最后一层的名字来修改最后一层。
num_class = 51
fc = getattr(model, 'fc')
feature_dim = fc.in_features
setattr(model,'fc',nn.Linear(feature_dim,num_class))
print(model)
这样就把最后一层修改完成了。
答3 :最后封装成新的模型类
import torchvision.models as model import torch.nn as nn class Model(nn.Module): def __init__(self, num_class,base_model= 'resnet18'): super().__init__() self._prepare_base_model(num_class = num_class,base_model = base_model ) def _prepare_base_model(self, base_model,num_class): if 'resnet' in base_model: self.model = getattr(model, base_model)(pretrained=True) feature_dim = getattr(self.model, 'fc').in_features setattr(self.model,'fc',nn.Linear(feature_dim,num_class)) else: raise ValueError('Unknown base model: {}'.format(base_model)) def forward(self,x): out = self.model(x) return out net = Model(num_class=51,base_model='resnet18') print(net)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。