当前位置:   article > 正文

PyTorch中model.state_dict(),model.modules(),model.children(),model.named_children()等含义

model.state_dict()

PyTorch 中 model 的各种方法总结:

首先定义网络模型 Net:

import torch 
import torch.nn as nn 

class Net(nn.Module):

    def __init__(self, num_class=10):
        super().__init__()
        
        self.backbone = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(6),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=12, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(12),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Linear(9*8*8, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(128, num_class)
        )

    def forward(self, x):
        output = self.backbone(x)
        output = output.view(output.size()[0], -1)
        output = self.classifier(output)
    
        return output

model = Net()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

网络 Net 本身是一个 nn.Module 的子类,包含了 backboneclassifier 两个由 Sequential 容器组成的 nn.Module 子类,backbone 和 classifier 各自又包含一些网络层,这些网络层也都属于 nn.Module 子类,所以从外到内共有三级

  1. Net(nn.Module子类)
  2. backbone 和 classifier(Sequential,nn.Module子类),是 Net 的子网络层
  3. 具体的网络层如 conv,relu,batchnorm 等(nn.Module子类),是 backbone 或 classifier 的子网络层

model 各种方法的返回值:

model.modules()                                                         
>>> <generator object Module.modules at 0x7fb381953740>
model.named_modules()                                                                                                       
>>> <generator object Module.named_modules at 0x7fb3819537b0>

model.children()                                                        
>>> <generator object Module.children at 0x7fb381953ac0>
model.named_children()                                                 
>>> <generator object Module.named_children at 0x7fb3819539e0>

model.parameters()                                                     
>>> <generator object Module.parameters at 0x7fb381953f90>
model.named_parameters()                                               
>>> <generator object Module.named_parameters at 0x7fb3818b95f0>

model.state_dict()                                                     
>>> 
OrderedDict([('backbone.0.weight', tensor([[[[ 0.1200, -0.1627, -0.0841],
                        [-0.1369, -0.1525,  0.0541],
                        [ 0.1203,  0.0564,  0.0908]],
                      ……
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

除了 model.state_dict() 返回值为一个有序字典,其他方法的返回值都是一个生成器,通过 for 循环将内容保存在一个列表里:

model_modules = [m for m in model.modules()]          
model_named_modules = [m for m in model.named_modules()]     
   
model_children = [m for m in model.children()]         
model_named_children = [m for m in model.named_children()]  
                                                               
model_parameters = [m for m in model.parameters()]                                                                         
model_named_parameters = [m for m in model.named_parameters()]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

1. model.modules() 和 model.named_modules()

model.modules() 迭代遍历模型的 所有子层,子层是指继承了 nn.Module 类的层。

定义的网络模型 Net 中,Net() 本身,backbone()、classifier() 以及二者包含的所有的 layer 都继承了 nn.Module 类,因此会被迭代遍历,且遍历方式符合深度优先遍历。比如对 Net 使用 .modules() 方法,会按照如下顺序遍历:Net --> backbone --> backbone layer --> classifier --> classifier layer

model.named_modules() 就是 带有 layer name 的 model.modules(),也就是它在 model.modules() 的基础上,还返回这些 layer 的名字,返回的每个元素是一个 tuple,tuple 都一个元素是 layer 名称,第二个元素才是 layer 本身。除了在 model 定义时有明确命名的 backbone 和 classifier,其他 layer 都是按照 PyTorch 内部规则自动命名的。

>>> model_modules
>>> len(model_modules)   # 15
>>> model_named_modules
>>> len(model_named_modules)   # 15

##########################
## output model_modules ##
##########################                                                                                   
[Net(
   (backbone): Sequential(
     (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
     (1): ReLU(inplace=True)
     (2): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
     (4): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
     (5): ReLU(inplace=True)
     (6): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   )
   (classifier): Sequential(
     (0): Linear(in_features=576, out_features=128, bias=True)
     (1): ReLU(inplace=True)
     (2): Dropout(p=0.5, inplace=False)
     (3): Linear(in_features=128, out_features=10, bias=True)
   )
 ),
 Sequential(
   (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
   (1): ReLU(inplace=True)
   (2): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (4): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
   (5): ReLU(inplace=True)
   (6): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 ),
 Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1)),
 ReLU(inplace=True),
 BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1)),
 ReLU(inplace=True),
 BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
 Sequential(
   (0): Linear(in_features=576, out_features=128, bias=True)
   (1): ReLU(inplace=True)
   (2): Dropout(p=0.5, inplace=False)
   (3): Linear(in_features=128, out_features=10, bias=True)
 ),
 Linear(in_features=576, out_features=128, bias=True),
 ReLU(inplace=True),
 Dropout(p=0.5, inplace=False),
 Linear(in_features=128, out_features=10, bias=True)]    

################################
## output model_named_modules ##
################################
[('',
  Net(
    (backbone): Sequential(
      (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU(inplace=True)
      (2): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
      (5): ReLU(inplace=True)
      (6): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (classifier): Sequential(
      (0): Linear(in_features=576, out_features=128, bias=True)
      (1): ReLU(inplace=True)
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=128, out_features=10, bias=True)
    )
  )),
 ('backbone',
  Sequential(
    (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )),
 ('backbone.0', Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))),
 ('backbone.1', ReLU(inplace=True)),
 ('backbone.2', BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
 ('backbone.3', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)),
 ('backbone.4', Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))),
 ('backbone.5', ReLU(inplace=True)),
 ('backbone.6', BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
 ('backbone.7', MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)),
 ('classifier',
  Sequential(
    (0): Linear(in_features=576, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=128, out_features=10, bias=True)
  )),
 ('classifier.0', Linear(in_features=576, out_features=128, bias=True)),
 ('classifier.1', ReLU(inplace=True)),
 ('classifier.2', Dropout(p=0.5, inplace=False)),
 ('classifier.3', Linear(in_features=128, out_features=10, bias=True))]                                                             
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107

基于 model.modules() 和 model.named_modules() 方法,都能够修改特定的层。

1)基于 model.modules(),可使用 isinstance() 函数挑选特定层进行处理:

for layer in model.modules():
	if isinstance(layer, nn.Conv2d):
		<对layer进行处理>
  • 1
  • 2
  • 3

2)基于 model.named_modules(),如果在模型定义时给每个 layer 定义了 name,比如卷积层都是 conv1,conv2…,就可以这样处理:

