当前位置:   article > 正文

【CVPR 2023】BiFormer: 具有双层路由注意力的ViT_bilevelroutingattention

bilevelroutingattention

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

摘要

        作为视觉Transformer的核心组成部分,注意力是捕捉远程依赖性的有力工具。 然而,这种能力是有代价的:当计算跨所有空间位置的成对Token交互时,它会带来巨大的计算负担和沉重的内存占用。 一系列的工作试图通过在注意力中引入手工的和内容不可知的稀疏性来缓解这一问题,例如将注意力操作限制在局部窗口、轴向条纹或扩张窗口内。 相对于这些方法,我们提出了一种新的动态稀疏注意力通过双层路由,使计算更灵活地分配且内容感知。 具体来说,对于查询,首先在粗区域级别过滤掉不相关的键值对,然后在剩余候选区域(即路由区域)的并集中应用细粒度的Token到Token关注。 我们提供了一个简单而有效的双层路由注意力的实现,它利用稀疏性来节省计算和内存,同时只涉及GPU友好的稠密矩阵乘法。 在此基础上,提出了一种新的通用视觉Transformer——Biformer。 由于Biformer以查询自适应的方式关注少量相关令牌,而不会分散注意力到其他无关令牌,因此它具有良好的性能和较高的计算效率,特别是在密集预测任务中。 在图像分类、目标检测和语义分割等多个计算机视觉任务中的实验结果验证了我们设计的有效性。

1. BiFormer

        众所周知,Transformer相比于CNNs的一大核心优势便是借助自注意力机制的优势捕捉长距离上下文依赖。正所谓物极必反,如图1(a)所示,在原始的 Transformer 架构设计中,这种结构虽然在一定程度上带来了性能上的提升,但却会引起两个老生常态的问题:

  1. 内存占用大
  2. 计算代价高

        因此,有许多研究也在致力于做一些这方面的优化工作,包括但不仅限于将注意力操作限制在:

  1. inside local windows, e.g., Swin transformer and Crossformer(如图1(b))
  2. axial stripes, e.g., Cswin transformer(如图1(c))
  3. dilated windows, e.g., Maxvit and Crossformer(如图1(d))
  4. Deformable windows, e.g., DAT(如图1(e))

        总的来说,作者认为以上这些方法大都是通过将手工制作和与内容无关的稀疏性引入到注意力机制来试图缓解这个问题。因此,本文通过双层路由(Bi-level Routing)提出了一种新颖的动态稀疏注意力(Dynamic Sparse Attention ),以实现更灵活的计算分配和内容感知,使其具备动态的查询感知稀疏性,如图1(f)所示。

1.1 Bi-Level Routing Attention (BRA)

        本文探索了一种动态的、查询感知的稀疏注意力机制,其关键思想是在粗糙区域级别过滤掉大部分不相关的键值对,以便只保留一小部分路由区域(。其次,作者在这些路由区域的联合中应用细粒度的token-to-token注意力。BRA主要包括如下三个部分:

  1. Region partition and input projection。首先将特征图划分为 S × S S \times S S×S 个不重叠区域并进行线性映射:

Q = X r W q , K = X r W k , V = X r W v \mathbf{Q}=\mathbf{X}^{r} \mathbf{W}^{q}, \quad \mathbf{K}=\mathbf{X}^{r} \mathbf{W}^{k}, \quad \mathbf{V}=\mathbf{X}^{r} \mathbf{W}^{v} Q=XrWq,K=XrWk,V=XrWv

  1. Region-to-region routing with directed graph。然后在粗粒度Token上计算注意力权重,然后仅取其中的Topk区域作为相关区域参与细粒度的运算。

A r = Q r ( K r ) T I r =  topkIndex  ( A r ) Ar=Qr(Kr)TIr= topkIndex (Ar)

Ar=Qr(Kr)TIr= topkIndex (Ar)

  1. Token-to-token attention。最后将每个Token最相关的Topk粗粒度区域作为键和值参与最终的运算。为了增强局部性,在值上还是用了一个深度卷积。

K g = gather ⁡ ( K , I r ) , V g = gather ⁡ ( V , I r ) O =  Attention  ( Q , K g , V g ) + L C E ( V ) Kg=gather(K,Ir),Vg=gather(V,Ir)O= Attention (Q,Kg,Vg)+LCE(V)

Kg=gather(K,Ir),Vg=gather(V,Ir)O= Attention (Q,Kg,Vg)+LCE(V)

1.2 BiFormer

        基于BRA模块,本文构建了一种新颖的通用视觉Transformer——BiFormer。如上图所示,其遵循大多数的vision transformer架构设计,也是采用四级金字塔结构,即下采样32倍。具体来说,BiFormer在第一阶段使用重叠块嵌入,在第二到第四阶段使用Merge模块来降低输入空间分辨率,同时增加通道数,然后是采用连续的BiFormer块做特征变换。需要注意的是,在每个块的开始均是使用 3 × 3 3 \times 3 3×3 的深度卷积来隐式编码相对位置信息。随后依次应用BRA模块和扩展率为e的2层多层感知机(Multi-Layer Perceptron, MLP)模块,分别用于交叉位置关系建模和每个位置嵌入。

2. 代码复现

2.1 下载并导入所需的库

!pip install einops-0.3.0-py3-none-any.whl
  • 1
!pip install paddlex
  • 1
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
import itertools
from einops import rearrange
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm)
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
train_dataset: 50000
val_dataset: 10000
  • 1
  • 2
