当前位置:   article > 正文

注意力机制——ECANet(Efficient Channel Attention Network)_eca注意力机制

eca注意力机制

ECANet(Efficient Channel Attention Network)是一种新颖的注意力机制,用于深度神经网络中的特征提取,它可以有效地减少模型参数量和计算量,提高模型的性能。

ECANet注意力机制是针对通道维度的注意力加权机制。它的基本思想是,通过学习通道之间的相关性,自适应地调整通道的权重,以提高网络的性能。ECANet通过两个步骤实现通道注意力加权:      1.提取通道特征             2.计算通道权重

用pytorch实现ECANet注意力机制:

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class ECANet(nn.Module):
  5. def __init__(self, in_channels, r=8):
  6. super(ECANet, self).__init__()
  7. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  8. self.fc1 = nn.Linear(in_channels, in_channels // r, bias=False)
  9. self.relu = nn.ReLU(inplace=True)
  10. self.fc2 = nn.Linear(in_channels // r, in_channels, bias=False)
  11. self.sigmoid = nn.Sigmoid()
  12. def forward(self, x):
  13. b, c, _, _ = x.size()
  14. y = self.avg_pool(x).view(b, c)
  15. y = self.fc1(y)
  16. y = self.relu(y)
  17. y = self.fc2(y)
  18. y = self.sigmoid(y).view(b, c, 1, 1)
  19. return x * y
  • nn.AdaptiveAvgPool2d(1)用于将输入的特征图转换为1x1大小的特征图,以进行全局平均池化。
  • nn.Linear(in_channels, in_channels // r, bias=False)是线性层,将输入通道数降低到输入通道数的r分之一,其中r是一个超参数。
  • nn.ReLU(inplace=True)是激活函数,将线性层的输出通过非线性变换。
  • nn.Linear(in_channels // r, in_channels, bias=False)是另一个线性层,将通道数恢复到原始数量。
  • nn.Sigmoid()是一个非线性函数,将输出值限制在0到1之间。

将ECANet注意力机制添加到神经网络中:

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class Net(nn.Module):
  5. def __init__(self):
  6. super(Net, self).__init__()
  7. self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
  8. self.ecanet1 = ECANet(64)
  9. self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
  10. self.ecanet2 = ECANet(128)
  11. self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
  12. self.ecanet3 = ECANet(256)
  13. self.fc1 = nn.Linear(256 * 8 * 8, 512)
  14. self.fc2 = nn.Linear(512, 10)
  15. def forward(self, x):
  16. x = F.relu(self.conv1(x))
  17. x = self.ecanet1(x)
  18. x = F.max_pool2d(x, 2)
  19. x = F.relu(self.conv2(x))
  20. x = self.ecanet2(x)
  21. x = F.max_pool2d(x, 2)
  22. x = F.relu(self.conv3(x))
  23. x = self.ecanet3(x)
  24. x = F.max_pool2d(x, 2)
  25. x = x.view(-1, 256 * 8 * 8)
  26. x = F.relu(self.fc1(x))
  27. x = self.fc2(x)
  28. return x

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/307133
推荐阅读
相关标签
  

闽ICP备14008679号