for name, layer in model.named_modules():
	if 'conv' in name:
		<对layer进行处理>
  • 1
  • 2
  • 3

2. model.children() 和 model.named_children()

前面说过,Net 可以分为三级,分别是 1)Net,2)Net 的子网络层 backbone/classifier,3)backbone/classifier 的子网络层 conv、relu、batchnorm 等。

model.modules() 会遍历 model 的所有子层,也包括所有子层的子层。举个不严谨的例子,就是会遍历树形结构从 root 到 leaf 的所有节点。在上面的例子里,会遍历三级结构的每一个元素。

model.children() 只会获取 model 第二层 网络结构,比如在上面的例子里,只会获取 backbone 和 classifier,既没有 Net,也没有 backbone/classifier 的子层。model.named_children() 和前面同理,就是带有 layer name 的 model.children()。

>>> model_children
>>> len(model_children)   # 2
>>> model_named_children
>>> len(model_named_children)   # 2

###########################
## output model_children ##
###########################
[Sequential(
   (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
   (1): ReLU(inplace=True)
   (2): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
   (4): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
   (5): ReLU(inplace=True)
   (6): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 ),
 Sequential(
   (0): Linear(in_features=576, out_features=128, bias=True)
   (1): ReLU(inplace=True)
   (2): Dropout(p=0.5, inplace=False)
   (3): Linear(in_features=128, out_features=10, bias=True)
 )]

#################################
## output model_named_children ##
#################################
[('backbone',
  Sequential(
    (0): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(6, 12, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU(inplace=True)
    (6): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )),
 ('classifier',
  Sequential(
    (0): Linear(in_features=576, out_features=128, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=128, out_features=10, bias=True)
  ))]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46

3. model.parameters() 和 model.named_parameters()

model.parameters() 迭代地返回 模型所有可学习参数,有些 layer 不含有可学习参数(比如 relu、maxpool),因此 model.parameters() 不会输出这些层。

相应地,model.named_parameters() 就是带有 layer name 的 model.parameters(),每个 tuple 打包了两个元素,分别是 layer name 和 layer param。layer name 的后缀 .weight.bias 用于区分权重和偏置。

>>> model_parameters
>>> len(model_parameters)   # 12
>>> model_named_parameters
>>> len(model_named_parameters)   # 12

#############################
## output model_parameters ##
#############################                                                               
[Parameter containing:
 tensor([[[[ 0.1871,  0.0998, -0.1136],
           [ 0.0104,  0.1804,  0.0761],
           [ 0.0331,  0.0311,  0.0843]],
           	...
          [[ 0.0990,  0.0997,  0.0398],
           [ 0.1182, -0.0016,  0.1722],
           [-0.1830,  0.0451,  0.0737]]]], requires_grad=True),
 Parameter containing:
 tensor([ 0.0617,  0.1688, -0.0237, -0.1017,  0.0201,  0.0849],
        requires_grad=True),
 Parameter containing:
 tensor([1., 1., 1., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0.], requires_grad=True),
 Parameter containing:
 tensor([[[[ 0.1346,  0.0129, -0.1315],
           [-0.0621, -0.1044,  0.0091],
           [-0.0638, -0.0477, -0.0327]],
           ...
          [[-0.1067, -0.1073,  0.1203],
           [-0.1091, -0.0542, -0.0008],
           [ 0.0517,  0.0297,  0.1107]]]], requires_grad=True),
 Parameter containing:
 tensor([-0.0344,  0.1320,  0.0165,  0.0100,  0.0784, -0.0792,  0.0044,  0.0419,
          0.0234, -0.0159, -0.0053, -0.1342], requires_grad=True),
 Parameter containing:
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True),
 Parameter containing:
 tensor([[ 0.0234,  0.0177,  0.0184,  ..., -0.0057, -0.0067,  0.0385],
         [-0.0355,  0.0340, -0.0090,  ...,  0.0243,  0.0241, -0.0264],
         [-0.0203,  0.0172, -0.0239,  ..., -0.0279, -0.0294, -0.0038],
         ...,
         [ 0.0090, -0.0009,  0.0363,  ...,  0.0019, -0.0086, -0.0304],
         [ 0.0032,  0.0007,  0.0056,  ..., -0.0060,  0.0083, -0.0253],
         [-0.0102,  0.0276,  0.0365,  ...,  0.0016,  0.0248,  0.0273]],
        requires_grad=True),
 Parameter containing:
 tensor([-0.0288, -0.0141,  0.0232, -0.0006, -0.0001,  0.0311, -0.0135, -0.0219,
         ...
         -0.0285, -0.0271, -0.0366,  0.0029, -0.0206,  0.0245, -0.0203, -0.0208],
        requires_grad=True),
 Parameter containing:
 tensor([[ 0.0330, -0.0665, -0.0036,  ..., -0.0092,  0.0171,  0.0699],
         [ 0.0871, -0.0311,  0.0330,  ...,  0.0013, -0.0871,  0.0667],
         [-0.0146, -0.0383, -0.0370,  ...,  0.0261,  0.0599,  0.0240],
         ...,
         [ 0.0058, -0.0125, -0.0157,  ..., -0.0055, -0.0823, -0.0664],
         [-0.0488,  0.0545, -0.0859,  ..., -0.0786, -0.0524, -0.0451],
         [ 0.0201, -0.0197, -0.0538,  ..., -0.0369, -0.0202, -0.0865]],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.0654,  0.0766, -0.0597, -0.0595, -0.0724, -0.0484,  0.0121, -0.0212,
          0.0234, -0.0146], requires_grad=True)]


