当前位置:   article > 正文

torch学习 (十五):读取和存储模型_torch存储模型

torch存储模型

引入

  本节介绍如何把在内存中训练好的模型参数进行存储,以及后续的读取 [ 1 ] ^{[1]} [1]

1 读写Tensor

  torch的save与load函数与numpy的类似:

import torch
from torch import nn


if __name__ == '__main__':
    # Main
    x = torch.ones(3)
    torch.save(x, 'x.pt')
    x = torch.load('x.pt')
    print(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

  输出如下:

tensor([1., 1., 1.])
  • 1

2 读写模型

  torch中,Module的可学习参数 (权重和偏差),以及模块模型包含在参数中,可通过model.parameters()访问。

2.1 state_dict

  state_dict是一个从参数名称映射到参数Tensor的字典:

import torch
from torch import nn


class Test(nn.Module):

    def __init__(self):
        super(Test, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        """
        The forward function.
        """
        return self.output(self.act(self.hidden(x)))


if __name__ == '__main__':
    # Main
    net = Test()
    print(net.state_dict())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

  输出如下:

OrderedDict([('hidden.weight', tensor([[-0.1209, -0.1974, -0.2399],
        [-0.3348,  0.5283, -0.5134]])), ('hidden.bias', tensor([ 0.1019, -0.1037])), ('output.weight', tensor([[-0.4111, -0.1848]])), ('output.bias', tensor([-0.4113]))])
  • 1
  • 2

  注:只有具有可学习参数的层才有state_dict条目。优化器也有一个state_dict,其中包含关于优化器状态以及琐事有超参数的信息:

import torch
from torch import nn


class Test(nn.Module):

    def __init__(self):
        super(Test, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        """
        The forward function.
        """
        return self.output(self.act(self.hidden(x)))


if __name__ == '__main__':
    # Main
    net = Test()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.2)
    print(optimizer.state_dict())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

  输出如下:

{'state': {}, 'param_groups': [{'lr': 0.1, 'momentum': 0.2, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3]}]}
  • 1

2.2 保存和加载模型

  torch保存和加载模型有两种常见的方法:
  1)仅保存和加载模型参数,即state_dict;
  2)保存和加载整个模型。

2.2.1 保存和加载state_dict (推荐)

import torch
from torch import nn


class Test(nn.Module):

    def __init__(self):
        super(Test, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        """
        The forward function.
        """
        return self.output(self.act(self.hidden(x)))


if __name__ == '__main__':
    # Main
    net = Test()
    torch.save(net.state_dict(), 'net.pt') # .pt or .pth
    model = Test()
    model.load_state_dict(torch.load('net.pt'))
    print(model)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

  输出如下:

Test(
  (hidden): Linear(in_features=3, out_features=2, bias=True)
  (act): ReLU()
  (output): Linear(in_features=2, out_features=1, bias=True)
)
  • 1
  • 2
  • 3
  • 4
  • 5

2.2.2 保存整个模型

import torch
from torch import nn


class Test(nn.Module):

    def __init__(self):
        super(Test, self).__init__()
        self.hidden = nn.Linear(3, 2)
        self.act = nn.ReLU()
        self.output = nn.Linear(2, 1)

    def forward(self, x):
        """
        The forward function.
        """
        return self.output(self.act(self.hidden(x)))


if __name__ == '__main__':
    # Main
    net = Test()
    torch.save(net, 'net.pt') # .pt or .pth
    net = torch.load('net.pt')
    print(net)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

  输出如下:

Test(
  (hidden): Linear(in_features=3, out_features=2, bias=True)
  (act): ReLU()
  (output): Linear(in_features=2, out_features=1, bias=True)
)
  • 1
  • 2
  • 3
  • 4
  • 5

参考文献
[1] 李沐、Aston Zhang等老师的这本《动手学深度学习》一书。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/328406
推荐阅读
相关标签
  

闽ICP备14008679号