当前位置:   article > 正文

Pytorch中的self.parameters()

Pytorch中的self.parameters()

1. 作用

在 PyTorch 中,self.parameters() 是一个模型方法,它返回模型中所有需要优化的参数。这些参数通常是模型中的权重和偏置项。

当你定义一个 PyTorch 模型类时,你会将模型的各个层(如全连接层、卷积层等)定义在 __init__ 方法中,这些层中的参数都会被 PyTorch 自动识别为模型的可训练参数。self.parameters() 方法就是用来访问这些可训练参数的。

2. 例子

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        
        # nn.Linear()   -->    y = xA^T + b
        self.fc1 = nn.Linear(10, 5)  # 定义一个全连接层,输入维度为10,输出维度为5 
        self.fc2 = nn.Linear(5, 2)   # 定义另一个全连接层,输入维度为5,输出维度为2

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用ReLU激活函数进行前向传播
        x = self.fc2(x)
        return x

# 创建一个模型实例
model = SimpleModel()

# 使用self.parameters()获取模型中的所有参数
params = model.parameters()

# 遍历并输出模型中的参数及其形状
for param in params:
    print(param.shape)

# torch.Size([5, 10])  第一个全连接层的A
# torch.Size([5])      第一个全连接层的b
# torch.Size([2, 5])   第二个全连接层的A
# torch.Size([2])      第二个全连接层的b
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

3.与.state_dict()的区别

model.parameters()model.state_dict() 是 PyTorch 中用于获取模型参数的两种不同方式,它们之间有一些区别:

  1. model.parameters()
    • model.parameters() 是一个方法,用于获取模型中所有需要训练的参数。
    • 返回一个迭代器,可以用来访问模型中的参数张量。
    • 这个方法返回的是参数张量本身,不包含参数的名称信息。
  2. model.state_dict()
    • model.state_dict() 是一个方法,用于获取模型的状态字典。
    • 返回一个字典,其中包含了模型中所有有参数的名称及其对应的参数张量。
    • 这个字典中的键是参数的名称,值是参数张量。

通常情况下,当你需要保存或加载模型的参数时,model.state_dict() 是更常用的选择,因为它提供了模型参数及其名称的完整信息,方便了保存和加载模型的状态。而 model.parameters() 则更适用于需要直接对参数进行操作的情况,比如初始化参数或手动更新参数等。

4.一个对比的例子

import torch
import torch.nn as nn

# 定义一个简单的神经网络模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)  # 定义一个全连接层,输入维度为10,输出维度为5
        self.fc2 = nn.Linear(5, 2)   # 定义另一个全连接层,输入维度为5,输出维度为2

    def forward(self, x):
        x = torch.relu(self.fc1(x))  # 使用ReLU激活函数进行前向传播
        x = self.fc2(x)
        return x

# 创建一个模型实例
model = SimpleModel()

# 打印模型结构
print("模型结构:")
print(model)
# 模型结构:
# SimpleModel(
#   (fc1): Linear(in_features=10, out_features=5, bias=True)
#   (fc2): Linear(in_features=5, out_features=2, bias=True)
# )

# 通过 model.parameters() 获取模型中的参数
print("\n所有参数:")
for param in model.parameters():
    print(param.shape)

# 所有参数:
# torch.Size([5, 10])
# torch.Size([5])
# torch.Size([2, 5])
# torch.Size([2])

# 通过 model.state_dict() 获取模型的状态字典
print("\n模型状态字典:")
print(model.state_dict())

# 模型状态字典:
# OrderedDict([('fc1.weight', tensor([[ 0.2434,  0.1585, -0.0489, -0.2854,  0.0958,  0.0450,  0.0235, -0.0228,
#           0.2934,  0.1910],
#         [-0.1329,  0.1001, -0.0748, -0.2244, -0.2213, -0.0490, -0.2735, -0.0396,
#          -0.2985, -0.0525],
#         [-0.2757, -0.2826, -0.1690,  0.0196, -0.1237, -0.0701,  0.0759, -0.0892,
#          -0.0736,  0.1501],
#         [-0.3107,  0.1578,  0.2759,  0.1827,  0.1034,  0.2269,  0.0864,  0.2918,
#          -0.2557,  0.0274],
#         [ 0.1479,  0.1868,  0.2288, -0.2756,  0.2752, -0.1571,  0.1131,  0.1191,
#           0.1174,  0.2341]])), ('fc1.bias', tensor([ 0.2031,  0.0612,  0.2677,  0.2544, -0.0595])), ('fc2.weight', tensor([[-0.3650, -0.1921,  0.0852, -0.0216,  0.0677],
#         [ 0.2857,  0.2233,  0.1513, -0.2641,  0.2005]])), ('fc2.bias', tensor([0.1477, 0.1283]))])

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/548394
推荐阅读
相关标签
  

闽ICP备14008679号