###################################
## output model_named_parameters ##
###################################
# 这里的数据和上面完全相同,简洁起见只print所有层的name
for k, v in model.named_parameters():
	print(k)
	
backbone.0.weight
backbone.0.bias
backbone.2.weight
backbone.2.bias
backbone.4.weight
backbone.4.bias
backbone.6.weight
backbone.6.bias
classifier.0.weight
classifier.0.bias
classifier.3.weight
classifier.3.bias
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85

4. model.state_dict()

model.state_dict() 能够获取 模型中的所有参数,包括可学习参数和不可学习参数,其返回值是一个有序字典 OrderedDict

从例子中可以看出,model.state_dict() 获取了 model 中所有的可学习参数(weight、bias),同时还获取了不可学习参数(BN layer 的 running mean 和 running var 等)。可以将 model.state_dict() 看作是在 model.parameters() 功能的基础上,又额外获取了所有不可学习参数。

OrderedDict([('backbone.0.weight',
              tensor([[[[ 0.1796,  0.0621,  0.1027],
                        [-0.0723, -0.0971,  0.0218],
                        [-0.0835, -0.0479,  0.0305]],
                        ...
                       [[-0.0544, -0.1858,  0.1559],
                        [-0.0589,  0.0146, -0.1285],
                        [-0.1033,  0.0743,  0.1137]]]])),
             ('backbone.0.bias',
              tensor([ 0.0202,  0.1326,  0.0124, -0.1895, -0.1094, -0.1045])),
             ('backbone.2.weight', tensor([1., 1., 1., 1., 1., 1.])),
             ('backbone.2.bias', tensor([0., 0., 0., 0., 0., 0.])),
             ('backbone.2.running_mean', tensor([0., 0., 0., 0., 0., 0.])),
             ('backbone.2.running_var', tensor([1., 1., 1., 1., 1., 1.])),
             ('backbone.2.num_batches_tracked', tensor(0)),
             ('backbone.4.weight',
              tensor([[[[ 1.3451e-01, -7.3591e-02, -1.0690e-01],
                        [-5.4909e-02, -3.3993e-02,  3.3203e-02],
                        [-6.4427e-02,  1.2523e-01, -3.7897e-02]],
                        ...
                       [[-1.0125e-01,  1.7249e-02, -6.3623e-02],
                        [ 4.0353e-02, -7.0894e-02,  6.0606e-03],
                        [ 6.2089e-02,  8.5485e-02,  1.0689e-01]]]])),
             ('backbone.4.bias',
              tensor([ 0.0999, -0.1271,  0.0010,  0.1151, -0.1221,  0.0144,  0.1088,  0.1214,
                      -0.0175, -0.1071,  0.0937, -0.0058])),
             ('backbone.6.weight',
              tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])),
             ('backbone.6.bias',
              tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])),
             ('backbone.6.running_mean',
              tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])),
             ('backbone.6.running_var',
              tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])),
             ('backbone.6.num_batches_tracked', tensor(0)),
             ('classifier.0.weight',
              tensor([[ 0.0359,  0.0245,  0.0020,  ...,  0.0282, -0.0255, -0.0319],
                      [ 0.0020,  0.0196,  0.0011,  ..., -0.0412,  0.0179,  0.0288],
                      [ 0.0251, -0.0245,  0.0152,  ...,  0.0136,  0.0084, -0.0052],
                      ...,
                      [ 0.0235, -0.0100, -0.0348,  ...,  0.0160, -0.0249, -0.0007],
                      [-0.0385,  0.0202, -0.0359,  ...,  0.0367,  0.0155, -0.0367],
                      [ 0.0092,  0.0375, -0.0229,  ..., -0.0322, -0.0065,  0.0008]])),
             ('classifier.0.bias',
              tensor([ 3.7528e-02, -2.4906e-02, -3.0417e-02, -2.9277e-02,  3.8544e-02,
                      ...
                      -1.4599e-02,  3.6207e-02,  1.8414e-02])),
             ('classifier.3.weight',
              tensor([[-0.0793, -0.0080,  0.0755,  ...,  0.0225,  0.0632,  0.0223],
                      [-0.0861, -0.0295,  0.0301,  ..., -0.0664, -0.0458,  0.0044],
                      [-0.0646,  0.0225, -0.0640,  ..., -0.0004,  0.0289, -0.0165],
                      ...,
                      [-0.0760, -0.0517, -0.0625,  ...,  0.0393, -0.0475, -0.0070],
                      [ 0.0558, -0.0860, -0.0813,  ..., -0.0578, -0.0843, -0.0303],
                      [-0.0077,  0.0227,  0.0247,  ..., -0.0424,  0.0134, -0.0196]])),
             ('classifier.3.bias',
              tensor([-0.0307,  0.0848,  0.0686,  0.0819,  0.0455,  0.0711,  0.0073,  0.0117,
                       0.0293,  0.0431]))])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58

