当前位置:   article > 正文

Pytorch学习-调整torchvision.models中模型输出类别数

Pytorch学习-调整torchvision.models中模型输出类别数

假设你的类别只有10个,而torchvision.models中Vgg16的输出类别为1000,这时应该如何调整呢?

方法一,直接修改模型中类别的输出。

  1. from torch.nn import Linear
  2. import torchvision
  3. import torch
  4. Vgg16=torchvision.models.vgg16(pretrained=True)
  5. Vgg16.classifier[6]=Linear(in_features=4096,out_features=10)
  6. if torch.cuda.is_available():
  7. T=Vgg16.cuda()

方法二,再模型的最后增加全连接层,改变输出类别。

  1. from torch.nn import Linear
  2. import torchvision
  3. import torch
  4. res=torchvision.models.resnet101(pretrained=True,progress=True)
  5. res.fc.add_module('linelayer',Linear(in_features=1000,out_features=10))
  6. if torch.cuda.is_available():
  7. T=res.cuda()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Cpp五条/article/detail/582731
推荐阅读
相关标签
  

闽ICP备14008679号