当前位置:   article > 正文

PyTorch学习笔记(7)网络模型的使用与修改_pytorch 自带模型修改分类

pytorch 自带模型修改分类

本文使用 pytorch 自带的网络模型,并略作修改,使它能够适应其它的数据集。


下载

这里以 vgg16 为例子。

pretrained=True 表示下载已经训练好的模型,即包含相关参数,progress=True 表示显示进度条。

import torchvision
from torch import nn

# vgg16_false = torchvision.models.vgg16(pretrained=False)    #未训练好的模型
vgg16_true = torchvision.models.vgg16(pretrained=True,progress=True)

print(vgg16_true) # 输出网络结构
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

通过调试可以查看模型具体属性:

image-20220819173114335

通过打印,我们可以看到详细的网络结构。然而 vgg 模型是对 1000 个类别进行分类的,我们应该怎样使用呢?

image-20220819173747760

修改模型

假如我们想使用 CIFAR10 数据集,去套用 vgg 的网络,那么就需要在最后加一个线性层 nn.Linear(1000,10),或者将原有的线性层输出改为 10。
代码如下:

import torchvision
from torch import nn

# vgg16_false = torchvision.models.vgg16(pretrained=False)    #未训练好的模型
vgg16_true = torchvision.models.vgg16(pretrained=True,progress=True)

# 如何应用这个网络模型?
# out_features=1000改为10
train_data = torchvision.datasets.CIFAR10("./dataset_CIFAR10",train=True,transform=torchvision.transforms.ToTensor(), download=False)

# 增加一个线性层
vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print(vgg16_true)

# 修改线性层
# vgg16_true.classifier[6] = nn.Linear(4096,10)
# print(vgg16_true)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

模型的保存与读取

有两种方式,一个是同时保存模型结构参数,一个是只保存参数。
不同的保存方式,对应的加载方式也略有不同。

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式1,模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")
model = torch.load("vgg16_method1.pth")
print(model)

# 保存方式2,模型参数(官方推荐)
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
model = torch.load("vgg16_method2.pth")
print(model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

如果只保存了参数,那么在使用时需要加载模型结构:

import torch
import torchvision
from torch import nn

# 方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/217555
推荐阅读
相关标签
  

闽ICP备14008679号