5. model.state_dict() 和 model.parameters() 的区别

1. 返回值类型不同

model.parameters() 返回的是一个生成器 generator object,而 model.state_dict() 返回的是有序列表 OrderedDict

model.parameters()                                                     
>>> <generator object Module.parameters at 0x7fb381953f90>

model.state_dict()                                                     
>>> 
OrderedDict([('backbone.0.weight', tensor([[[[ 0.1200, -0.1627, -0.0841],
                        [-0.1369, -0.1525,  0.0541],
                        [ 0.1203,  0.0564,  0.0908]],
                        ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

2. 存储的模型参数种类不同

为了直观展示区别,这里使用 model.named_parameters() 与 model.parameters() 做比较:

model.parameters() 获取了模型中所有可学习的参数,而 model.state_dict() 在 model.parameters() 功能的基础上,又额外获取了所有不可学习参数(BN layer 的 running mean 和 running var 等)。

model_state_dict = model.state_dict()
model_named_parameters = model.named_parameters()

for k,v in model_named_parameters:
    print(k)

for k in model_state_dict:
    print(k)

###################################
## output model_named_parameters ##
###################################
backbone.0.weight
backbone.0.bias
backbone.2.weight
backbone.2.bias
backbone.4.weight
backbone.4.bias
backbone.6.weight
backbone.6.bias
classifier.0.weight
classifier.0.bias
classifier.3.weight
classifier.3.bias

#############################
## output model_state_dict ##
#############################
backbone.0.weight
backbone.0.bias
backbone.2.weight
backbone.2.bias
backbone.2.running_mean
backbone.2.running_var
backbone.2.num_batches_tracked
backbone.4.weight
backbone.4.bias
backbone.6.weight
backbone.6.bias
backbone.6.running_mean
backbone.6.running_var
backbone.6.num_batches_tracked
classifier.0.weight
classifier.0.bias
classifier.3.weight
classifier.3.bias
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/103450?site
推荐阅读
相关标签
  

闽ICP备14008679号