当前位置:   article > 正文

综述|计算机视觉中的注意力机制

计算机视觉自注意力综述
  1. 点击上方“小白学视觉”,选择加"星标"或“置顶”
  2. 重磅干货,第一时间送达

f6b17087d98c0f0c4a9bf0788298e371.png

作者丨HUST小菜鸡@知乎

来源丨https://zhuanlan.zhihu.com/p/146130215

编辑丨极市平台

之前在看DETR这篇论文中的self_attention,然后结合之前实验室组会经常提起的注意力机制,所以本周时间对注意力机制进行了相关的梳理,以及相关的源码阅读了解其实现的机制。

一、注意力机制(attention mechanism)

attention机制可以它认为是一种资源分配的机制,可以理解为对于原本平均分配的资源根据attention对象的重要程度重新分配资源,重要的单位就多分一点,不重要或者不好的单位就少分一点,在深度神经网络的结构设计中,attention所要分配的资源基本上就是权重了。

视觉注意力分为几种,核心思想是基于原有的数据找到其之间的关联性,然后突出其某些重要特征,有通道注意力,像素注意力,多阶注意力等,也有把NLP中的自注意力引入。

二、自注意力(self-attention)

参考文献:http://papers.nips.cc/paper/7181-attention-is-all-you-need

参考资料:https://zhuanlan.zhihu.com/p/48508221

GitHub:https://github.com/huggingface/transformers

自注意力有时候也称为内部注意力,是一个与单个序列的不同位置相关的注意力机制,目的是计算序列的表达形式,因为解码器的位置不变性,以及在DETR中,每个像素不仅仅包含数值信息,并且每个像素的位置信息也很重要。

4066a27cdae4391f13c34e15bb8f1b2e.png

所有的编码器在结构上都是相同的,但它们没有共享参数。每个编码器都可以分解成两个子层:

98eb6cb28309ba41d2a0b4253629be89.png

在transformer中,每个encoder子层有Multi-head self-attention和position-wise FFN组成。

28f9583a1817a69b052d2d7e8da7072c.png

输入的每个单词通过嵌入的方式形成词向量,通过自注意进行编码,然后再送入FFN得出一个层级的编码。

d28a911fb30c780502d410ce24a69d90.png

解码器在结构上也是多个相同的堆叠而成,在有和encoder相似的结构的Multi-head self-attention和position-wise FFN,同时还多了一个注意力层用来关注输入句子的相关部分。

Self-Attention

Self-Attention是Transformer最核心的内容,可以理解位将队列和一组值与输入对应,即形成querry,key,value向output的映射,output可以看作是value的加权求和,加权值则是由Self-Attention来得出的。

具体实施细节如下:

在self-attention中,每个单词有3个不同的向量,它们分别是Query向量,Key向量和Value向量,长度均是64。它们是通过3个不同的权值矩阵由嵌入向量X乘以三个不同的权值矩阵得到,其中三个矩阵的尺寸也是相同的。均是512×64。

9e89f76578bb7e26daac10e65af7b461.png

Self_attention的计算过程如下

  1. 将输入单词转化成嵌入向量;

  2. 根据嵌入向量得到q,k,v三个向量;

  3. 为每个向量计算一个score:score=q×v;

  4. 为了梯度的稳定,Transformer使用了score归一化,即除以sqrt(dk);

  5. 对score施以softmax激活函数;

  6. softmax点乘Value值v,得到加权的每个输入向量的评分v;

  7. 相加之后得到最终的输出结果z。

690c4651a4226722bf4fbc01881c9dd9.png

矩阵形式的计算过程:

2904ea56decfb3c0f42a218e0ac4a9a6.png

对于Multi-head self-attention,通过论文可以看出就是将单个点积注意力进行融合,两者相结合得出了transformer

3365d6f0cca881f0636d3c8649e17f31.png

a6ddba22c520886f980de8078036bc2f.png

