当前位置:   article > 正文

重新理解一个类中的forward()和__init__()函数_类里的forward

类里的forward

forward()函数和__init__()的关系

__init__() 是一个类的构造函数,用于初始化对象的属性。它会在创建对象时自动调用,而且通常在这里完成对象所需的所有初始化操作。

forward() 是一个神经网络模型中的方法,用于定义数据流的向前传播过程。它接受输入数据,通过网络的各个层进行计算,最终返回输出结果。

在神经网络的 PyTorch 实现中,__init__() 方法通常用于实例化各个网络层(例如卷积层、池化层、全连接层的维度等【这里只是执行了初始化,但是可以通过后面实例化时调用的forward()重新给神经网络维度赋值】),并设置各层的超参数(例如卷积核大小、步幅、填充等)。而 forward() 方法则定义了这些网络层之间的计算顺序与逻辑,它负责将输入数据传递到网络中,并返回计算结果【这里输入进forward的数据维度要和forward()接收的第一个参数维度相同,虽然你看它只接受了一个参数‘x’,但是这个x的维度是多维的(在本代码中就是(input_dim, hidden_dim)两个大维度),而不是普通意义上的一个自然数

因此,两个方法通常一起使用,__init__() 用于设置网络结构和超参数,forward() 则定义了从输入到输出的完整计算流程。

例子:

定义类:

  1. import torch
  2. import torch.nn as nn
  3. class SimpleNet(nn.Module):
  4. def __init__(self, input_dim, hidden_dim, output_dim):
  5. super(SimpleNet, self).__init__()
  6. self.fc1 = nn.Linear(input_dim, hidden_dim)
  7. self.relu = nn.ReLU()
  8. self.fc2 = nn.Linear(hidden_dim, output_dim)
  9. def forward(self, x):
  10. out = self.fc1(x)
  11. out = self.relu(out)
  12. out = self.fc2(out)
  13. return out

在上面的代码中,我们定义了一个名为 SimpleNet 的神经网络模型,它继承自 PyTorch 中的 nn.Module 类。我们在 __init__() 方法中定义了三层网络结构,分别是输入层 fc1、激活层 relu 和输出层 fc2。其中,输入层和输出层都使用了全连接层(nn.Linear),而激活层使用了 ReLU 激活函数。

forward() 方法中,我们按照输入数据 x 经过 fc1relufc2 三层的顺序进行计算,最终返回输出结果 out

调用

调用上述代码的 forward() 方法需要先创建一个 SimpleNet 类的对象,并将输入数据传递给该对象。以下是一个简单的示例:

  1. # 创建一个 SimpleNet 对象,设置输入维度为 10,隐藏层维度为 20,输出维度为 5
  2. net = SimpleNet(10, 20, 5)
  3. # 构造一个随机的输入张量,大小为 [batch_size, input_dim],这里令 batch_size=1
  4. input_tensor = torch.randn(1, 10)
  5. # 将输入张量传入网络中,得到输出张量
  6. output_tensor = net(input_tensor)
  7. # 打印输出张量的形状
  8. print(output_tensor.shape)

为什么上面的代没有看到 __init__()、forword()函数的出现就完成了上述代码的调用呢?

初始化一个类时,则自动调用了该类的 __init__() 方法【net = SimpleNet(10, 20, 5)】

调用一个类的实例时,会自动调用该类的forward() 方法【output_tensor = net(input_tensor)】

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/361177
推荐阅读
  

闽ICP备14008679号