赞
踩
目录
按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。
Linear类是torch.nn中最基础的模块之一,其作用是用于构建线性全连接神经元网络。本文说明其计算原理及使用方法。
在PyTorch中,全连接层(Fully Connected Layer)通常通过torch.nn.Linear
模块来实现。全连接层是神经网络中的基本构建块之一,它将输入数据与一组可学习的权重进行矩阵乘法,然后可能添加一个偏置项,生成输出。
torch.nn.Linear
的工作原理基于线性代数中的矩阵乘法和向量加法。在神经网络中,一个线性层可以表示为输入数据与权重矩阵的乘积,再加上一个可选的偏置项。数学上,对于一个输入向量x
和一个权重矩阵A
,以及一个偏置向量b
,线性层的输出y
可以通过以下公式计算:
此部分也可以参见PyTorch官网:
Linear的作用是用于线性计算: 其中为输出,为输入,为权重,为偏置。
在调用Linear模块时,需要输入的参数有3个:
- in_features:输入数据的size(关于这个size的定义后面还会提到)
- out_features:输出数据的size
- bias:设定为True则会生成偏置,如果设定为False偏置为0
关于权重和偏置的初始值,设定为从到的随机值。
在PyTorch中,使用torch.nn.Linear
的步骤如下:
首先需要导入PyTorch的相关模块。
- import torch
- from torch import nn
使用nn.Linear
类创建一个线性层对象。在实例化时,需要指定两个参数:in_features
和out_features
。in_features
表示输入数据的特征维度(或说是上一层神经元的个数),out_features
表示输出数据的特征维度(或说是这一层神经元的个数)。
linear_layer = nn.Linear(in_features=5, out_features=3)
这将创建一个从5维输入到3维输出的线性层。
在模型的前向传播过程中,我们将输入数据传递给线性层以得到输出。
- input_data = torch.randn(10, 5) # 假设我们有10个样本,每个样本有5个特征
- output_data = linear_layer(input_data)
线性层通常作为更大的神经网络的一部分进行训练。在训练过程中,我们会使用优化器(如torch.optim.SGD
或torch.optim.Adam
)来更新线性层的权重和偏置。
- optimizer = torch.optim.SGD(linear_layer.parameters(), lr=0.01)
-
- for epoch in range(num_epochs):
- # 前向传播
- output = linear_layer(input_data)
-
- # 计算损失
- loss = some_loss_function(output, target_data)
-
- # 反向传播和优化
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
需要注意的是,在实际使用中,可能还需要考虑将模型和数据移动到GPU上以加速计算,这可以通过.to(device)
方法实现,其中device
是代表GPU的torch.device
对象。
这里可以通过.state_dict()方法打印出权重数据,用手算验证PyTorch输出结果。
这里值得注意的一点是:in_features和out_features并不是真正的输入输出数据的“大小”,而是输入输出数据的“最后一个维度的大小”,例如下面这段代码:
- import torch
-
- linear = torch.nn.Linear(in_features=3, out_features=1, bias=False)
-
- b = torch.tensor([[1,1,1]], dtype=torch.float32)
- c = torch.tensor([[1,1,1],
- [1,1,1],
- [1,1,1]], dtype=torch.float32)
-
- out2 = linear(b)
- out3 = linear(c)
-
- print(out2)
- print(out3)
-
- ------------------输出-----------------------
- tensor([[0.3057]], grad_fn=<MmBackward0>)
- tensor([[0.3057],
- [0.3057],
- [0.3057]], grad_fn=<MmBackward0>)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。