赞
踩
假设你的类别只有10个,而torchvision.models中Vgg16的输出类别为1000,这时应该如何调整呢?
- from torch.nn import Linear
- import torchvision
- import torch
-
- Vgg16=torchvision.models.vgg16(pretrained=True)
- Vgg16.classifier[6]=Linear(in_features=4096,out_features=10)
- if torch.cuda.is_available():
- T=Vgg16.cuda()
- from torch.nn import Linear
- import torchvision
- import torch
-
- res=torchvision.models.resnet101(pretrained=True,progress=True)
- res.fc.add_module('linelayer',Linear(in_features=1000,out_features=10))
- if torch.cuda.is_available():
- T=res.cuda()
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。