当前位置:   article > 正文

注意力机制的原理及实现(pytorch)_空间注意力机制代码

空间注意力机制代码

本文参加新星计划人工智能(Pytorch)赛道:https://bbs.csdn.net/topics/613989052

空间注意力机制(attention Unet)

  1. class Attention_block(nn.Module):
  2. def __init__(self, F_g, F_l, F_int):
  3. super(Attention_block, self).__init__()
  4. self.W_g = nn.Sequential(
  5. nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
  6. nn.BatchNorm2d(F_int)
  7. )
  8. self.W_x = nn.Sequential(
  9. nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
  10. nn.BatchNorm2d(F_int)
  11. )
  12. self.psi = nn.Sequential(
  13. nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
  14. nn.BatchNorm2d(1),
  15. nn.Sigmoid()
  16. )
  17. self.relu = nn.ReLU(inplace=True)
  18. def forward(self, g, x):
  19. # 下采样的gating signal 卷积
  20. g1 = self.W_g(g)
  21. # 上采样的 l 卷积
  22. x1 = self.W_x(x)
  23. # concat + relu
  24. psi = self.relu(g1 + x1)
  25. # channel 减为1,并Sigmoid,得到权重矩阵
  26. psi = self.psi(psi)
  27. print(psi.size())
  28. # 返回加权的 x
  29. return x * psi

Unet。

通道注意力(seNet)

  1. from torch import nn
  2. class SELayer(nn.Module):
  3. def __init__(self, channel, reduction=16):
  4. super(SELayer, self).__init__()
  5. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  6. self.fc = nn.Sequential(
  7. nn.Linear(channel, channel // reduction, bias=False),
  8. nn.ReLU(inplace=True),
  9. nn.Linear(channel // reduction, channel, bias=False),
  10. nn.Sigmoid()
  11. )
  12. def forward(self, x):
  13. #得到输入张量的batch数量和通道数量
  14. b, c, _, _ = x.size()
  15. #通过平均池化,将张量的shape变为1*1
  16. y = self.avg_pool(x).view(b, c)
  17. #通过全连接层学习权重,得到通道上的权值
  18. y = self.fc(y).view(b, c, 1, 1)
  19. #将y的形状与x对齐并相乘得到最后的输出
  20. return x * y.expand_as(x)

给定一个输入 ,其特征通道数为 ,通过一系列卷积等一般变换后得到一个特征通道数为 的特征。与传统的CNN不一样的是,接下来通过三个操作来重标定前面得到的特征。

1) Squeeze(压缩)。顺着空间维度来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。它表征着在特征通道上响应的全局分布,而且使得靠近输入的层也可以获得全局的感受野,这一点在很多任务中都是非常有用。

2) Excitation(激发)。它是一个类似于循环神经网络中门的机制。通过参数来为每个特征通道生成权重,其中参数被学习用来显式地建模特征通道间的相关性。

3)Reweight(缩放)。将Excitation的输出的权重看做是进过特征选择后的每个特征通道的重要性,然后通过乘法逐通道加权到先前的特征上,完成在通道维度上的对原始特征的重标定

空间注意力+通道注意力(CBAM)

CBAM模块

CAM模块和SAM模块