具体的实施可以参照detr的models/transformer

  1. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  2. """
  3. DETR Transformer class.
  4. Copy-paste from torch.nn.Transformer with modifications:
  5. * positional encodings are passed in MHattention
  6. * extra LN at the end of encoder is removed
  7. * decoder returns a stack of activations from all decoding layers
  8. """
  9. import copy
  10. from typing import Optional, List
  11. import torch
  12. import torch.nn.functional as F
  13. from torch import nn, Tensor
  14. class Transformer(nn.Module):
  15. def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
  16. num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
  17. activation="relu", normalize_before=False,
  18. return_intermediate_dec=False):
  19. super().__init__()
  20. encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
  21. dropout, activation, normalize_before)
  22. encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
  23. self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
  24. decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
  25. dropout, activation, normalize_before)
  26. decoder_norm = nn.LayerNorm(d_model)
  27. self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
  28. return_intermediate=return_intermediate_dec)
  29. self._reset_parameters()
  30. self.d_model = d_model
  31. self.nhead = nhead
  32. def _reset_parameters(self):
  33. for p in self.parameters():
  34. if p.dim() > 1:
  35. nn.init.xavier_uniform_(p)
  36. def forward(self, src, mask, query_embed, pos_embed):
  37. # flatten NxCxHxW to HWxNxC
  38. bs, c, h, w = src.shape
  39. src = src.flatten(2).permute(2, 0, 1)
  40. pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
  41. query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
  42. mask = mask.flatten(1)
  43. tgt = torch.zeros_like(query_embed)
  44. memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
  45. hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
  46. pos=pos_embed, query_pos=query_embed)
  47. return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
  48. class TransformerEncoder(nn.Module):
  49. def __init__(self, encoder_layer, num_layers, norm=None):
  50. super().__init__()
  51. self.layers = _get_clones(encoder_layer, num_layers)
  52. self.num_layers = num_layers
  53. self.norm = norm
  54. def forward(self, src,
  55. mask: Optional[Tensor] = None,
  56. src_key_padding_mask: Optional[Tensor] = None,
  57. pos: Optional[Tensor] = None):
  58. output = src
  59. for layer in self.layers:
  60. output = layer(output, src_mask=mask,
  61. src_key_padding_mask=src_key_padding_mask, pos=pos)
  62. if self.norm is not None:
  63. output = self.norm(output)
  64. return output
  65. class TransformerDecoder(nn.Module):
  66. def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
  67. super().__init__()
  68. self.layers = _get_clones(decoder_layer, num_layers)
  69. self.num_layers = num_layers
  70. self.norm = norm
  71. self.return_intermediate = return_intermediate
  72. def forward(self, tgt, memory,
  73. tgt_mask: Optional[Tensor] = None,
  74. memory_mask: Optional[Tensor] = None,
  75. tgt_key_padding_mask: Optional[Tensor] = None,
  76. memory_key_padding_mask: Optional[Tensor] = None,
  77. pos: Optional[Tensor] = None,
  78. query_pos: Optional[Tensor] = None):
  79. output = tgt
  80. intermediate = []
  81. for layer in self.layers:
  82. output = layer(output, memory, tgt_mask=tgt_mask,
  83. memory_mask=memory_mask,
  84. tgt_key_padding_mask=tgt_key_padding_mask,
  85. memory_key_padding_mask=memory_key_padding_mask,
  86. pos=pos, query_pos=query_pos)
  87. if self.return_intermediate:
  88. intermediate.append(self.norm(output))
  89. if self.norm is not None:
  90. output = self.norm(output)
  91. if self.return_intermediate:
  92. intermediate.pop()
  93. intermediate.append(output)
  94. if self.return_intermediate:
  95. return torch.stack(intermediate)
  96. return output
  97. class TransformerEncoderLayer(nn.Module):
  98. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
  99. activation="relu", normalize_before=False):
  100. super().__init__()
  101. self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
  102. # Implementation of Feedforward model
  103. self.linear1 = nn.Linear(d_model, dim_feedforward)
  104. self.dropout = nn.Dropout(dropout)
  105. self.linear2 = nn.Linear(dim_feedforward, d_model)
  106. self.norm1 = nn.LayerNorm(d_model)
  107. self.norm2 = nn.LayerNorm(d_model)
  108. self.dropout1 = nn.Dropout(dropout)
  109. self.dropout2 = nn.Dropout(dropout)
  110. self.activation = _get_activation_fn(activation)
  111. self.normalize_before = normalize_before
  112. def with_pos_embed(self, tensor, pos: Optional[Tensor]):
  113. return tensor if pos is None else tensor + pos
  114. def forward_post(self,
  115. src,
  116. src_mask: Optional[Tensor] = None,
  117. src_key_padding_mask: Optional[Tensor] = None,
  118. pos: Optional[Tensor] = None):
  119. q = k = self.with_pos_embed(src, pos)
  120. src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
  121. key_padding_mask=src_key_padding_mask)[0]
  122. src = src + self.dropout1(src2)
  123. src = self.norm1(src)
  124. src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
  125. src = src + self.dropout2(src2)
  126. src = self.norm2(src)
  127. return src
  128. def forward_pre(self, src,
  129. src_mask: Optional[Tensor] = None,
  130. src_key_padding_mask: Optional[Tensor] = None,
  131. pos: Optional[Tensor] = None):
  132. src2 = self.norm1(src)
  133. q = k = self.with_pos_embed(src2, pos)
  134. src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
  135. key_padding_mask=src_key_padding_mask)[0]
  136. src = src + self.dropout1(src2)
  137. src2 = self.norm2(src)
  138. src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
  139. src = src + self.dropout2(src2)
  140. return src
  141. def forward(self, src,
  142. src_mask: Optional[Tensor] = None,
  143. src_key_padding_mask: Optional[Tensor] = None,
  144. pos: Optional[Tensor] = None):
  145. if self.normalize_before:
  146. return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
  147. return self.forward_post(src, src_mask, src_key_padding_mask, pos)
  148. class TransformerDecoderLayer(nn.Module):
  149. def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
  150. activation="relu", normalize_before=False):
  151. super().__init__()
  152. self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
  153. self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
  154. # Implementation of Feedforward model
  155. self.linear1 = nn.Linear(d_model, dim_feedforward)
  156. self.dropout = nn.Dropout(dropout)
  157. self.linear2 = nn.Linear(dim_feedforward, d_model)
  158. self.norm1 = nn.LayerNorm(d_model)
  159. self.norm2 = nn.LayerNorm(d_model)
  160. self.norm3 = nn.LayerNorm(d_model)
  161. self.dropout1 = nn.Dropout(dropout)
  162. self.dropout2 = nn.Dropout(dropout)
  163. self.dropout3 = nn.Dropout(dropout)
  164. self.activation = _get_activation_fn(activation)
  165. self.normalize_before = normalize_before
  166. def with_pos_embed(self, tensor, pos: Optional[Tensor]):
  167. return tensor if pos is None else tensor + pos
  168. def forward_post(self, tgt, memory,
  169. tgt_mask: Optional[Tensor] = None,
  170. memory_mask: Optional[Tensor] = None,
  171. tgt_key_padding_mask: Optional[Tensor] = None,
  172. memory_key_padding_mask: Optional[Tensor] = None,
  173. pos: Optional[Tensor] = None,
  174. query_pos: Optional[Tensor] = None):
  175. q = k = self.with_pos_embed(tgt, query_pos)
  176. tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
  177. key_padding_mask=tgt_key_padding_mask)[0]
  178. tgt = tgt + self.dropout1(tgt2)
  179. tgt = self.norm1(tgt)
  180. tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
  181. key=self.with_pos_embed(memory, pos),
  182. value=memory, attn_mask=memory_mask,
  183. key_padding_mask=memory_key_padding_mask)[0]
  184. tgt = tgt + self.dropout2(tgt2)
  185. tgt = self.norm2(tgt)
  186. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
  187. tgt = tgt + self.dropout3(tgt2)
  188. tgt = self.norm3(tgt)
  189. return tgt
  190. def forward_pre(self, tgt, memory,
  191. tgt_mask: Optional[Tensor] = None,
  192. memory_mask: Optional[Tensor] = None,
  193. tgt_key_padding_mask: Optional[Tensor] = None,
  194. memory_key_padding_mask: Optional[Tensor] = None,
  195. pos: Optional[Tensor] = None,
  196. query_pos: Optional[Tensor] = None):
  197. tgt2 = self.norm1(tgt)
  198. q = k = self.with_pos_embed(tgt2, query_pos)
  199. tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
  200. key_padding_mask=tgt_key_padding_mask)[0]
  201. tgt = tgt + self.dropout1(tgt2)
  202. tgt2 = self.norm2(tgt)
  203. tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
  204. key=self.with_pos_embed(memory, pos),
  205. value=memory, attn_mask=memory_mask,
  206. key_padding_mask=memory_key_padding_mask)[0]
  207. tgt = tgt + self.dropout2(tgt2)
  208. tgt2 = self.norm3(tgt)
  209. tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
  210. tgt = tgt + self.dropout3(tgt2)
  211. return tgt
  212. def forward(self, tgt, memory,
  213. tgt_mask: Optional[Tensor] = None,
  214. memory_mask: Optional[Tensor] = None,
  215. tgt_key_padding_mask: Optional[Tensor] = None,
  216. memory_key_padding_mask: Optional[Tensor] = None,
  217. pos: Optional[Tensor] = None,
  218. query_pos: Optional[Tensor] = None):
  219. if self.normalize_before:
  220. return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
  221. tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
  222. return self.forward_post(tgt, memory, tgt_mask, memory_mask,
  223. tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
  224. def _get_clones(module, N):
  225. return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
  226. def build_transformer(args):
  227. return Transformer(
  228. d_model=args.hidden_dim,
  229. dropout=args.dropout,
  230. nhead=args.nheads,
  231. dim_feedforward=args.dim_feedforward,
  232. num_encoder_layers=args.enc_layers,
  233. num_decoder_layers=args.dec_layers,
  234. normalize_before=args.pre_norm,
  235. return_intermediate_dec=True,
  236. )
  237. def _get_activation_fn(activation):
  238. """Return an activation function given a string"""
  239. if activation == "relu":
  240. return F.relu
  241. if activation == "gelu":
  242. return F.gelu
  243. if activation == "glu":
  244. return F.glu
  245. raise RuntimeError(F"activation should be relu/gelu, not {activation}.")

三、软注意力(soft-attention)

软注意力是一个[0,1]间的连续分布问题,更加关注区域或者通道,软注意力是确定性注意力,学习完成后可以通过网络生成,并且是可微的,可以通过神经网络计算出梯度并且可以前向传播和后向反馈来学习得到注意力的权重。

1、空间域注意力(spatial transformer network)

论文地址:http://papers.nips.cc/paper/5854-spatial-transformer-networks

GitHub地址:https://github.com/fxia22/stn.pytorch

空间区域注意力可以理解为让神经网络在看哪里。通过注意力机制,将原始图片中的空间信息变换到另一个空间中并保留了关键信息,在很多现有的方法中都有使用这种网络,自己接触过的一个就是ALPHA Pose。spatial transformer其实就是注意力机制的实现,因为训练出的spatial transformer能够找出图片信息中需要被关注的区域,同时这个transformer又能够具有旋转、缩放变换的功能,这样图片局部的重要信息能够通过变换而被框盒提取出来。

2b0dcd496de9b865642f4ef337c85e2d.png

7188f706e1484b05d723ba9e4cfe5670.png

主要在于空间变换矩阵的学习

  1. class STN(Module):
  2. def __init__(self, layout = 'BHWD'):
  3. super(STN, self).__init__()
  4. if layout == 'BHWD':
  5. self.f = STNFunction()
  6. else:
  7. self.f = STNFunctionBCHW()
  8. def forward(self, input1, input2):
  9. return self.f(input1, input2)
  10. class STNFunction(Function):
  11. def forward(self, input1, input2):
  12. self.input1 = input1
  13. self.input2 = input2
  14. self.device_c = ffi.new("int *")
  15. output = torch.zeros(input1.size()[0], input2.size()[1], input2.size()[2], input1.size()[3])
  16. #print('decice %d' % torch.cuda.current_device())
  17. if input1.is_cuda:
  18. self.device = torch.cuda.current_device()
  19. else:
  20. self.device = -1
  21. self.device_c[0] = self.device
  22. if not input1.is_cuda:
  23. my_lib.BilinearSamplerBHWD_updateOutput(input1, input2, output)
  24. else:
  25. output = output.cuda(self.device)
  26. my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output, self.device_c)
  27. return output
  28. def backward(self, grad_output):
  29. grad_input1 = torch.zeros(self.input1.size())
  30. grad_input2 = torch.zeros(self.input2.size())
  31. #print('backward decice %d' % self.device)
  32. if not grad_output.is_cuda:
  33. my_lib.BilinearSamplerBHWD_updateGradInput(self.input1, self.input2, grad_input1, grad_input2, grad_output)
  34. else:
  35. grad_input1 = grad_input1.cuda(self.device)
  36. grad_input2 = grad_input2.cuda(self.device)
  37. my_lib.BilinearSamplerBHWD_updateGradInput_cuda(self.input1, self.input2, grad_input1, grad_input2, grad_output, self.device_c)
  38. return grad_input1, grad_input2

