赞
踩
本文使用 pytorch 自带的网络模型,并略作修改,使它能够适应其它的数据集。
这里以 vgg16
为例子。
pretrained=True
表示下载已经训练好的模型,即包含相关参数,progress=True
表示显示进度条。
import torchvision
from torch import nn
# vgg16_false = torchvision.models.vgg16(pretrained=False) #未训练好的模型
vgg16_true = torchvision.models.vgg16(pretrained=True,progress=True)
print(vgg16_true) # 输出网络结构
通过调试可以查看模型具体属性:
通过打印,我们可以看到详细的网络结构。然而 vgg
模型是对 1000 个类别进行分类的,我们应该怎样使用呢?
假如我们想使用 CIFAR10
数据集,去套用 vgg 的网络,那么就需要在最后加一个线性层 nn.Linear(1000,10)
,或者将原有的线性层输出改为 10。
代码如下:
import torchvision from torch import nn # vgg16_false = torchvision.models.vgg16(pretrained=False) #未训练好的模型 vgg16_true = torchvision.models.vgg16(pretrained=True,progress=True) # 如何应用这个网络模型? # out_features=1000改为10 train_data = torchvision.datasets.CIFAR10("./dataset_CIFAR10",train=True,transform=torchvision.transforms.ToTensor(), download=False) # 增加一个线性层 vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10)) print(vgg16_true) # 修改线性层 # vgg16_true.classifier[6] = nn.Linear(4096,10) # print(vgg16_true)
有两种方式,一个是同时保存模型结构和参数,一个是只保存参数。
不同的保存方式,对应的加载方式也略有不同。
import torch
import torchvision
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")
model = torch.load("vgg16_method1.pth")
print(model)
# 保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
model = torch.load("vgg16_method2.pth")
print(model)
如果只保存了参数,那么在使用时需要加载模型结构:
import torch
import torchvision
from torch import nn
# 方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。