赞
踩
在 PyTorch 中,你可以使用如下方式获取神经网络中的某一部分参数:
Module.named_parameters()
函数,这将返回一个迭代器,包含网络中所有可学习参数的名称和数值。例如:import torch import torch.nn as nn # 定义一个简单的神经网络 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 6, 3) self.conv2 = nn.Conv2d(6, 16, 3) self.fc1 = nn.Linear(16 * 6 * 6, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) # 创建网络实例并获取参数 net = Net() for name, param in net.named_parameters(): print(name, param.size())
输出结果如下:
conv1.weight torch.Size
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。