CBAM的Pytorch实现

  1. # ------------------------#
  2. # CBAM模块的Pytorch实现
  3. # ------------------------#
  4. # 通道注意力模块
  5. class ChannelAttentionModule(nn.Module):
  6. def __init__(self, channel, reduction=16):
  7. super(ChannelAttentionModule, self).__init__()
  8. mid_channel = channel // reduction
  9. # 使用自适应池化缩减map的大小,保持通道不变
  10. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  11. self.max_pool = nn.AdaptiveMaxPool2d(1)
  12. self.shared_MLP = nn.Sequential(
  13. nn.Linear(in_features=channel, out_features=mid_channel),
  14. nn.ReLU(),
  15. nn.Linear(in_features=mid_channel, out_features=channel)
  16. )
  17. self.sigmoid = nn.Sigmoid()
  18. # self.act=SiLU()
  19. def forward(self, x):
  20. avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0),-1)).unsqueeze(2).unsqueeze(3)
  21. maxout = self.shared_MLP(self.max_pool(x).view(x.size(0),-1)).unsqueeze(2).unsqueeze(3)
  22. return self.sigmoid(avgout + maxout)
  23. # 空间注意力模块
  24. class SpatialAttentionModule(nn.Module):
  25. def __init__(self):
  26. super(SpatialAttentionModule, self).__init__()
  27. self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
  28. # self.act=SiLU()
  29. self.sigmoid = nn.Sigmoid()
  30. def forward(self, x):
  31. # map尺寸不变,缩减通道
  32. avgout = torch.mean(x, dim=1, keepdim=True)
  33. maxout, _ = torch.max(x, dim=1, keepdim=True)
  34. out = torch.cat([avgout, maxout], dim=1)
  35. out = self.sigmoid(self.conv2d(out))
  36. return out
  37. # CBAM模块
  38. class CBAM(nn.Module):
  39. def __init__(self, channel):
  40. super(CBAM, self).__init__()
  41. self.channel_attention = ChannelAttentionModule(c1)
  42. self.spatial_attention = SpatialAttentionModule()
  43. def forward(self, x):
  44. out = self.channel_attention(x) * x
  45. out = self.spatial_attention(out) * out
  46. return out

ResNet中与一个ResBlock集成的CBAM的用法

