当前位置:   article > 正文

【论文笔记】BiFormer: Vision Transformer with Bi-Level Routing Attention

biformer: vision transformer with bi-level routing attention

论文地址:BiFormer: Vision Transformer with Bi-Level Routing Attention

代码地址:https://github.com/rayleizhu/BiFormer

vision transformer中Attention是极其重要的模块,但是它有着非常大的缺点:计算量太大。

BiFormer提出了Bi-Level Routing Attention,在Attention计算时,只关注最重要的token,由此来降低计算量。

一、Bi-Level Routing Attention

下图是多个不同的Attention模块关注的区域,(a)是原始的attention,其他的都是稀疏的Attention结构。Bi-Level Routing Attention如下图(f)所示。

与其他的稀疏Attention结构有所不同,Bi-Level Routing Attention首先将特征图分为不同的区域(区域大小是SxS),每个区域经过线性映射,得到QKV,然后QK在每个SxS的窗口内取平均作为该区域的token(可参考代码),得到Q^{r}K^{r}(r表示region,即SxS的窗口),通过Q^{r}K^{r}的矩阵运算得到A^{r},如下式:

得到了邻接矩阵A^{r}后,取其相关性最高的k个token索引I^{r},这样就知道每个窗口与哪k个窗口相关性更高了。

得到了I^{r}后,用gather运算得到K^{g}V^{g}

最后计算Attention:

Bi-Level Routing Attention的计算过程如下图所示,其中的k就是计算相关性索引时设置的参数。

二、代码

