当前位置:   article > 正文

pytorch学习笔记(七)——pytorch中现有网络模型的使用、修改、模型的保存、加载_加载模型并修改模型结构pytorch

加载模型并修改模型结构pytorch

一、pytorch中现有网络模型的使用、修改

  1. 位于torchvision.models

  2. 使用vgg模型为例,采用的数据集是ImageNet,而ImageNet数据集使用前提需要有scipy包
    pip install scipy

    注意:ImageNet光训练集就有147.9G,而且不再能公开访问了

  3. pytorch中使用现有网络模型以及修改现有的网络模型代码示例

import torchvision

# train_data = torchvision.datasets.ImageNet("../data_image_net", split="train", download=True,
#                                            transform=torchvision.transforms.ToTensor())
from torch import nn

"""
理解:
1. pretrained=False时,相当于使用pytorch中现有的网络模型,其中各层的参数采用默认的
2. pretrained=True时,相当于使用pytorch中现有的网络模型,但其中各层的参数采用 我们在数据集上训练好的参数
"""

# 1.使用现有的网络模型
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

# 2.在现有的网络模型中添加一层
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))

# 3.修改现有网络中的某层的参数
vgg16_false.classifier[7] = nn.Linear(4096, 10)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

二、模型的保存和加载

1. 模型的保存

import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式1,保存了网络模型的结构以及其中的参数
torch.save(vgg16, "vgg16_method1.pth")

# 保存方式2,把网络模型的参数保存成字典,不再保存网络模型的结构(官方推荐)占的空间小
torch.save(vgg16.state_dict(), "vgg16_method2.pth")


# 陷阱,用方式1保存自己写的神经网络
class MyNeural(nn.Module):
    def __init__(self):
        super(MyNeural, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return x


my_neural = MyNeural()
torch.save(my_neural, "my_neural_method1.pth")

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

2.模型的加载

import torch
import torchvision
from c17_model_save import *

vgg16 = torchvision.models.vgg16(pretrained=False)

# 加载方式1,对应保存方式1
model = torch.load("vgg16_method1.pth")
print(model)

# 加载方式2,对应保存方式2
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

# 陷阱1,
# 要让该.py文件加载自己定义的神经网络,需要引入自己定义的神经网络的模板类 from c17_model_save import *
model = torch.load("my_neural_method1.pth")
print(model)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/217578?site
推荐阅读
相关标签
  

闽ICP备14008679号