当前位置:   article > 正文

torch模型介绍,如何保存模型,torch.nn.Module类介绍和使用方法

torch.nn.module

仅供用于个人学习

使用pytorch 构建网络,经常会使用到工具箱torch.nn。
其中nn.Module是PyTorch提供的神经网络类,并在类中实现了网络各层的定义及前向计算与反向计算机制。

一、torch.nn.Module类的简介

class Module(object):
    def __init__(self):
    def forward(self, *input):
 
    def add_module(self, name, module):
    def cuda(self, device=None):
    def cpu(self):
    def __call__(self, *input, **kwargs):
    def parameters(self, recurse=True):
    def named_parameters(self, prefix='', recurse=True):
    def children(self):
    def named_children(self):
    def modules(self):  
    def named_modules(self, memo=None, prefix=''):
    def train(self, mode=True):
    def eval(self):
    def zero_grad(self):
    def __repr__(self):
    def __dir__(self):
		......
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

我们在定义自己的网络的时候,需要继承nn.Module类,并重新实现构造函数__init__和forward前向传播这两个方法。
(1)一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数__init__()中,当然也可以吧不具有参数的层也放在里面;
(2)一般把不具有可学习参数的层(如ReLU、dropout、BatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数__init__里面,则在forward方法里面可以使用nn.functional来代替
(3)forward方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。

import torch
 
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu1=torch.nn.ReLU()
        self.max_pooling1=torch.nn.MaxPool2d(2,1)
 
        self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu2=torch.nn.ReLU()
        self.max_pooling2=torch.nn.MaxPool2d(2,1)
 
        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.max_pooling1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x
 
model = MyNet()
print(model)
'''运行结果为:
MyNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (max_pooling1): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (max_pooling2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
  (dense1): Linear(in_features=288, out_features=128, bias=True)
  (dense2): Linear(in_features=128, out_features=10, bias=True)
)
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

上面的是将所有的层都放在了构造函数__init__里面,但是只是定义了一系列的层,包括了具有可学习参数的层和不具有可学习参数的层,各个层之间到底是什么连接关系并没有,而是在forward里面实现所有层的连接关系,当然这里依然是顺序连接的。

另一个例子

import torch
import torch.nn.functional as F
 
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
 
        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = F.max_pool2d(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x
 
model = MyNet()
print(model)
'''运行结果为:
MyNet(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (dense1): Linear(in_features=288, out_features=128, bias=True)
  (dense2): Linear(in_features=128, out_features=10, bias=True)
)
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

此时,将没有训练参数的层没有放在构造函数里面了,所以这些层就不会出现在model里面,但是运行关系是在forward里面通过functional的方法实现的<>。

总结:所有放在构造函数__init__里面的层的都是这个模型的“固有属性”.

二、torch.nn.Module类的多种实现

Module类有很多种灵活的实现方式

2.1 通过Sequential类包装层

2.1.1 Sequential类介绍

当一个模型较简单的时候,我们可以使用torch.nn.Sequential类来实现简单的顺序连接模型。这个模型也是继承自Module类的,

Sequential可以看做是一个容器,作为一个容器包装机各层,我下面是它的定义:

class Sequential(Module): # 继承Module
    def __init__(self, *args):  # 重写了构造函数
    def _get_item_by_idx(self, iterator, idx):
    def __getitem__(self, idx):
    def __setitem__(self, idx, module):
    def __delitem__(self, idx):
    def __len__(self):
    def __dir__(self):
    def forward(self, input):  # 重写关键方法forward
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

2.1.2 Sequential类实现(3种实现方式)

1) 序贯模型:单输入单输出,一条路通到底,层与层之间只有相邻关系,没有跨层连接。

import torch.nn as nn
model = nn.Sequential(
                  nn.Conv2d(1,20,5),
                  nn.ReLU(),
                  nn.Conv2d(20,64,5),
                  nn.ReLU()
                )
 
print(model)
print(model[2]) # 通过索引获取第几个层
'''运行结果为:
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

注意: 这样做有一个问题,每一个层是没有名称,默认的是以0、1、2、3来命名,从上面的运行结果也可以看出。

2)使用有序字典类OrderedDict类,给每一层添加名称

import torch.nn as nn
from collections import OrderedDict
model = nn.Sequential(OrderedDict([
                  ('conv1', nn.Conv2d(1,20,5)),
                  ('relu1', nn.ReLU()),
                  ('conv2', nn.Conv2d(20,64,5)),
                  ('relu2', nn.ReLU())
                ]))
 
print(model)
print(model[2]) # 通过索引获取第几个层
'''运行结果为:
Sequential(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
)
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

