当前位置:   article > 正文

pytorch nn.Module类及其参数详解 state_dict和parameters_nn.moudle参数

nn.moudle参数

pytorch nn.Module类

详解Pytorch中的网络构造

pytorch中文文档

pytorch教程之nn.Module类详解——state_dict和parameters两个方法的差异性比较

import torch
import torch.nn.functional as F
from torch.optim import SGD

class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()  # 第一句话,调用父类的构造函数
        self.conv1 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        self.relu1=torch.nn.ReLU()
        self.max_pooling1=torch.nn.MaxPool2d(2,1)

        self.mlp = torch.nn.Sequential( 
            torch.nn.Conv2d(3, 32, 3, 2, 1),
            torch.nn.Sigmoid(),
            torch.nn.MaxPool2d(3,1),)
        # self.conv2 = torch.nn.Conv2d(3, 32, 3, 1, 1)
        # self.relu2=torch.nn.ReLU()
        # self.max_pooling2=torch.nn.MaxPool2d(2,1)

        self.dense1 = torch.nn.Linear(32 * 3 * 3, 128)
        self.dense2 = torch.nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.max_pooling1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.max_pooling2(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

model = MyNet() # 构造模型
print(model.parameters())

print('\n')
for name, para in model.named_parameters():
    print(name)
    print(para.shape)
    print("---------------------------")

print('\n')
for name, para in model.named_parameters():
    if 'bias' in name:
        print(name)
        print(para.type)
        print("---------------------------")

print('\n')
no_decay = ["bias", "LayerNorm.weight"]
i = 0
for n, p in model.named_parameters():
    i += 1
    print("{}".format(i))
    print(any(nd in n for nd in no_decay))
params_decay = [n for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)]
params_nodecay = [n for n, p in model.named_parameters() if any(nd in n for nd in no_decay)]
print(params_decay)
print(params_nodecay)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60

输出:

<generator object Module.parameters at 0x000002018E623048>


conv1.weight
torch.Size([32, 3, 3, 3])
---------------------------
conv1.bias
torch.Size([32])
---------------------------
mlp.0.weight
torch.Size([32, 3, 3, 3])
---------------------------
mlp.0.bias
torch.Size([32])
---------------------------
dense1.weight
torch.Size([128, 288])
---------------------------
dense1.bias
torch.Size([128])
---------------------------
dense2.weight
torch.Size([10, 128])
---------------------------
dense2.bias
torch.Size([10])
---------------------------


conv1.bias
<built-in method type of Parameter object at 0x0000020193FE17C8>
---------------------------
mlp.0.bias
<built-in method type of Parameter object at 0x0000020193FE1868>
---------------------------
dense1.bias
<built-in method type of Parameter object at 0x0000020193FE1908>
---------------------------
dense2.bias
<built-in method type of Parameter object at 0x0000020193FE19A8>
---------------------------


1
False
2
True
3
False
4
True
5
False
6
True
7
False
8
True
['conv1.weight', 'mlp.0.weight', 'dense1.weight', 'dense2.weight']
['conv1.bias', 'mlp.0.bias', 'dense1.bias', 'dense2.bias']```

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/361190
推荐阅读
相关标签
  

闽ICP备14008679号