batch_size=32
  • 1
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
  • 1
  • 2

2.3 模型的创建

2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
2.3.3 BiFormer模型的创建
class TopkRouting(nn.Layer):
    """
    differentiable topk routing with scaling
    Args:
        qk_dim: int, feature dimension of query and key
        topk: int, the 'topk'
        qk_scale: int or None, temperature (multiply) of softmax activation
        with_param: bool, wether inorporate learnable params in routing unit
    """
    def __init__(self, qk_dim, topk=4, qk_scale=None):
        super().__init__()
        self.topk = topk
        self.qk_dim = qk_dim
        self.scale = qk_scale or qk_dim ** -0.5

    def forward(self, query, key):
        """
        Args:
            q, k: (n, p^2, c) tensor
        Return:
            topk_index: (n, p^2, topk) tensor
        """
        query_hat, key_hat = query.detach(), key.detach()
        attn_logit = (query_hat * self.scale) @ key_hat.transpose([0, 2, 1]) # (n, p^2, p^2)
        topk_attn_logit, topk_index = paddle.topk(attn_logit, k=self.topk, axis=-1) # (n, p^2, k), (n, p^2, k)

        return topk_index
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
class KVGather(nn.Layer):
    def __init__(self):
        super().__init__()

    def forward(self, r_idx, kv):
        """
        r_idx: (b, p^2, topk) tensor
        r_weight: (b, p^2, topk) tensor
        kv: (b, p^2, w^2, c_kq+c_v)

        Return:
            (n, p^2, topk, w^2, c_kq+c_v) tensor
        """
        # select kv according to routing index
        b, p2, w2, c_kv = kv.shape
        topk = r_idx.shape[-1]

        topk_kv = paddle.take_along_axis(
            kv.reshape((b, 1, p2, w2, c_kv)).expand((-1, p2, -1, -1, -1)), # (n, p^2, p^2, w^2, c_kv) without mem cpy
            r_idx.reshape((b, p2, topk, 1, 1)).expand((-1, -1, -1, w2, c_kv)), # (n, p^2, k, w^2, c_kv)
            axis=2
        )

        return topk_kv
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
class QKVLinear(nn.Layer):
    def __init__(self, dim, qk_dim, bias=True):
        super().__init__()
        self.dim = dim
        self.qk_dim = qk_dim
        self.qkv = nn.Linear(dim, qk_dim + qk_dim + dim, bias_attr=bias)

    def forward(self, x):
        q, kv = paddle.split(self.qkv(x), [self.qk_dim, self.qk_dim+self.dim], axis=-1)
        return q, kv
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
class BiLevelRoutingAttention(nn.Layer):
    def __init__(self, dim, num_heads=8, n_win=7,  qk_dim=None, qk_scale=None, topk=4, side_dwconv=3):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.n_win = n_win
        self.qk_dim = qk_dim or dim
        assert self.qk_dim % num_heads == 0 and self.dim % num_heads==0, 'qk_dim and dim must be divisible by num_heads!'
        self.qk_scale = qk_scale or self.qk_dim ** -0.5

        ################side_dwconv (i.e. LCE in ShuntedTransformer)###########
        self.lepe = nn.Conv2D(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim)

        self.topk = topk

        # 寻找Topk区域
        self.router = TopkRouting(qk_dim=self.qk_dim,
                                  qk_scale=self.qk_scale,
                                  topk=self.topk)
        self.kv_gather = KVGather()

        self.qkv = QKVLinear(self.dim, self.qk_dim)
        self.wo = nn.Linear(dim, dim)

        self.attn_act = nn.Softmax(axis=-1)

    def forward(self, x):
        """
        Input:
            x: NHWC tensor

        Return:
            NHWC tensor
        """

        N, H, W, C = x.shape

        x = rearrange(x, "b (j h) (i w) c -> b (j i) h w c", j=self.n_win, i=self.n_win)

        q, kv = self.qkv(x)

        q_pix = rearrange(q, 'b p2 h w c -> b p2 (h w) c')
        kv_pix = rearrange(kv, 'b p2 h w c -> b p2 (h w) c')

        q_win, k_win = q.mean([2, 3]), kv[..., 0:self.qk_dim].mean([2, 3]) # window-wise qk, (b, p^2, c_qk), (b, p^2, c_qk)

        ################side_dwconv (i.e. LCE in ShuntedTransformer)###########
        lepe = self.lepe(rearrange(kv[..., self.qk_dim:], 'b (j i) h w c -> b c (j h) (i w)', j=self.n_win, i=self.n_win))
        lepe = rearrange(lepe, 'b c h w -> b h w c')

        r_idx = self.router(q_win, k_win)
        kv_pix_sel = self.kv_gather(r_idx=r_idx, kv=kv_pix) #(n, p^2, topk, h_kv*w_kv, c_qk+c_v)
        k_pix_sel, v_pix_sel = paddle.split(kv_pix_sel, [self.qk_dim, self.dim], axis=-1)

        # MHSA
        k_pix_sel = rearrange(k_pix_sel, 'b p2 k w2 (mh c) -> (b p2) mh c (k w2)', mh=self.num_heads)
        v_pix_sel = rearrange(v_pix_sel, 'b p2 k w2 (mh c) -> (b p2) mh (k w2) c', mh=self.num_heads)
        q_pix = rearrange(q_pix, 'b p2 w2 (mh c) -> (b p2) mh w2 c', mh=self.num_heads)

        attn = q_pix @ k_pix_sel * self.qk_scale
        attn = self.attn_act(attn)
        out = attn @ v_pix_sel

        out = rearrange(out, '(b j i) mh (h w) c -> b (j h) (i w) (mh c)',
                        j=self.n_win, i=self.n_win, h=H // self.n_win, w=W // self.n_win)

        out = out + lepe

        out = self.wo(out)

        return out
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
model = BiLevelRoutingAttention(64, num_heads=2)
paddle.summary(model, (1, 56, 56, 64))
  • 1
  • 2
W0318 21:57:45.335470   399 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0318 21:57:45.340256   399 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.


----------------------------------------------------------------------------------------------------
 Layer (type)          Input Shape                      Output Shape                   Param #    
====================================================================================================
   Linear-1        [[1, 49, 8, 8, 64]]               [1, 49, 8, 8, 192]                12,480     
  QKVLinear-1      [[1, 49, 8, 8, 64]]     [[1, 49, 8, 8, 64], [1, 49, 8, 8, 128]]        0       
   Conv2D-1         [[1, 64, 56, 56]]                  [1, 64, 56, 56]                   640      
 TopkRouting-1  [[1, 49, 64], [1, 49, 64]]               [1, 49, 4]                       0       
  KVGather-1                []                       [1, 49, 4, 64, 128]                  0       
   Softmax-1        [[49, 2, 64, 256]]                [49, 2, 64, 256]                    0       
   Linear-2         [[1, 56, 56, 64]]                  [1, 56, 56, 64]                  4,160     
====================================================================================================
Total params: 17,280
Trainable params: 17,280
Non-trainable params: 0
----------------------------------------------------------------------------------------------------
Input size (MB): 0.77
Forward/backward pass size (MB): 36.75
Params size (MB): 0.07
Estimated Total Size (MB): 37.58
----------------------------------------------------------------------------------------------------






{'total_params': 17280, 'trainable_params': 17280}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
class AttentionLePE(nn.Layer):
    """
    vanilla attention
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., side_dwconv=5):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.lepe = nn.Conv2D(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim)
        self.softmax = nn.Softmax(axis=-1)

    def forward(self, x):
        """
        args:
            x: NHWC tensor
        return:
            NHWC tensor
        """
        _, H, W, _ = x.shape
        x = rearrange(x, 'n h w c -> n (h w) c')

        #######################################
        B, N, C = x.shape
        qkv = self.qkv(x).reshape((B, N, 3, self.num_heads, C // self.num_heads)).transpose([2, 0, 3, 1, 4])
        q, k, v = qkv[0], qkv[1], qkv[2]

        lepe = self.lepe(rearrange(x, 'n (h w) c -> n c h w', h=H, w=W))
        lepe = rearrange(lepe, 'n c h w -> n (h w) c')

        attn = (q @ k.transpose([0, 1, 3, 2])) * self.scale
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose([0, 2, 1, 3]).reshape((B, N, C))
        x = x + lepe

        x = self.proj(x)
        x = self.proj_drop(x)
        #######################################

        x = rearrange(x, 'n (h w) c -> n h w c', h=H, w=W)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
class Block(nn.Layer):
    def __init__(self, dim, drop_path=0., layer_scale_init_value=-1,
                       num_heads=8, n_win=7, qk_dim=None, qk_scale=None,
                       topk=4, mlp_ratio=3, side_dwconv=5, before_attn_dwconv=3):
        super().__init__()
        qk_dim = qk_dim or dim

        # modules
        self.pos_embed = nn.Conv2D(dim, dim,  kernel_size=before_attn_dwconv, padding=1, groups=dim)

        self.norm1 = nn.LayerNorm(dim, epsilon=1e-6) # important to avoid attention collapsing
        if topk > 0:
            self.attn = BiLevelRoutingAttention(dim=dim, num_heads=num_heads, n_win=n_win, qk_dim=qk_dim,
                                        qk_scale=qk_scale, topk=topk, side_dwconv=side_dwconv)
        else:
            self.attn = AttentionLePE(dim=dim, side_dwconv=side_dwconv)

        self.norm2 = nn.LayerNorm(dim, epsilon=1e-6)

        self.mlp = nn.Sequential(nn.Linear(dim, int(mlp_ratio * dim)),
                                 nn.GELU(),
                                 nn.Linear(int(mlp_ratio * dim), dim))

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        if layer_scale_init_value > 0:
            self.use_layer_scale = True
            self.gamma1 = self.create_parameter([dim], default_initializer=nn.initializer.Constant(layer_scale_init_value))
            self.gamma2 = self.create_parameter([dim], default_initializer=nn.initializer.Constant(layer_scale_init_value))
        else:
            self.use_layer_scale = False

    def forward(self, x):
        """
        x: NCHW tensor
        """
        # conv pos embedding
        x = x + self.pos_embed(x)
        # permute to NHWC tensor for attention & mlp
        x = x.transpose([0, 2, 3, 1])

        # attention & mlp
        if self.use_layer_scale:
            x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
            x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
        else:
            x = x + self.drop_path(self.attn(self.norm1(x)))
            x = x + self.drop_path(self.mlp(self.norm2(x)))

        # permute back
        x = x.transpose([0, 3, 1, 2])
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
class BiFormer(nn.Layer):
    def __init__(self, depth=[3, 4, 8, 3], in_chans=3, num_classes=1000, embed_dim=[64, 128, 320, 512],
                 head_dim=64, qk_scale=None,
                 drop_path_rate=0., drop_rate=0.,
                 n_win=7,
                 topks=[8, 8, -1, -1],
                 side_dwconv=5,
                 layer_scale_init_value=-1,
                 qk_dims=[None, None, None, None],
                 pe_stages=[0],
                 before_attn_dwconv=3,
                 mlp_ratios=[3, 3, 3, 3]):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models

        ############ downsample layers (patch embeddings) ######################
        self.downsample_layers = nn.LayerList()
        # NOTE: uniformer uses two 3*3 conv, while in many other transformers this is one 7*7 conv
        stem = nn.Sequential(
            nn.Conv2D(in_chans, embed_dim[0] // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2D(embed_dim[0] // 2),
            nn.GELU(),
            nn.Conv2D(embed_dim[0] // 2, embed_dim[0], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
            nn.BatchNorm2D(embed_dim[0]),
        )

        self.downsample_layers.append(stem)

        for i in range(3):
            downsample_layer = nn.Sequential(
                nn.Conv2D(embed_dim[i], embed_dim[i+1], kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)),
                nn.BatchNorm2D(embed_dim[i+1])
            )
            self.downsample_layers.append(downsample_layer)
        ##########################################################################

        self.stages = nn.LayerList() # 4 feature resolution stages, each consisting of multiple residual blocks
        nheads= [dim // head_dim for dim in qk_dims]
        dp_rates=[x.item() for x in paddle.linspace(0, drop_path_rate, sum(depth))]
        cur = 0
        for i in range(4):
            stage = nn.Sequential(
                *[Block(dim=embed_dim[i], drop_path=dp_rates[cur + j],
                        layer_scale_init_value=layer_scale_init_value,
                        topk=topks[i],
                        num_heads=nheads[i],
                        n_win=n_win,
                        qk_dim=qk_dims[i],
                        qk_scale=qk_scale,
                        mlp_ratio=mlp_ratios[i],
                        side_dwconv=side_dwconv,
                        before_attn_dwconv=before_attn_dwconv) for j in range(depth[i])],
            )
            self.stages.append(stage)
            cur += depth[i]

        ##########################################################################
        self.norm = nn.BatchNorm2D(embed_dim[-1])

        # Classifier head
        self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
        self.apply(self._init_weights)

    def _init_weights(self, m):
        tn = nn.initializer.TruncatedNormal(std=.02)
        zeros = nn.initializer.Constant(0.)
        ones = nn.initializer.Constant(1.)
        if isinstance(m, nn.Linear):
            tn(m.weight)
            if isinstance(m, nn.Linear) and m.bias is not None:
                zeros(m.bias)
        elif isinstance(m, (nn.LayerNorm)):
            zeros(m.bias)
            ones(m.weight)

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        for i in range(4):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)
        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = x.flatten(2).mean(-1)
        x = self.head(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
2.3.4 模型的参数
num_classes = 10

def biformer_tiny():
    model = BiFormer(
        depth=[2, 2, 8, 2],
        num_classes=num_classes,
        embed_dim=[64, 128, 256, 512], mlp_ratios=[3, 3, 3, 3],
        n_win=7,
        topks=[1, 4, 16, -2],
        side_dwconv=5,
        before_attn_dwconv=3,
        layer_scale_init_value=-1,
        qk_dims=[64, 128, 256, 512],
        head_dim=32)

    return model

def biformer_small():
    model = BiFormer(
        depth=[4, 4, 18, 4],
        num_classes=num_classes,
        embed_dim=[64, 128, 256, 512], mlp_ratios=[3, 3, 3, 3],
        n_win=7,
        topks=[1, 4, 16, -2],
        side_dwconv=5,
        before_attn_dwconv=3,
        layer_scale_init_value=-1,
        qk_dims=[64, 128, 256, 512],
        head_dim=32)

    return model

def biformer_base():
    model = BiFormer(
        depth=[4, 4, 18, 4],
        num_classes=num_classes,
        embed_dim=[96, 192, 384, 768], mlp_ratios=[3, 3, 3, 3],
        n_win=7,
        topks=[1, 4, 16, -2],
        side_dwconv=5,
        before_attn_dwconv=3,
        layer_scale_init_value=-1,
        qk_dims=[96, 192, 384, 768],
        head_dim=32)

    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
model = biformer_tiny()
paddle.summary(model, (1, 3, 224, 224))
  • 1
  • 2

model = biformer_small()
paddle.summary(model, (1, 3, 224, 224))
  • 1
  • 2

model = biformer_base()
paddle.summary(model, (1, 3, 224, 224))
  • 1
  • 2

2.4 训练

learning_rate = 0.0001
n_epochs = 50
paddle.seed(42)
np.random.seed(42)
  • 1
  • 2
  • 3
  • 4
work_path = 'work/model'

# BiFormer-T
model = biformer_tiny()

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)

    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99

2.5 结果分析

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')
  • 1

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')
  • 1

在这里插入图片描述

import time
work_path = 'work/model'
model = biformer_tiny()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
Throughout:214
  • 1
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = biformer_tiny()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

在这里插入图片描述

!pip install interpretdl
  • 1
import interpretdl as it
  • 1
work_path = 'work/model'
model = biformer_tiny()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
  • 1
  • 2
  • 3
  • 4
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model)
  • 1
  • 2
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
  • 1
100%|██████████| 10000/10000 [01:45<00:00, 95.06it/s]
  • 1

██████| 10000/10000 [01:45<00:00, 95.06it/s]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-cqDscsCZ-1687745839073)(main_files/main_53_1.png)]

总结

        BiFormer为Token的稀疏化提供了一个很好的思路,在粗粒度上计算关键区域,对关键区域在细粒度上进行注意力交互,这可以减少冗余和无关Token对查询的影响。但是在实际中,BiFormer仍难以解决Attention内存占用大的问题。

参考文献

  1. 论文:BiFormer: Vision Transformer with Bi-Level Routing Attention(CVPR 2023)
  2. 代码:rayleizhu/BiFormer
  3. CVPR’2023 即插即用系列! | BiFormer: 基于动态稀疏注意力构建高效金字塔网络架构

此文章为搬运
原项目链接

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/爱喝兽奶帝天荒/article/detail/991549
推荐阅读
相关标签
  

闽ICP备14008679号