Bi-Level Routing Attention的代码如下:

  1. """
  2. Core of BiFormer, Bi-Level Routing Attention.
  3. To be refactored.
  4. author: ZHU Lei
  5. github: https://github.com/rayleizhu
  6. email: ray.leizhu@outlook.com
  7. This source code is licensed under the license found in the
  8. LICENSE file in the root directory of this source tree.
  9. """
  10. from typing import Tuple
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. from einops import rearrange
  15. from torch import Tensor
  16. class TopkRouting(nn.Module):
  17. """
  18. differentiable topk routing with scaling
  19. Args:
  20. qk_dim: int, feature dimension of query and key
  21. topk: int, the 'topk'
  22. qk_scale: int or None, temperature (multiply) of softmax activation
  23. with_param: bool, wether inorporate learnable params in routing unit
  24. diff_routing: bool, wether make routing differentiable
  25. soft_routing: bool, wether make output value multiplied by routing weights
  26. """
  27. def __init__(self, qk_dim, topk=4, qk_scale=None, param_routing=False, diff_routing=False):
  28. super().__init__()
  29. self.topk = topk
  30. self.qk_dim = qk_dim
  31. self.scale = qk_scale or qk_dim ** -0.5
  32. self.diff_routing = diff_routing
  33. # TODO: norm layer before/after linear?
  34. self.emb = nn.Linear(qk_dim, qk_dim) if param_routing else nn.Identity()
  35. # routing activation
  36. self.routing_act = nn.Softmax(dim=-1)
  37. def forward(self, query:Tensor, key:Tensor)->Tuple[Tensor]:
  38. """
  39. Args:
  40. q, k: (n, p^2, c) tensor
  41. Return:
  42. r_weight, topk_index: (n, p^2, topk) tensor
  43. """
  44. if not self.diff_routing:
  45. query, key = query.detach(), key.detach()
  46. query_hat, key_hat = self.emb(query), self.emb(key) # per-window pooling -> (n, p^2, c)
  47. attn_logit = (query_hat*self.scale) @ key_hat.transpose(-2, -1) # (n, p^2, p^2)
  48. topk_attn_logit, topk_index = torch.topk(attn_logit, k=self.topk, dim=-1) # (n, p^2, k), (n, p^2, k)
  49. r_weight = self.routing_act(topk_attn_logit) # (n, p^2, k)
  50. return r_weight, topk_index
  51. class KVGather(nn.Module):
  52. def __init__(self, mul_weight='none'):
  53. super().__init__()
  54. assert mul_weight in ['none', 'soft', 'hard']
  55. self.mul_weight = mul_weight
  56. def forward(self, r_idx:Tensor, r_weight:Tensor, kv:Tensor):
  57. """
  58. r_idx: (n, p^2, topk) tensor
  59. r_weight: (n, p^2, topk) tensor
  60. kv: (n, p^2, w^2, c_kq+c_v)
  61. Return:
  62. (n, p^2, topk, w^2, c_kq+c_v) tensor
  63. """
  64. # select kv according to routing index
  65. n, p2, w2, c_kv = kv.size()
  66. topk = r_idx.size(-1)
  67. # print(r_idx.size(), r_weight.size())
  68. # FIXME: gather consumes much memory (topk times redundancy), write cuda kernel?
  69. topk_kv = torch.gather(kv.view(n, 1, p2, w2, c_kv).expand(-1, p2, -1, -1, -1), # (n, p^2, p^2, w^2, c_kv) without mem cpy
  70. dim=2,
  71. index=r_idx.view(n, p2, topk, 1, 1).expand(-1, -1, -1, w2, c_kv) # (n, p^2, k, w^2, c_kv)
  72. )
  73. if self.mul_weight == 'soft':
  74. topk_kv = r_weight.view(n, p2, topk, 1, 1) * topk_kv # (n, p^2, k, w^2, c_kv)
  75. elif self.mul_weight == 'hard':
  76. raise NotImplementedError('differentiable hard routing TBA')
  77. # else: #'none'
  78. # topk_kv = topk_kv # do nothing
  79. return topk_kv
  80. class QKVLinear(nn.Module):
  81. def __init__(self, dim, qk_dim, bias=True):
  82. super().__init__()
  83. self.dim = dim
  84. self.qk_dim = qk_dim
  85. self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias=bias)
  86. def forward(self, x):
  87. q, kv = self.qkv(x).split([self.qk_dim, self.qk_dim+self.dim], dim=-1)
  88. return q, kv
  89. # q, k, v = self.qkv(x).split([self.qk_dim, self.qk_dim, self.dim], dim=-1)
  90. # return q, k, v
  91. class BiLevelRoutingAttention(nn.Module):
  92. """
  93. n_win: number of windows in one side (so the actual number of windows is n_win*n_win)
  94. kv_per_win: for kv_downsample_mode='ada_xxxpool' only, number of key/values per window. Similar to n_win, the actual number is kv_per_win*kv_per_win.
  95. topk: topk for window filtering
  96. param_attention: 'qkvo'-linear for q,k,v and o, 'none': param free attention
  97. param_routing: extra linear for routing
  98. diff_routing: wether to set routing differentiable
  99. soft_routing: wether to multiply soft routing weights
  100. """
  101. def __init__(self, dim, num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
  102. kv_per_win=4, kv_downsample_ratio=4, kv_downsample_kernel=None, kv_downsample_mode='identity',
  103. topk=4, param_attention="qkvo", param_routing=False, diff_routing=False, soft_routing=False, side_dwconv=3,
  104. auto_pad=False):
  105. super().__init__()
  106. # local attention setting
  107. self.dim = dim
  108. self.n_win = n_win # Wh, Ww
  109. self.num_heads = num_heads
  110. self.qk_dim = qk_dim or dim
  111. assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'
  112. self.scale = qk_scale or self.qk_dim ** -0.5
  113. ################side_dwconv (i.e. LCE in ShuntedTransformer)###########
  114. self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
  115. lambda x: torch.zeros_like(x)
  116. ################ global routing setting #################
  117. self.topk = topk
  118. self.param_routing = param_routing
  119. self.diff_routing = diff_routing
  120. self.soft_routing = soft_routing
  121. # router
  122. assert not (self.param_routing and not self.diff_routing) # cannot be with_param=True and diff_routing=False
  123. self.router = TopkRouting(qk_dim=self.qk_dim,
  124. qk_scale=self.scale,
  125. topk=self.topk,
  126. diff_routing=self.diff_routing,
  127. param_routing=self.param_routing)
  128. if self.soft_routing: # soft routing, always diffrentiable (if no detach)
  129. mul_weight = 'soft'
  130. elif self.diff_routing: # hard differentiable routing
  131. mul_weight = 'hard'
  132. else: # hard non-differentiable routing
  133. mul_weight = 'none'
  134. self.kv_gather = KVGather(mul_weight=mul_weight)
  135. # qkv mapping (shared by both global routing and local attention)
  136. self.param_attention = param_attention
  137. if self.param_attention == 'qkvo':
  138. self.qkv = QKVLinear(self.dim, self.qk_dim)
  139. self.wo = nn.Linear(dim, dim)
  140. elif self.param_attention == 'qkv':
  141. self.qkv = QKVLinear(self.dim, self.qk_dim)
  142. self.wo = nn.Identity()
  143. else:
  144. raise ValueError(f'param_attention mode {self.param_attention} is not surpported!')
  145. self.kv_downsample_mode = kv_downsample_mode
  146. self.kv_per_win = kv_per_win
  147. self.kv_downsample_ratio = kv_downsample_ratio
  148. self.kv_downsample_kenel = kv_downsample_kernel
  149. if self.kv_downsample_mode == 'ada_avgpool':
  150. assert self.kv_per_win is not None
  151. self.kv_down = nn.AdaptiveAvgPool2d(self.kv_per_win)
  152. elif self.kv_downsample_mode == 'ada_maxpool':
  153. assert self.kv_per_win is not None
  154. self.kv_down = nn.AdaptiveMaxPool2d(self.kv_per_win)
  155. elif self.kv_downsample_mode == 'maxpool':
  156. assert self.kv_downsample_ratio is not None
  157. self.kv_down = nn.MaxPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
  158. elif self.kv_downsample_mode == 'avgpool':
  159. assert self.kv_downsample_ratio is not None
  160. self.kv_down = nn.AvgPool2d(self.kv_downsample_ratio) if self.kv_downsample_ratio > 1 else nn.Identity()
  161. elif self.kv_downsample_mode == 'identity': # no kv downsampling
  162. self.kv_down = nn.Identity()
  163. elif self.kv_downsample_mode == 'fracpool':
  164. # assert self.kv_downsample_ratio is not None
  165. # assert self.kv_downsample_kenel is not None
  166. # TODO: fracpool
  167. # 1. kernel size should be input size dependent
  168. # 2. there is a random factor, need to avoid independent sampling for k and v
  169. raise NotImplementedError('fracpool policy is not implemented yet!')
  170. elif kv_downsample_mode == 'conv':
  171. # TODO: need to consider the case where k != v so that need two downsample modules
  172. raise NotImplementedError('conv policy is not implemented yet!')
  173. else:
  174. raise ValueError(f'kv_down_sample_mode {self.kv_downsaple_mode} is not surpported!')
  175. # softmax for local attention
  176. self.attn_act = nn.Softmax(dim=-1)
  177. self.auto_pad=auto_pad
  178. def forward(self, x, ret_attn_mask=False):
  179. """
  180. x: NHWC tensor
  181. Return:
  182. NHWC tensor
  183. """
  184. # NOTE: use padding for semantic segmentation
  185. ###################################################
  186. if self.auto_pad:
  187. N, H_in, W_in, C = x.size()
  188. pad_l = pad_t = 0
  189. pad_r = (self.n_win - W_in % self.n_win) % self.n_win
  190. pad_b = (self.n_win - H_in % self.n_win) % self.n_win
  191. x = F.pad(x, (0, 0, # dim=-1
  192. pad_l, pad_r, # dim=-2
  193. pad_t, pad_b)) # dim=-3
  194. _, H, W, _ = x.size() # padded size
  195. else:
  196. N, H, W, C = x.size()
  197. assert H%self.n_win == 0 and W%self.n_win == 0 #
  198. ###################################################
  199. # patchify, (n, p^2, w, w, c), keep 2d window as we need 2d pooling to reduce kv size
  200. x = rearrange(x, "n (j h) (i w) c -> n (j i) h w c", j=self.n_win, i=self.n_win)
  201. #################qkv projection###################
  202. # q: (n, p^2, w, w, c_qk)
  203. # kv: (n, p^2, w, w, c_qk+c_v)
  204. # NOTE: separte kv if there were memory leak issue caused by gather
  205. q, kv = self.qkv(x)
  206. # pixel-wise qkv
  207. # q_pix: (n, p^2, w^2, c_qk)
  208. # kv_pix: (n, p^2, h_kv*w_kv, c_qk+c_v)
  209. q_pix = rearrange(q, 'n p2 h w c -> n p2 (h w) c')
  210. kv_pix = self.kv_down(rearrange(kv, 'n p2 h w c -> (n p2) c h w'))
  211. kv_pix = rearrange(kv_pix, '(n j i) c h w -> n (j i) (h w) c', j=self.n_win, i=self.n_win)
  212. q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (n, p^2, c_qk), (n, p^2, c_qk)
  213. ##################side_dwconv(lepe)##################
  214. # NOTE: call contiguous to avoid gradient warning when using ddp
  215. lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'n (j i) h w c -> n c (j h) (i w)', j=self.n_win, i=self.n_win).contiguous())
  216. lepe = rearrange(lepe, 'n c (j h) (i w) -> n (j h) (i w) c', j=self.n_win, i=self.n_win)
  217. ############ gather q dependent k/v #################
  218. r_weight, r_idx = self.router(q_win, k_win) # both are (n, p^2, topk) tensors
  219. kv_pix_sel = self.kv_gather(r_idx=r_idx, r_weight=r_weight, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)
  220. k_pix_sel, v_pix_sel = kv_pix_sel.split([self.qk_dim, self.dim], dim=-1)
  221. # kv_pix_sel: (n, p^2, topk, h_kv*w_kv, c_qk)
  222. # v_pix_sel: (n, p^2, topk, h_kv*w_kv, c_v)
  223. ######### do attention as normal ####################
  224. k_pix_sel = rearrange(k_pix_sel, 'n p2 k w2 (m c) -> (n p2) m c (k w2)', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_kq//m) transpose here?
  225. v_pix_sel = rearrange(v_pix_sel, 'n p2 k w2 (m c) -> (n p2) m (k w2) c', m=self.num_heads) # flatten to BMLC, (n*p^2, m, topk*h_kv*w_kv, c_v//m)
  226. q_pix = rearrange(q_pix, 'n p2 w2 (m c) -> (n p2) m w2 c', m=self.num_heads) # to BMLC tensor (n*p^2, m, w^2, c_qk//m)
  227. # param-free multihead attention
  228. attn_weight = (q_pix * self.scale) @ k_pix_sel # (n*p^2, m, w^2, c) @ (n*p^2, m, c, topk*h_kv*w_kv) -> (n*p^2, m, w^2, topk*h_kv*w_kv)
  229. attn_weight = self.attn_act(attn_weight)
  230. out = attn_weight @ v_pix_sel # (n*p^2, m, w^2, topk*h_kv*w_kv) @ (n*p^2, m, topk*h_kv*w_kv, c) -> (n*p^2, m, w^2, c)
  231. out = rearrange(out, '(n j i) m (h w) c -> n (j h) (i w) (m c)', j=self.n_win, i=self.n_win,
  232. h=H//self.n_win, w=W//self.n_win)
  233. out = out + lepe
  234. # output linear
  235. out = self.wo(out)
  236. # NOTE: use padding for semantic segmentation
  237. # crop padded region
  238. if self.auto_pad and (pad_r > 0 or pad_b > 0):
  239. out = out[:, :H_in, :W_in, :].contiguous()
  240. if ret_attn_mask:
  241. return out, r_weight, r_idx, attn_weight
  242. else:
  243. return out
本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号