2、通道注意力(Channel Attention,CA)

通道注意力可以理解为让神经网络在看什么,典型的代表是SENet。卷积网络的每一层都有好多卷积核,每个卷积核对应一个特征通道,相对于空间注意力机制,通道注意力在于分配各个卷积通道之间的资源,分配粒度上比前者大了一个级别。

cafd63b296592f96e1800d0df7721438.png

论文:Squeeze-and-Excitation Networks(https://arxiv.org/abs/1709.01507)

GitHub地址:https://github.com/moskomule/senet.pytorch

Squeeze操作:将各通道的全局空间特征作为该通道的表示,使用全局平均池化生成各通道的统计量

Excitation操作:学习各通道的依赖程度,并根据依赖程度对不同的特征图进行调整,得到最后的输出,需要考察各通道的依赖程度

整体的结构如图所示:

f046249d2d0819bf75cbe647803f1273.png

卷积层的输出并没有考虑对各通道的依赖,SEBlock的目的在于然根网络选择性的增强信息量最大的特征,是的后期处理充分利用这些特征并抑制无用的特征。

77d60848e5cfbe7c6bd2290b81b15b9d.png

c371ac8d4ddea5c8552e8ab819ff4fe2.png

  1. 将输入特征进行 Global avgpooling,得到1×1×Channel

  2. 然后bottleneck特征交互一下,先压缩channel数,再重构回channel数

  3. 最后接个sigmoid,生成channel间0~1的attention weights,最后scale乘回原输入特征

SE-ResNet的SE-Block

  1. class SEBasicBlock(nn.Module):
  2. expansion = 1
  3. def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
  4. base_width=64, dilation=1, norm_layer=None,
  5. *, reduction=16):
  6. super(SEBasicBlock, self).__init__()
  7. self.conv1 = conv3x3(inplanes, planes, stride)
  8. self.bn1 = nn.BatchNorm2d(planes)
  9. self.relu = nn.ReLU(inplace=True)
  10. self.conv2 = conv3x3(planes, planes, 1)
  11. self.bn2 = nn.BatchNorm2d(planes)
  12. self.se = SELayer(planes, reduction)
  13. self.downsample = downsample
  14. self.stride = stride
  15. def forward(self, x):
  16. residual = x
  17. out = self.conv1(x)
  18. out = self.bn1(out)
  19. out = self.relu(out)
  20. out = self.conv2(out)
  21. out = self.bn2(out)
  22. out = self.se(out)
  23. if self.downsample is not None:
  24. residual = self.downsample(x)
  25. out += residual
  26. out = self.relu(out)
  27. return out
  28. class SELayer(nn.Module):
  29. def __init__(self, channel, reduction=16):
  30. super(SELayer, self).__init__()
  31. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  32. self.fc = nn.Sequential(
  33. nn.Linear(channel, channel // reduction, bias=False),
  34. nn.ReLU(inplace=True),
  35. nn.Linear(channel // reduction, channel, bias=False),
  36. nn.Sigmoid()
  37. )
  38. def forward(self, x):
  39. b, c, _, _ = x.size()
  40. y = self.avg_pool(x).view(b, c)
  41. y = self.fc(y).view(b, c, 1, 1)
  42. return x * y.expand_as(x)

ResNet的Basic Block

  1. class BasicBlock(nn.Module):
  2. def __init__(self, inplanes, planes, stride=1):
  3. super(BasicBlock, self).__init__()
  4. self.conv1 = conv3x3(inplanes, planes, stride)
  5. self.bn1 = nn.BatchNorm2d(planes)
  6. self.relu = nn.ReLU(inplace=True)
  7. self.conv2 = conv3x3(planes, planes)
  8. self.bn2 = nn.BatchNorm2d(planes)
  9. if inplanes != planes:
  10. self.downsample = nn.Sequential(nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False),
  11. nn.BatchNorm2d(planes))
  12. else:
  13. self.downsample = lambda x: x
  14. self.stride = stride
  15. def forward(self, x):
  16. residual = self.downsample(x)
  17. out = self.conv1(x)
  18. out = self.bn1(out)
  19. out = self.relu(out)
  20. out = self.conv2(out)
  21. out = self.bn2(out)
  22. out += residual
  23. out = self.relu(out)
  24. return out

两者的差别主要体现在多了一个SElayer,详细可以查看源码

3、混合域模型(融合空间域和通道域注意力)

(1)论文:Residual Attention Network for image classification(CVPR 2017 Open Access Repository)

文章中注意力的机制是软注意力基本的加掩码(mask)机制,但是不同的是,这种注意力机制的mask借鉴了残差网络的想法,不只根据当前网络层的信息加上mask,还把上一层的信息传递下来,这样就防止mask之后的信息量过少引起的网络层数不能堆叠很深的问题。

该文章的注意力机制的创新点在于提出了残差注意力学习(residual attention learning),不仅只把mask之后的特征张量作为下一层的输入,同时也将mask之前的特征张量作为下一层的输入,这时候可以得到的特征更为丰富,从而能够更好的注意关键特征。同时采用三阶注意力模块来构成整个的注意力。

8fb33a5524332f6515a3532001550f2e.png

(2)Dual Attention Network for Scene Segmentation(CVPR 2019 Open Access Repository)

66705dc9e5e0fc58f5187d9c396aa478.png

4、Non-Local

论文:non-local neural networks(CVPR 2018 Open Access Repository)

GitHub地址:https://github.com/AlexHex7/Non-local_pytorch

Local这个词主要是针对感受野(receptive field)来说的。以单一的卷积操作为例,它的感受野大小就是卷积核大小,而我们一般都选用3*3,5*5之类的卷积核,它们只考虑局部区域,因此都是local的运算。同理,池化(Pooling)也是。相反的,non-local指的就是感受野可以很大,而不是一个局部领域。全连接就是non-local的,而且是global的。但是全连接带来了大量的参数,给优化带来困难。卷积层的堆叠可以增大感受野,但是如果看特定层的卷积核在原图上的感受野,它毕竟是有限的。这是local运算不能避免的。然而有些任务,它们可能需要原图上更多的信息,比如attention。如果在某些层能够引入全局的信息,就能很好地解决local操作无法看清全局的情况,为后面的层带去更丰富的信息。

文章定义的对于神经网络通用的Non-Local计算如下所示:

12da575bfa2999d964c2ada13e691088.png

如果按照上面的公式,用for循环实现肯定是很慢的。此外,如果在尺寸很大的输入上应用non-local layer,也是计算量很大的。后者的解决方案是,只在高阶语义层中引入non-local layer。还可以通过对embedding(θ,ϕ,g)的结果加pooling层来进一步地减少计算量。

9e3e434cd45aa688d473cb436922784a.png

  1. 首先对输入的 feature map X 进行线性映射(通过1x1卷积,来压缩通道数),然后得到 θ,ϕ,g 特征

  2. 通过reshape操作,强行合并上述的三个特征除通道数外的维度,然后对 进行矩阵点乘操作,得到类似协方差矩阵的东西(这个过程很重要,计算出特征中的自相关性,即得到每帧中每个像素对其他所有帧所有像素的关系)

  3. 然后对自相关特征 以列or以行(具体看矩阵 g 的形式而定) 进行 Softmax 操作,得到0~1的weights,这里就是我们需要的 Self-attention 系数

  4. 最后将 attention系数,对应乘回特征矩阵g中,然后再上扩channel 数,与原输入feature map X残差

5、位置注意力(position-wise attention)

论文:CCNet: Criss-Cross Attention for Semantic Segmentation(ICCV 2019 Open Access Repository)

Github地址:https://github.com/speedinghzl/CCNet

本篇文章的亮点在于用了巧妙的方法减少了参数量。在上面的DANet中,attention map计算的是所有像素与所有像素之间的相似性,空间复杂度为(HxW)x(HxW),而本文采用了criss-cross思想,只计算每个像素与其同行同列即十字上的像素的相似性,通过进行循环(两次相同操作),间接计算到每个像素与每个像素的相似性,将空间复杂度降为(HxW)x(H+W-1)

efac89837f0dccdbbe99fbda45671f4e.png

在计算矩阵相乘时每个像素只抽取特征图中对应十字位置的像素进行点乘,计算相似度。和non-local的方法相比极大的降低了计算量,同时采用二阶注意力,能够从所有像素中获取全图像的上下文信息,以生成具有密集且丰富的上下文信息的新特征图。在计算矩阵相乘时,每个像素只抽取特征图中对应十字位置的像素进行点乘,计算相似度。

2fd7991053c19317aa2341a9c82978ba.png

  1. def _check_contiguous(*args):
  2. if not all([mod is None or mod.is_contiguous() for mod in args]):
  3. raise ValueError("Non-contiguous input")
  4. class CA_Weight(autograd.Function):
  5. @staticmethod
  6. def forward(ctx, t, f):
  7. # Save context
  8. n, c, h, w = t.size()
  9. size = (n, h+w-1, h, w)
  10. weight = torch.zeros(size, dtype=t.dtype, layout=t.layout, device=t.device)
  11. _ext.ca_forward_cuda(t, f, weight)
  12. # Output
  13. ctx.save_for_backward(t, f)
  14. return weight
  15. @staticmethod
  16. @once_differentiable
  17. def backward(ctx, dw):
  18. t, f = ctx.saved_tensors
  19. dt = torch.zeros_like(t)
  20. df = torch.zeros_like(f)
  21. _ext.ca_backward_cuda(dw.contiguous(), t, f, dt, df)
  22. _check_contiguous(dt, df)
  23. return dt, df
  24. class CA_Map(autograd.Function):
  25. @staticmethod
  26. def forward(ctx, weight, g):
  27. # Save context
  28. out = torch.zeros_like(g)
  29. _ext.ca_map_forward_cuda(weight, g, out)
  30. # Output
  31. ctx.save_for_backward(weight, g)
  32. return out
  33. @staticmethod
  34. @once_differentiable
  35. def backward(ctx, dout):
  36. weight, g = ctx.saved_tensors
  37. dw = torch.zeros_like(weight)
  38. dg = torch.zeros_like(g)
  39. _ext.ca_map_backward_cuda(dout.contiguous(), weight, g, dw, dg)
  40. _check_contiguous(dw, dg)
  41. return dw, dg
  42. ca_weight = CA_Weight.apply
  43. ca_map = CA_Map.apply
  44. class CrissCrossAttention(nn.Module):
  45. """ Criss-Cross Attention Module"""
  46. def __init__(self,in_dim):
  47. super(CrissCrossAttention,self).__init__()
  48. self.chanel_in = in_dim
  49. self.query_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
  50. self.key_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim//8 , kernel_size= 1)
  51. self.value_conv = nn.Conv2d(in_channels = in_dim , out_channels = in_dim , kernel_size= 1)
  52. self.gamma = nn.Parameter(torch.zeros(1))
  53. def forward(self,x):
  54. proj_query = self.query_conv(x)
  55. proj_key = self.key_conv(x)
  56. proj_value = self.value_conv(x)
  57. energy = ca_weight(proj_query, proj_key)
  58. attention = F.softmax(energy, 1)
  59. out = ca_map(attention, proj_value)
  60. out = self.gamma*out + x
  61. return out
  62. __all__ = ["CrissCrossAttention", "ca_weight", "ca_map"]

四、强注意力(hard attention)

0/1问题,哪些被attention,哪些不被attention。更加关注点,图像中的每个点都可能延伸出注意力,同时强注意力是一个随机预测的过程,更加强调动态变化,并且是不可微,所以训练过程往往通过增强学习。

参考资料

https://blog.csdn.net/xys430381_1/article/details/89323444

Gapeng:Non-local neural networks

NX-8MAA09148HY:双注意力网络,是丰富了还是牵强了attention?

  1. 下载1:OpenCV-Contrib扩展模块中文版教程
  2. 在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。
  3. 下载2:Python视觉实战项目52
  4. 在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。
  5. 下载3:OpenCV实战项目20
  6. 在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。
  7. 交流群
  8. 欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/101366
推荐阅读
相关标签
  

闽ICP备14008679号