当前位置:   article > 正文

注意力机制——Non-local Networks(NLNet)_非局部注意力模块

非局部注意力模块

Non-local Networks(NLNet):NLNet是一种非局部注意力模型,通过对整个输入空间的特征进行加权求和,以捕捉全局信息。

传统的卷积神经网络(CNN)在处理图像时,只考虑了局部区域内的像素信息,忽略了全局信息之间的相互作用。NLNets通过引入非局部块来解决这个问题,该块包括一个自注意力模块,用于学习像素之间的相互作用。

自注意力模块采用注意力机制来计算每个像素与其他像素之间的相互依赖关系,并使用这些依赖关系来加权聚合所有像素的特征表示。这种全局交互方式使得模型能够在像素之间建立远距离的依赖关系,从而提高了模型的表示能力。

NonLocalBlock模块pytorch实现:

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class NonLocalBlock(nn.Module):
  5. def __init__(self, in_channels, inter_channels=None):
  6. super(NonLocalBlock, self).__init__()
  7. self.in_channels = in_channels
  8. self.inter_channels = inter_channels or in_channels // 2
  9. # 定义 g、theta、phi、out 四个卷积层
  10. self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
  11. self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
  12. self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels, kernel_size=1, stride=1, padding=0)
  13. self.out = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels, kernel_size=1, stride=1, padding=0)
  14. # 定义 softmax 层,用于将 f_ij 进行归一化
  15. self.softmax = nn.Softmax(dim=-1)
  16. def forward(self, x):
  17. batch_size = x.size(0)
  18. # 计算 g(x)
  19. g_x = self.g(x).view(batch_size, self.inter_channels, -1)
  20. g_x = g_x.permute(0, 2, 1)
  21. # 计算 theta(x)
  22. theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
  23. theta_x = theta_x.permute(0, 2, 1)
  24. # 计算 phi(x)
  25. phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
  26. # 计算 f_ij
  27. f = torch.matmul(theta_x, phi_x)
  28. # 对 f_ij 进行归一化
  29. f_div_C = self.softmax(f)
  30. # 计算 y_i
  31. y = torch.matmul(f_div_C, g_x)
  32. y = y.permute(0, 2, 1).contiguous()
  33. y = y.view(batch_size, self.inter_channels, *x.size()[2:])
  34. # 计算 z_i
  35. y = self.out(y)
  36. z = y + x
  37. return z

NonLocalBlock模块在网络中添加:

  1. class NLNet(nn.Module):
  2. def __init__(self, num_classes=10):
  3. super(NLNet, self).__init__()
  4. self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
  5. self.bn1 = nn.BatchNorm2d(64)
  6. self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  7. self.nonlocal1 = NonLocalBlock(64)
  8. self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
  9. self.bn2 = nn.BatchNorm2d(128)
  10. self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  11. self.nonlocal2 = NonLocalBlock(128)
  12. self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
  13. self.bn3 = nn.BatchNorm2d(256)
  14. self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
  15. self.fc = nn.Linear(256*4*4, num_classes)
  16. def forward(self, x):
  17. x = self.conv1(x)
  18. x = F.relu(self.bn1(x))
  19. x = self.pool1(x)
  20. x = self.nonlocal1(x)
  21. x = self.conv2(x)
  22. x = F.relu(self.bn2(x))
  23. x = self.pool2(x)
  24. x = self.nonlocal2(x)
  25. x = self.conv3(x)
  26. x = F.relu(self.bn3(x))
  27. x = self.pool3(x)
  28. x = x.view(-1, 256*4*4)
  29. x = self.fc(x)
  30. return x

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/855339
推荐阅读
相关标签
  

闽ICP备14008679号