注意: 从上面的结果中可以看出,这个时候每一个层都有了自己的名称,但是此时需要注意,我并不能够通过名称直接获取层,依然只能通过索引index,即

model[2] 是正确的

model[“conv2”] 是错误的

这其实是由它的定义实现的,看上面的Sequenrial定义可知,只支持index访问。

3)Sequential的第三种实现

import torch.nn as nn
from collections import OrderedDict
model = nn.Sequential()
model.add_module("conv1",nn.Conv2d(1,20,5))
model.add_module('relu1', nn.ReLU())
model.add_module('conv2', nn.Conv2d(20,64,5))
model.add_module('relu2', nn.ReLU())
 
print(model)
print(model[2]) # 通过索引获取第几个层
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

2.1.3 通过sequential包装层

(1)方式一:

import torch.nn as nn
from collections import OrderedDict
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2))
        self.dense_block = nn.Sequential(
            nn.Linear(32 * 3 * 3, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )
    # 在这里实现层之间的连接关系,其实就是所谓的前向传播
    def forward(self, x):
        conv_out = self.conv_block(x)
        res = conv_out.view(conv_out.size(0), -1)
        out = self.dense_block(res)
        return out
 
model = MyNet()
print(model)
'''运行结果为:
MyNet(
  (conv_block): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (dense_block): Sequential(
    (0): Linear(in_features=288, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=10, bias=True)
  )
)
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

这里在每一个包装块里面,各个层是没有名称的,默认按照0、1、2、3、4来排名。

(2)方式二:
使用OrdereDict类,添加每一层的名字

import torch.nn as nn
from collections import OrderedDict
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv_block = nn.Sequential(
            OrderedDict(
                [
                    ("conv1", nn.Conv2d(3, 32, 3, 1, 1)),
                    ("relu1", nn.ReLU()),
                    ("pool", nn.MaxPool2d(2))
                ]
            ))
 
        self.dense_block = nn.Sequential(
            OrderedDict([
                ("dense1", nn.Linear(32 * 3 * 3, 128)),
                ("relu2", nn.ReLU()),
                ("dense2", nn.Linear(128, 10))
            ])
        )
 
    def forward(self, x):
        conv_out = self.conv_block(x)
        res = conv_out.view(conv_out.size(0), -1)
        out = self.dense_block(res)
        return out
 
model = MyNet()
print(model)
'''运行结果为:
MyNet(
  (conv_block): Sequential(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU()
    (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (dense_block): Sequential(
    (dense1): Linear(in_features=288, out_features=128, bias=True)
    (relu2): ReLU()
    (dense2): Linear(in_features=128, out_features=10, bias=True)
  )
)
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44

(3)方式三:
通过add_module方法添加层

import torch.nn as nn
from collections import OrderedDict
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv_block=torch.nn.Sequential()
        self.conv_block.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
        self.conv_block.add_module("relu1",torch.nn.ReLU())
        self.conv_block.add_module("pool1",torch.nn.MaxPool2d(2))
 
        self.dense_block = torch.nn.Sequential()
        self.dense_block.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
        self.dense_block.add_module("relu2",torch.nn.ReLU())
        self.dense_block.add_module("dense2",torch.nn.Linear(128, 10))
 
    def forward(self, x):
        conv_out = self.conv_block(x)
        res = conv_out.view(conv_out.size(0), -1)
        out = self.dense_block(res)
        return out
 
model = MyNet()
print(model)
'''运行结果为:
MyNet(
  (conv_block): Sequential(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu1): ReLU()
    (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (dense_block): Sequential(
    (dense1): Linear(in_features=288, out_features=128, bias=True)
    (relu2): ReLU()
    (dense2): Linear(in_features=128, out_features=10, bias=True)
  )
)
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37

2.2 Module类的几个常见方法的使用

Sequential类虽然继承自Module类,二者有相似部分,但是也有很多不同的部分,集中体现在:
Sequential 类 实现了整数索引,故而可以使用model[index] 这样的方式获取一个层,但是Module类并没有实现整数索引,不能够通过整数索引来获得层,那该怎么办呢?它提供了几个主要的方法,model.children()、model.named_children()、model.modules()、model.named_modules()

(1)方式一:model.children()

import torch.nn as nn
from collections import OrderedDict
class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.conv_block=torch.nn.Sequential()
        self.conv_block.add_module("conv1",torch.nn.Conv2d(3, 32, 3, 1, 1))
        self.conv_block.add_module("relu1",torch.nn.ReLU())
        self.conv_block.add_module("pool1",torch.nn.MaxPool2d(2))
 
        self.dense_block = torch.nn.Sequential()
        self.dense_block.add_module("dense1",torch.nn.Linear(32 * 3 * 3, 128))
        self.dense_block.add_module("relu2",torch.nn.ReLU())
        self.dense_block.add_module("dense2",torch.nn.Linear(128, 10))
 
    def forward(self, x):
        conv_out = self.conv_block(x)
        res = conv_out.view(conv_out.size(0), -1)
        out = self.dense_block(res)
        return out
 
model = MyNet()
 
for i in model.children():
    print(i)
    print(type(i)) # 查看每一次迭代的元素到底是什么类型,实际上是 Sequential 类型,所以有可以使用下标index索引来获取每一个Sequenrial 里面的具体层
 
'''运行结果为:
Sequential(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
<class 'torch.nn.modules.container.Sequential'>
Sequential(
  (dense1): Linear(in_features=288, out_features=128, bias=True)
  (relu2): ReLU()
  (dense2): Linear(in_features=128, out_features=10, bias=True)
)
<class 'torch.nn.modules.container.Sequential'>
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41

如果把这个网络模型MyNet按层次从外到内进行划分的话,conv_block和dense_block是MyNet的子层,而conv2d, ReLU, Maxpool2d这些又是conv_block的子层, Linear, ReLU等是dense_block的子层,下面的model.modules()不但会遍历模型的子层,还会遍历子层的子层,以及所有子层。
而model.children()只会遍历模型的子层,这里即是conv_block和dense_block。

(2)方式二:model.named_children()方法

for i in model.named_children():
    print(i)
    print(type(i))
    print(i[0])
    print(i[1])# 查看每一次迭代的元素到底是什么类型,实际上是 返回一个tuple,tuple 的第一个元素是模块名称,第二个原始是sequential里面的具体层
 
'''运行结果为:
('conv_block', Sequential(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
))
<class 'tuple'>
conv_block
Sequential(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
('dense_block', Sequential(
  (dense1): Linear(in_features=288, out_features=128, bias=True)
  (relu2): ReLU()
  (dense2): Linear(in_features=128, out_features=10, bias=True)
))
<class 'tuple'>
dense_block
Sequential(
  (dense1): Linear(in_features=288, out_features=128, bias=True)
  (relu2): ReLU()
  (dense2): Linear(in_features=128, out_features=10, bias=True)
)
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

对比上面的model.children(), 这里的model.named_children()还返回了两个子层的名称:conv_block 和 dense_block .
(1)model.children()和model.named_children()方法返回的是迭代器iterator

(2)model.children():每一次迭代返回的每一个元素实际上是 Sequential 类型,而Sequential类型又可以使用下标index索引来获取每一个Sequenrial 里面的具体层,比如conv层、dense层等;

(3)model.named_children():每一次迭代返回的每一个元素实际上是 一个元组类型,元组的第一个元素是名称,第二个元素就是对应的层或者是Sequential。


(1)model.modules()和model.named_modules()方法返回的是迭代器iterator;

(2)model的modules()方法和named_modules()方法都会将整个模型的所有构成(包括包装层、单独的层、自定义层等)由浅入深依次遍历出来,只不过modules()返回的每一个元素是直接返回的层对象本身,而named_modules()返回的每一个元素是一个元组,第一个元素是名称,第二个元素才是层对象本身。

(3)如何理解children和modules之间的这种差异性。注意pytorch里面不管是模型、层、激活函数、损失函数都可以当成是Module的拓展,所以modules和named_modules会层层迭代,由浅入深,将每一个自定义块block、然后block里面的每一个层都当成是module来迭代。而children就比较直观,就表示的是所谓的“孩子”,所以没有层层迭代深入。

三、nn.Module查看模型 state_dict和parameters方法

model.state_dict()方法和model.parameters()方法会获取到模型中训练的参数(权值矩阵、偏置bias)

3.1model.state_dict()方法

pytorch 中的 state_dict 是一个简单的python的字典对象,将每一层与它的对应参数建立映射关系.(如model的每一层的weights及偏置等等)

注意:
(1)只有那些参数可以训练的layer才会被保存到模型的state_dict中,如卷积层,线性层等等,像什么池化层、BN层这些本身没有参数的层是没有在这个字典中的;

(2)这个方法的作用一方面是方便查看某一个层的权值和偏置数据,另一方面更多的是在模型保存的时候使用。

import torch
import torch.nn.functional as F
from torch.optim import SGD
 
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu1=torch.nn.ReLU()
        self.max_pooling1=torch.nn.MaxPool2d(2,1)
 
        self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu2=torch.nn.ReLU()
        self.max_pooling2=torch.nn.MaxPool2d(2,1)
 
        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.max_pooling1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x
 
model = MyNet() # 构造模型
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

3.1.1 Module的层的权值以及bias查看

print(type(model.state_dict()))  # 查看state_dict所返回的类型,是一个“顺序字典OrderedDict”

print(model.state_dict().keys()) # 查看state_dict中的键值

model.state_dict()['conv1.weight'].size()
print('\n')

for param_tensor in model.state_dict(): # 字典的遍历默认是遍历 key,所以param_tensor实际上是键值
    print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
'''
<class 'collections.OrderedDict'>
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'dense1.weight', 'dense1.bias', 'dense2.weight', 'dense2.bias'])


conv1.weight 	 torch.Size([32, 3, 3, 3])
conv1.bias 	 torch.Size([32])
conv2.weight 	 torch.Size([32, 3, 3, 3])
conv2.bias 	 torch.Size([32])
dense1.weight 	 torch.Size([128, 288])
dense1.bias 	 torch.Size([128])
dense2.weight 	 torch.Size([10, 128])
dense2.bias 	 torch.Size([10])
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

这里之查看每一个参数的维度信息,如果是查看具体的值也是一样的,跟字典的操作是一样的。

3.1.2 优化器optimizer的state_dict()方法

优化器对象Optimizer也有一个state_dict,它包含了优化器的状态以及被使用的超参数(如lr, momentum,weight_decay等)

optimizer = SGD(model.parameters(),lr=0.001,momentum=0.9)
for var_name in optimizer.state_dict():
    print(var_name,'\t',optimizer.state_dict()[var_name])
 
'''
state    {}
param_groups     [{'lr': 0.001, 
                   'momentum': 0.9,
                   'dampening': 0, 
                   'weight_decay': 0, 
                   'nesterov': False, 
                   'params': [1412966600640, 1412966613064, 1412966613136, 1412966613208, 
                             1412966613280, 1412966613352, 1412966613496, 1412966613568]
                 }]
'''

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

3.1.3 model.parameters()方法

print(type(model.parameters()))  # 返回的是一个generator
 
for para in model.parameters():
    print(para.size())  # 只查看形状
'''
<class 'generator'>
torch.Size([32, 3, 3, 3])
torch.Size([32])
torch.Size([32, 3, 3, 3])
torch.Size([32])
torch.Size([128, 288])
torch.Size([128])
torch.Size([10, 128])
torch.Size([10])
'''
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

从这里可以看出,其实这个state_dict方法所得到结果差不多,不同的是,model.parameters()方法返回的是一个生成器generator,每一个元素是从开头到结尾的参数,parameters没有对应的key名称,是一个由纯参数组成的generator,而state_dict是一个字典,包含了一个key。

四、保存模型.pth

.pt
.pt: 文件是PyTorch 1.6及以上版本中引入的新的模型文件格式,它可以保存整个PyTorch模型,包括模型结构、模型参数以及优化器状态等信息。.pt文件是一个二进制文件,可以通过torch.save()函数来保存模型,以及通过torch.load()函数来加载模型。

.pth
**.pth:**文件是PyTorch旧版本中使用的模型文件格式,。.pth文件同样是一个二进制文件,可以通过torch.save()函数来保存模型参数,以及通过torch.load()函数来加载模型参数。

.pth.tar
**.pth.tar:**文件是一个压缩文件,它包含一个.pth文件以及其他相关信息,比如模型结构、优化器状态、超参数等。.pth.tar文件可以通过Python的标准库tarfile来解压,然后通过torch.load()函数来加载模型。

总的来说,.pt文件是最新的、最全面的模型保存格式,可以保存整个PyTorch模型,包括模型结构、参数、优化器状态等信息。.pth文件只保存了模型参数,而.pth.tar文件则是在.pth基础上加入了一些元数据信息,可以方便地保存和加载整个模型状态。在实际应用中,我们可以根据需要选择适合自己的模型保存格式。

模型保存与调用方式一:

保存:

torch.save(model.state_dict(), mymodel.pth)#只保存模型权重参数,不保存模型结构
  • 1

调用:

model = My_model(*args, **kwargs)  #这里需要重新模型结构,My_model
model.load_state_dict(torch.load(mymodel.pth))#这里根据模型结构,调用存储的模型参数
model.eval()
  • 1
  • 2
  • 3

模型保存与调用方式一:

保存:

torch.save(model, mymodel.pth)#保存整个model的状态
  • 1

调用:

model=torch.load(mymodel.pth)#这里已经不需要重构模型结构了,直接load就可以
model.eval()
  • 1
  • 2
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/772741
推荐阅读
相关标签
  

闽ICP备14008679号