赞
踩
在 PyTorch 中,self.parameters()
是一个模型方法,它返回模型中所有需要优化的参数。这些参数通常是模型中的权重和偏置项。
当你定义一个 PyTorch 模型类时,你会将模型的各个层(如全连接层、卷积层等)定义在 __init__
方法中,这些层中的参数都会被 PyTorch 自动识别为模型的可训练参数。self.parameters()
方法就是用来访问这些可训练参数的。
import torch import torch.nn as nn class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() # nn.Linear() --> y = xA^T + b self.fc1 = nn.Linear(10, 5) # 定义一个全连接层,输入维度为10,输出维度为5 self.fc2 = nn.Linear(5, 2) # 定义另一个全连接层,输入维度为5,输出维度为2 def forward(self, x): x = torch.relu(self.fc1(x)) # 使用ReLU激活函数进行前向传播 x = self.fc2(x) return x # 创建一个模型实例 model = SimpleModel() # 使用self.parameters()获取模型中的所有参数 params = model.parameters() # 遍历并输出模型中的参数及其形状 for param in params: print(param.shape) # torch.Size([5, 10]) 第一个全连接层的A # torch.Size([5]) 第一个全连接层的b # torch.Size([2, 5]) 第二个全连接层的A # torch.Size([2]) 第二个全连接层的b
model.parameters()
和 model.state_dict()
是 PyTorch 中用于获取模型参数的两种不同方式,它们之间有一些区别:
model.parameters()
:
model.parameters()
是一个方法,用于获取模型中所有需要训练的参数。model.state_dict()
:
model.state_dict()
是一个方法,用于获取模型的状态字典。通常情况下,当你需要保存或加载模型的参数时,model.state_dict()
是更常用的选择,因为它提供了模型参数及其名称的完整信息,方便了保存和加载模型的状态。而 model.parameters()
则更适用于需要直接对参数进行操作的情况,比如初始化参数或手动更新参数等。
import torch import torch.nn as nn # 定义一个简单的神经网络模型 class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.fc1 = nn.Linear(10, 5) # 定义一个全连接层,输入维度为10,输出维度为5 self.fc2 = nn.Linear(5, 2) # 定义另一个全连接层,输入维度为5,输出维度为2 def forward(self, x): x = torch.relu(self.fc1(x)) # 使用ReLU激活函数进行前向传播 x = self.fc2(x) return x # 创建一个模型实例 model = SimpleModel() # 打印模型结构 print("模型结构:") print(model) # 模型结构: # SimpleModel( # (fc1): Linear(in_features=10, out_features=5, bias=True) # (fc2): Linear(in_features=5, out_features=2, bias=True) # ) # 通过 model.parameters() 获取模型中的参数 print("\n所有参数:") for param in model.parameters(): print(param.shape) # 所有参数: # torch.Size([5, 10]) # torch.Size([5]) # torch.Size([2, 5]) # torch.Size([2]) # 通过 model.state_dict() 获取模型的状态字典 print("\n模型状态字典:") print(model.state_dict()) # 模型状态字典: # OrderedDict([('fc1.weight', tensor([[ 0.2434, 0.1585, -0.0489, -0.2854, 0.0958, 0.0450, 0.0235, -0.0228, # 0.2934, 0.1910], # [-0.1329, 0.1001, -0.0748, -0.2244, -0.2213, -0.0490, -0.2735, -0.0396, # -0.2985, -0.0525], # [-0.2757, -0.2826, -0.1690, 0.0196, -0.1237, -0.0701, 0.0759, -0.0892, # -0.0736, 0.1501], # [-0.3107, 0.1578, 0.2759, 0.1827, 0.1034, 0.2269, 0.0864, 0.2918, # -0.2557, 0.0274], # [ 0.1479, 0.1868, 0.2288, -0.2756, 0.2752, -0.1571, 0.1131, 0.1191, # 0.1174, 0.2341]])), ('fc1.bias', tensor([ 0.2031, 0.0612, 0.2677, 0.2544, -0.0595])), ('fc2.weight', tensor([[-0.3650, -0.1921, 0.0852, -0.0216, 0.0677], # [ 0.2857, 0.2233, 0.1513, -0.2641, 0.2005]])), ('fc2.bias', tensor([0.1477, 0.1283]))])
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。