pytorch代码实现

  1. # ------------------#
  2. # ResBlock+CBAM
  3. # ------------------#
  4. import torch
  5. import torch.nn as nn
  6. import torchvision
  7. class ChannelAttentionModule(nn.Module):
  8. def __init__(self, channel, ratio=16):
  9. super(ChannelAttentionModule, self).__init__()
  10. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  11. self.max_pool = nn.AdaptiveMaxPool2d(1)
  12. self.shared_MLP = nn.Sequential(
  13. nn.Conv2d(channel, channel // ratio, 1, bias=False),
  14. nn.ReLU(),
  15. nn.Conv2d(channel // ratio, channel, 1, bias=False)
  16. )
  17. self.sigmoid = nn.Sigmoid()
  18. def forward(self, x):
  19. avgout = self.shared_MLP(self.avg_pool(x))
  20. print(avgout.shape)
  21. maxout = self.shared_MLP(self.max_pool(x))
  22. return self.sigmoid(avgout + maxout)
  23. class SpatialAttentionModule(nn.Module):
  24. def __init__(self):
  25. super(SpatialAttentionModule, self).__init__()
  26. self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
  27. self.sigmoid = nn.Sigmoid()
  28. def forward(self, x):
  29. avgout = torch.mean(x, dim=1, keepdim=True)
  30. maxout, _ = torch.max(x, dim=1, keepdim=True)
  31. out = torch.cat([avgout, maxout], dim=1)
  32. out = self.sigmoid(self.conv2d(out))
  33. return out
  34. class CBAM(nn.Module):
  35. def __init__(self, channel):
  36. super(CBAM, self).__init__()
  37. self.channel_attention = ChannelAttentionModule(channel)
  38. self.spatial_attention = SpatialAttentionModule()
  39. def forward(self, x):
  40. out = self.channel_attention(x) * x
  41. print('outchannels:{}'.format(out.shape))
  42. out = self.spatial_attention(out) * out
  43. return out
  44. class ResBlock_CBAM(nn.Module):
  45. def __init__(self,in_places, places, stride=1,downsampling=False, expansion = 4):
  46. super(ResBlock_CBAM,self).__init__()
  47. self.expansion = expansion
  48. self.downsampling = downsampling
  49. self.bottleneck = nn.Sequential(
  50. nn.Conv2d(in_channels=in_places,out_channels=places,kernel_size=1,stride=1, bias=False),
  51. nn.BatchNorm2d(places),
  52. nn.ReLU(inplace=True),
  53. nn.Conv2d(in_channels=places, out_channels=places, kernel_size=3, stride=stride, padding=1, bias=False),
  54. nn.BatchNorm2d(places),
  55. nn.ReLU(inplace=True),
  56. nn.Conv2d(in_channels=places, out_channels=places*self.expansion, kernel_size=1, stride=1, bias=False),
  57. nn.BatchNorm2d(places*self.expansion),
  58. )
  59. self.cbam = CBAM(channel=places*self.expansion)
  60. if self.downsampling:
  61. self.downsample = nn.Sequential(
  62. nn.Conv2d(in_channels=in_places, out_channels=places*self.expansion, kernel_size=1, stride=stride, bias=False),
  63. nn.BatchNorm2d(places*self.expansion)
  64. )
  65. self.relu = nn.ReLU(inplace=True)
  66. def forward(self, x):
  67. residual = x
  68. out = self.bottleneck(x)
  69. print(x.shape)
  70. out = self.cbam(out)
  71. if self.downsampling:
  72. residual = self.downsample(x)
  73. out += residual
  74. out = self.relu(out)
  75. return out
  76. model = ResBlock_CBAM(in_places=16, places=4)
  77. print(model)
  78. input = torch.randn(2, 16, 64, 64)
  79. out = model(input)
  80. print(out.shape)

自注意力机制(self-attention)

  1. # Muti-head Attention 机制的实现
  2. from math import sqrt
  3. import torch
  4. import torch.nn
  5. class Self_Attention(nn.Module):
  6. # input : batch_size * seq_len * input_dim
  7. # q : batch_size * input_dim * dim_k
  8. # k : batch_size * input_dim * dim_k
  9. # v : batch_size * input_dim * dim_v
  10. def __init__(self,input_dim,dim_k,dim_v):
  11. super(Self_Attention,self).__init__()
  12. self.q = nn.Linear(input_dim,dim_k)
  13. self.k = nn.Linear(input_dim,dim_k)
  14. self.v = nn.Linear(input_dim,dim_v)
  15. self._norm_fact = 1 / sqrt(dim_k)
  16. def forward(self,x):
  17. Q = self.q(x) # Q: batch_size * seq_len * dim_k
  18. K = self.k(x) # K: batch_size * seq_len * dim_k
  19. V = self.v(x) # V: batch_size * seq_len * dim_v
  20. atten = nn.Softmax(dim=-1)(torch.bmm(Q,K.permute(0,2,1))) * self._norm_fact # Q * K.T() # batch_size * seq_len * seq_len
  21. output = torch.bmm(atten,V) # Q * K.T() * V # batch_size * seq_len * dim_v
  22. return output

vison transformer实现

input:[2,3,256,256]

划分小patch:patch_szie:32*32,num_patches=(256//32)**2

n,c,w,h --> n,w*h//(32*32), 32*32*c

linear: n,w*h//(32*32), 32*32*c -->n,w*h//(32*32), 32*32

position_embeding

class_token

  1. import torch
  2. from torch import nn
  3. from einops import rearrange, repeat
  4. from einops.layers.torch import Rearrange
  5. # helpers
  6. #返回一个tuple,宽高信息
  7. def pair(t):
  8. return t if isinstance(t, tuple) else (t, t)
  9. # classes
  10. class PreNorm(nn.Module):
  11. def __init__(self, dim, fn):
  12. super().__init__()
  13. self.norm = nn.LayerNorm(dim)
  14. self.fn = fn
  15. def forward(self, x, **kwargs):
  16. return self.fn(self.norm(x), **kwargs)
  17. class FeedForward(nn.Module):
  18. def __init__(self, dim, hidden_dim, dropout = 0.):
  19. super().__init__()
  20. self.net = nn.Sequential(
  21. nn.Linear(dim, hidden_dim),
  22. nn.ReLU(),
  23. nn.Dropout(dropout),
  24. nn.Linear(hidden_dim, dim),
  25. nn.Dropout(dropout)
  26. )
  27. def forward(self, x):
  28. return self.net(x)
  29. class Attention(nn.Module):
  30. def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
  31. super().__init__()
  32. inner_dim = dim_head * heads
  33. project_out = not (heads == 1 and dim_head == dim)
  34. self.heads = heads
  35. self.scale = dim_head ** -0.5
  36. self.attend = nn.Softmax(dim = -1)
  37. self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
  38. self.to_out = nn.Sequential(
  39. nn.Linear(inner_dim, dim),
  40. nn.Dropout(dropout)
  41. ) if project_out else nn.Identity()
  42. def forward(self, x):
  43. qkv = self.to_qkv(x).chunk(3, dim = -1)
  44. q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
  45. dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
  46. attn = self.attend(dots)
  47. out = torch.matmul(attn, v)
  48. out = rearrange(out, 'b h n d -> b n (h d)')
  49. return self.to_out(out)
  50. class Transformer(nn.Module):
  51. def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
  52. super().__init__()
  53. self.layers = nn.ModuleList([])
  54. for _ in range(depth):
  55. self.layers.append(nn.ModuleList([
  56. PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
  57. PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
  58. ]))
  59. def forward(self, x):
  60. for attn, ff in self.layers:
  61. x = attn(x) + x
  62. x = ff(x) + x
  63. return x
  64. class ViT(nn.Module):
  65. def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
  66. super().__init__()
  67. image_height, image_width = pair(image_size)
  68. patch_height, patch_width = pair(patch_size)
  69. assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'
  70. num_patches = (image_height // patch_height) * (image_width // patch_width)
  71. patch_dim = channels * patch_height * patch_width
  72. assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
  73. self.to_patch_embedding = nn.Sequential(
  74. Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
  75. nn.Linear(patch_dim, dim),
  76. )
  77. self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
  78. self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
  79. self.dropout = nn.Dropout(emb_dropout)
  80. self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
  81. self.pool = pool
  82. self.to_latent = nn.Identity()
  83. self.mlp_head = nn.Sequential(
  84. nn.LayerNorm(dim),
  85. nn.Linear(dim, num_classes)
  86. )
  87. def forward(self, img):
  88. x = self.to_patch_embedding(img)
  89. b, n, _ = x.shape
  90. cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
  91. x = torch.cat((cls_tokens, x), dim=1)
  92. x += self.pos_embedding[:, :(n + 1)]
  93. x = self.dropout(x)
  94. x = self.transformer(x)
  95. x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
  96. x = self.to_latent(x)
  97. return self.mlp_head(x)

测试脚本

  1. import torch
  2. from vit_pytorch import ViT
  3. import numpy as np
  4. v = ViT(
  5. image_size = 256,
  6. patch_size = 32,
  7. num_classes = 1000,
  8. dim = 1024,
  9. depth = 6,
  10. heads = 16,
  11. mlp_dim = 2048,
  12. dropout = 0.1,
  13. emb_dropout = 0.1
  14. )
  15. img = torch.randn(2, 3, 256, 256)
  16. preds = v(img) # (1, 1000)
  17. print(preds.shape)
  18. print(np.argmax(preds.detach().numpy(), 1))
空间注意力
https://blog.csdn.net/weixin_37737254/article/details/125863392
代码 https://github.com/Andy-zhujunwen/UNET-ZOO/blob/master/attention_unet.py
通道注意力
https://blog.csdn.net/gaoxueyi551/article/details/120233959
代码 https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py
空间注意力加通道注意力(CBAM):
https://blog.csdn.net/weixin_41790863/article/details/123413303
论文: https://arxiv.org/pdf/1807.06521.pdf
自注意力机制(vision transformer)
论文: https://arxiv.org/pdf/2010.11929.pdf
参考博客:
https://blog.csdn.net/weixin_42392454/article/details/122667271
https://zhuanlan.zhihu.com/p/410776234

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/290966?site
推荐阅读
相关标签
  

闽ICP备14008679号