赞
踩
摘要:
作为视觉转换器的核心构建块,注意力是捕捉长期依赖关系的强大工具。然而,这种能力是有代价的:当计算所有空间位置上的成对令牌交互时,它会带来巨大的计算负担和沉重的内存占用。一系列工作试图通过将手工制作和内容无关的稀疏性引入注意力来缓解这个问题,例如将注意力操作限制在局部窗口、轴向条纹或扩展窗口内。与这些方法相反,我们提出了一种新的通过双层路由的动态稀疏注意力,以实现具有内容感知的更灵活的计算分配。具体来说,对于查询,首先在粗略区域级别过滤掉不相关的键值对,然后在剩余候选区域(即路由区域)的并集中应用细粒度的令牌对令牌关注。我们提供了所提出的双层路由注意的一个简单而有效的实现,该实现利用稀疏性来节省计算和内存,同时只涉及GPU友好的密集矩阵乘法。基于所提出的双层路由注意,提出了一种新的通用视觉转换器BiFormer。由于BiFormer以查询自适应的方式处理一小部分相关令牌,而不会分散对其他无关令牌的注意力,因此它具有良好的性能和较高的计算效率,尤其是在密集的预测任务中。在图像分类、对象检测和语义分割等几个计算机视觉任务中的经验结果验证了我们设计的有效性。
论文标题:BiFormer: Vision Transformer With Bi-Level Routing Attention
论文连接:CVPR 2023 Open Access Repository (thecvf.com)
源码:https://github.com/rayleizhu/BiFormer
- class Block(nn.Module):
- def __init__(self, dim, drop_path=0., layer_scale_init_value=-1,
- num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
- kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='ada_avgpool',
- topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, mlp_ratio=4, mlp_dwconv=False,
- side_dwconv=5, before_attn_dwconv=3, pre_norm=True, auto_pad=False):
- super().__init__()
- qk_dim = qk_dim or dim
-
- # modules
- if before_attn_dwconv > 0:
- self.pos_embed = nn.Conv2d(dim, dim, kernel_size=before_attn_dwconv, padding=1, groups=dim)
- else:
- self.pos_embed = lambda x: 0
- self.norm1 = nn.LayerNorm(dim, eps=1e-6) # important to avoid attention collapsing
- if topk > 0:
- self.attn = BiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
- qk_scale=qk_scale, kv_per_win=kv_per_win, kv_downsample_ratio=kv_downsample_ratio,
- kv_downsample_kernel=kv_downsample_kernel, kv_downsample_mode=kv_downsample_mode,
- topk=topk, param_attention=param_attention, param_routing=param_routing,
- diff_routing=diff_routing, soft_routing=soft_routing, side_dwconv=side_dwconv,
- auto_pad=auto_pad)
- elif topk == -1:
- self.attn = Attention(dim=dim)
- elif topk == -2:
- self.attn = AttentionLePE(dim=dim, side_dwconv=side_dwconv)
- elif topk == 0:
- self.attn = nn.Sequential(Rearrange('n h w c -> n c h w'), # compatiability
- nn.Conv2d(dim, dim, 1), # pseudo qkv linear
- nn.Conv2d(dim, dim, 5, padding=2, groups=dim), # pseudo attention
- nn.Conv2d(dim, dim, 1), # pseudo out linear
- Rearrange('n c h w -> n h w c')
- )
- self.norm2 = nn.LayerNorm(dim, eps=1e-6)
- self.mlp = nn.Sequential(nn.Linear(dim, int(mlp_ratio *dim)),
- DWConv(int(mlp_ratio *dim)) if mlp_dwconv else nn.Identity(),
- nn.GELU(),
- nn.Linear(int(mlp_ratio *dim), dim)
- )
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
-
- # tricks: layer scale & pre_norm/post_norm
- if layer_scale_init_value > 0:
- self.use_layer_scale = True
- self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
- self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
- else:
- self.use_layer_scale = False
- self.pre_norm = pre_norm
-
-
- def forward(self, x):
- """
- x: NCHW tensor
- """
- # conv pos embedding
- x = x + self.pos_embed(x)
- # permute to NHWC tensor for attention & mlp
- x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
-
- # attention & mlp
- if self.pre_norm:
- if self.use_layer_scale: # Fallse
- mop = self.norm1(x)
- x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) # (N, H, W, C)
- x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) # (N, H, W, C)
- else:
- x = x + self.drop_path(self.attn(self.norm1(x))) # (N, H, W, C)
- x = x + self.drop_path(self.mlp(self.norm2(x))) # (N, H, W, C)
- else: # https://kexue.fm/archives/9009
- if self.use_layer_scale: # False
- x = self.norm1(x + self.drop_path(self.gamma1 * self.attn(x))) # (N, H, W, C)
- x = self.norm2(x + self.drop_path(self.gamma2 * self.mlp(x))) # (N, H, W, C)
- else:
- x = self.norm1(x + self.drop_path(self.attn(x))) # (N, H, W, C)
- x = self.norm2(x + self.drop_path(self.mlp(x))) # (N, H, W, C)
-
- # permute back
- x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
- return x
-
-
- if __name__ == "__main__":
- x2 = torch.randn(1,128,28,28)
- net = Block(dim=128)
- out2 = net(x2)
- print(out2.shape)
- """
- x2:[1,128,28,28]
- out2: [1,128,28,28]
- """

如果你不想自己看,想着如何用模型的代码已经备好:
https://h5.m.taobao.com/awp/core/detail.htm?ft=t&id=751904510679
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。