当前位置:   article > 正文

CVPR 2023 Biformer,计算代价小的自注意力机制

biformer

摘要:

作为视觉转换器的核心构建块,注意力是捕捉长期依赖关系的强大工具。然而,这种能力是有代价的:当计算所有空间位置上的成对令牌交互时,它会带来巨大的计算负担和沉重的内存占用。一系列工作试图通过将手工制作和内容无关的稀疏性引入注意力来缓解这个问题,例如将注意力操作限制在局部窗口、轴向条纹或扩展窗口内。与这些方法相反,我们提出了一种新的通过双层路由的动态稀疏注意力,以实现具有内容感知的更灵活的计算分配。具体来说,对于查询,首先在粗略区域级别过滤掉不相关的键值对,然后在剩余候选区域(即路由区域)的并集中应用细粒度的令牌对令牌关注。我们提供了所提出的双层路由注意的一个简单而有效的实现,该实现利用稀疏性来节省计算和内存,同时只涉及GPU友好的密集矩阵乘法。基于所提出的双层路由注意,提出了一种新的通用视觉转换器BiFormer由于BiFormer以查询自适应的方式处理一小部分相关令牌,而不会分散对其他无关令牌的注意力,因此它具有良好的性能和较高的计算效率,尤其是在密集的预测任务中。在图像分类、对象检测和语义分割等几个计算机视觉任务中的经验结果验证了我们设计的有效性。

论文标题:BiFormer: Vision Transformer With Bi-Level Routing Attention

论文连接:CVPR 2023 Open Access Repository (thecvf.com)

源码:https://github.com/rayleizhu/BiFormer

  1. class Block(nn.Module):
  2. def __init__(self, dim, drop_path=0., layer_scale_init_value=-1,
  3. num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
  4. kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='ada_avgpool',
  5. topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, mlp_ratio=4, mlp_dwconv=False,
  6. side_dwconv=5, before_attn_dwconv=3, pre_norm=True, auto_pad=False):
  7. super().__init__()
  8. qk_dim = qk_dim or dim
  9. # modules
  10. if before_attn_dwconv > 0:
  11. self.pos_embed = nn.Conv2d(dim, dim, kernel_size=before_attn_dwconv, padding=1, groups=dim)
  12. else:
  13. self.pos_embed = lambda x: 0
  14. self.norm1 = nn.LayerNorm(dim, eps=1e-6) # important to avoid attention collapsing
  15. if topk > 0:
  16. self.attn = BiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
  17. qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,
  18. kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,
  19. topk=topk, param_attention=param_attention, param_routing=param_routing,
  20. diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,
  21. auto_pad=auto_pad)
  22. elif topk == -1:
  23. self.attn = Attention(dim=dim)
  24. elif topk == -2:
  25. self.attn = AttentionLePE(dim=dim, side_dwconv=side_dwconv)
  26. elif topk == 0:
  27. self.attn = nn.Sequential(Rearrange('n h w c -> n c h w'), # compatiability
  28. nn.Conv2d(dim, dim, 1), # pseudo qkv linear
  29. nn.Conv2d(dim, dim, 5, padding=2, groups=dim), # pseudo attention
  30. nn.Conv2d(dim, dim, 1), # pseudo out linear
  31. Rearrange('n c h w -> n h w c')
  32. )
  33. self.norm2 = nn.LayerNorm(dim, eps=1e-6)
  34. self.mlp = nn.Sequential(nn.Linear(dim, int(mlp_ratio *dim)),
  35. DWConv(int(mlp_ratio *dim)) if mlp_dwconv else nn.Identity(),
  36. nn.GELU(),
  37. nn.Linear(int(mlp_ratio *dim), dim)
  38. )
  39. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  40. # tricks: layer scale & pre_norm/post_norm
  41. if layer_scale_init_value > 0:
  42. self.use_layer_scale = True
  43. self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  44. self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
  45. else:
  46. self.use_layer_scale = False
  47. self.pre_norm = pre_norm
  48. def forward(self, x):
  49. """
  50. x: NCHW tensor
  51. """
  52. # conv pos embedding
  53. x = x + self.pos_embed(x)
  54. # permute to NHWC tensor for attention & mlp
  55. x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
  56. # attention & mlp
  57. if self.pre_norm:
  58. if self.use_layer_scale: # Fallse
  59. mop = self.norm1(x)
  60. x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) # (N, H, W, C)
  61. x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) # (N, H, W, C)
  62. else:
  63. x = x + self.drop_path(self.attn(self.norm1(x))) # (N, H, W, C)
  64. x = x + self.drop_path(self.mlp(self.norm2(x))) # (N, H, W, C)
  65. else: # https://kexue.fm/archives/9009
  66. if self.use_layer_scale: # False
  67. x = self.norm1(x + self.drop_path(self.gamma1 * self.attn(x))) # (N, H, W, C)
  68. x = self.norm2(x + self.drop_path(self.gamma2 * self.mlp(x))) # (N, H, W, C)
  69. else:
  70. x = self.norm1(x + self.drop_path(self.attn(x))) # (N, H, W, C)
  71. x = self.norm2(x + self.drop_path(self.mlp(x))) # (N, H, W, C)
  72. # permute back
  73. x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
  74. return x
  75. if __name__ == "__main__":
  76. x2 = torch.randn(1,128,28,28)
  77. net = Block(dim=128)
  78. out2 = net(x2)
  79. print(out2.shape)
  80. """
  81. x2:[1,128,28,28]
  82. out2: [1,128,28,28]
  83. """

如果你不想自己看,想着如何用模型的代码已经备好:

https://h5.m.taobao.com/awp/core/detail.htm?ft=t&id=751904510679

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/704446
推荐阅读
相关标签
  

闽ICP备14008679号