赞
踩
在PyTorch中,nn.Module 类扮演着核心角色,它是构建任何自定义神经网络层、复杂模块或完整神经网络架构的基础构建块。通过继承 nn.Module 并在其子类中定义模型结构和前向传播逻辑(forward() 方法),开发者能够方便地搭建并训练深度学习模型。
关于 nn.Module
的更多介绍可以参考博客:PyTorch之nn.Module、nn.Sequential、nn.ModuleList使用详解
这里,我们基于nn.Module
创建一个简单的神经网络模型,实现代码如下:
import torch
import torch.nn as nn
class MyModel(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(MyModel, self).__init__()
self.layer1 = nn.Linear(input_size, hidden_size)
self.layer2 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.relu(self.layer1(x))
x = self.layer2(x)
return x
nn.functional
是PyTorch中一个重要的模块,它包含了许多用于构建神经网络的函数。与 nn.Module
不同,nn.functional
中的函数不具有可学习的参数。这些函数通常用于执行各种非线性操作、损失函数、激活函数等。
如何在神经网络中使用nn.functional?
在PyTorch中,你可以轻松地在神经网络中使用 nn.functional
函数。通常,你只需将输入数据传递给这些函数,并将它们作为网络的一部分。
以下是一个简单的示例,演示如何在一个全连接神经网络中使用ReLU激活函数:
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.fc1 = nn.Linear(64, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
在上述示例中,我们首先导入nn.functional 模块,然后在网络的forward 方法中使用F.relu 函数作为激活函数。
nn.functional
的主要优势是它的计算效率和灵活性,因为它允许你以函数的方式直接调用这些操作,而不需要创建额外的层。
(1)激活函数
激活函数是神经网络中的关键组件,它们引入非线性性,使网络能够拟合复杂的数据。以下是一些常见的激活函数:
output = F.relu(input)
output = F.sigmoid(input)
output = F.tanh(input)
(2)损失函数
损失函数用于度量模型的预测与真实标签之间的差距。PyTorch的nn.functional 模块包含了各种常用的损失函数,例如:
loss = F.cross_entropy(input, target)
loss = F.mse_loss(input, target)
loss = F.l1_loss(input, target)
(3)非线性操作
nn.functional 模块还包含了许多非线性操作,如池化、归一化等。
output = F.max_pool2d(input, kernel_size)
output = F.batch_norm(input, mean, std, weight, bias)
nn.Module 与 nn.functional 的主要区别在于:
注意:
nn.ReLU() :
import torch.nn as nn
'''
nn.ReLU()
F.relu():
import torch.nn.functional as F
'''
out = F.relu(input)
其实这两种方法都是使用relu激活,只是使用的场景不一样,F.relu()是函数调用,一般使用在foreward函数里。而nn.ReLU()是模块调用,一般在定义网络层的时候使用。
当用print(net)输出时,nn.ReLU()会有对应的层,而F.ReLU()是没有输出的。
import torch.nn as nn import torch.nn.functional as F class NET1(nn.Module): def __init__(self): super(NET1, self).__init__() self.conv = nn.Conv2d(3, 16, 3, 1, 1) self.bn = nn.BatchNorm2d(16) self.relu = nn.ReLU() # 模块的激活函数 def forward(self, x): out = self.conv(x) x = self.bn(x) out = self.relu() return out class NET2(nn.Module): def __init__(self): super(NET2, self).__init__() self.conv = nn.Conv2d(3, 16, 3, 1, 1) self.bn = nn.BatchNorm2d(16) def forward(self, x): x = self.conv(x) x = self.bn(x) out = F.relu(x) # 函数的激活函数 return out net1 = NET1() net2 = NET2() print(net1) print(net2)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。