当前位置:   article > 正文

pytorch如何获取神经网络其中一部分的参数_怎么输出神经网络的参数

怎么输出神经网络的参数

在 PyTorch 中,你可以使用如下方式获取神经网络中的某一部分参数:

1. 使用 Module.named_parameters() 函数,这将返回一个迭代器,包含网络中所有可学习参数的名称和数值。例如:

import torch
import torch.nn as nn

# 定义一个简单的神经网络
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

# 创建网络实例并获取参数
net = Net()
for name, param in net.named_parameters():
    print(name, param.size())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

输出结果如下:

conv1.weight torch.Size
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/717063
    推荐阅读
    相关标签
      

    闽ICP备14008679号