当前位置:   article > 正文

【CVPR2023】EfficientViT:具级联组注意力的访存高效ViT

efficientvit

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

摘要

        视觉Transformer由于其高度的建模能力而取得了巨大的成功。 然而,它们的卓越性能伴随着沉重的计算代价,这使得它们不适合于实时应用。 在本文中,我们提出了一个高速视觉变换器族,命名为EfficientViT。 我们发现,现有的Transformer模型的速度通常受到访存效率低的操作的限制,尤其是在MHSA中的张量重塑和逐元素函数。 因此,我们设计了一种新的三明治布局的构建块,即在有效的FFN层之间使用单个内存受限的MHSA,在增强通道通信的同时提高了访存效率。 此外,我们发现注意力图在头部之间有很高的相似性,导致计算冗余。 为了解决这一问题,我们提出了一个级联的分组注意力模块,给注意力头提供全特征的不同划分,不仅节省了计算开销,而且提高了注意力的多样性。 广泛实验证明EfficientViT优于现有的高效模型,在速度和精度之间取得了良好的平衡。 例如,我们的效率EffilientViT-M5在精确度上比MobileNetV3-Large高1.9%,而在英伟达V100 GPU和英特尔至强CPU上的吞吐量分别高出40.4%和45.2%。 与最近的高效机型MobileVit-XXS相比,EffilientViT-M2在GPU/CPU上运行速度快5.8×/3.7×,转换为ONNX格式时速度快7.4×,精确度高1.8%。

1. EfficientViT

        如图2所示,本文首先分析了DeiT和Swin两个架构的运行时间分析,发现Transformer架构的速度通常受限于访存。针对这一问题本文提出了一种三明治架构,即2N个FFN中间中间夹一个MHSA结构的级联分组注意力。同时本文发现现有的多头划分方法导致每个头的注意力图高度相似,这造成了资源的浪费,本文提出了一种新的多头划分策略来缓解这一问题。

        本文通过分析DeiT和Swin两个Transformer架构得出如下结论:

  1. 适当降低MHSA层利用率可以在提高模型性能的同时提高访存效率(如图3所示)
  2. 在不同的头部使用不同的通道划分特征,而不是像MHSA那样对所有头部使用相同的全特征,可以有效地减少注意力计算冗余(如图4所示)
  3. 典型的通道配置,即在每个阶段之后将通道数加倍或对所有块使用等效通道,可能在最后几个块中产生大量冗余(如图5所示)
  4. 在维度相同的情况下,Q、K的冗余度比V大得多(如图5所示)



        本文的整体框架如图6所示,包含三个阶段,每个阶段包含若干个三明治结构,三明治结构由2N个DWConv(空间局部通信)和FFN(信道通信)以及级联分组注意力构成。级联分组注意力相对于之前的MHSA不同之处在于先划分头部然后再生成Q、K、V。同时为了学习更丰富的特征映射来提高模型容量,本文将每个头的输出与下一个头的输入相加。最后将多个头输出Concat起来,使用一个线性层进行映射得到最终的输出,用公式表示为:

X ~ i j = Attn ⁡ ( X i j W i j Q , X i j W i j K , X i j W i j V ) X ~ i + 1 = Concat ⁡ [ X ~ i j ] j = 1 : h W i P X i j ′ = X i j + X ~ i ( j − 1 ) , 1 < j ≤ h

X~ij=Attn(XijWijQ,XijWijK,XijWijV)X~i+1=Concat[X~ij]j=1:hWiPXij=Xij+X~i(j1),1<jh
X~ijX~i+1Xij=Attn(XijWijQ,XijWijK,XijWijV)=Concat[X~ij]j=1:hWiP=Xij+X~i(j1),1<jh

2. 代码复现

2.1 下载并导入所需的库

!pip install paddlex
  • 1
%matplotlib inline
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib.pyplot as plt
from paddle.vision.datasets import Cifar10
from paddle.vision.transforms import Transpose
from paddle.io import Dataset, DataLoader
from paddle import nn
import paddle.nn.functional as F
import paddle.vision.transforms as transforms
import os
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
import paddlex
import itertools
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

2.2 创建数据集

train_tfm = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0)),
    transforms.ColorJitter(brightness=0.2,contrast=0.2, saturation=0.2),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomRotation(20),
    paddlex.transforms.MixupImage(),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

test_tfm = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
paddle.vision.set_image_backend('cv2')
# 使用Cifar10数据集
train_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='train', transform = train_tfm, )
val_dataset = Cifar10(data_file='data/data152754/cifar-10-python.tar.gz', mode='test',transform = test_tfm)
print("train_dataset: %d" % len(train_dataset))
print("val_dataset: %d" % len(val_dataset))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
train_dataset: 50000
val_dataset: 10000
  • 1
  • 2
batch_size=256
  • 1
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=4)
  • 1
  • 2

2.3 模型的创建

2.3.1 标签平滑
class LabelSmoothingCrossEntropy(nn.Layer):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing

    def forward(self, pred, target):

        confidence = 1. - self.smoothing
        log_probs = F.log_softmax(pred, axis=-1)
        idx = paddle.stack([paddle.arange(log_probs.shape[0]), target], axis=1)
        nll_loss = paddle.gather_nd(-log_probs, index=idx)
        smooth_loss = paddle.mean(-log_probs, axis=-1)
        loss = confidence * nll_loss + self.smoothing * smooth_loss

        return loss.mean()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
2.3.2 DropPath
def drop_path(x, drop_prob=0.0, training=False):
    """
    Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
    """
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = paddle.to_tensor(1 - drop_prob)
    shape = (paddle.shape(x)[0],) + (1,) * (x.ndim - 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=x.dtype)
    random_tensor = paddle.floor(random_tensor)  # binarize
    output = x.divide(keep_prob) * random_tensor
    return output


class DropPath(nn.Layer):
    def __init__(self, drop_prob=None):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
2.3.3 EfficientViT模型的创建
class Conv2D_BN(nn.Sequential):
    def __init__(self, in_channel, out_channel, ks=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1):
        super().__init__()
        self.add_sublayer('conv', nn.Conv2D(in_channel, out_channel, ks, stride=stride, padding=padding, groups=groups, dilation=dilation))
        self.add_sublayer('bn', nn.BatchNorm2D(out_channel))
        init = nn.initializer.Constant(bn_weight_init)
        init(self.bn.weight)
        zero = nn.initializer.Constant(0)
        zero(self.bn.bias)


class BN_Linear(nn.Sequential):
    def __init__(self, in_channel, out_channel,bias=True, std=0.02):
        super().__init__()
        self.add_sublayer('bn', nn.BatchNorm1D(in_channel))
        self.add_sublayer('linear', nn.Linear(in_channel, out_channel, bias_attr=bias))
        tn = nn.initializer.TruncatedNormal(std=std)
        tn(self.linear.weight)
        if bias:
            zero = nn.initializer.Constant(0.0)
            zero(self.linear.bias)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
class  SqueezeExcite(nn.Layer):
    def __init__(
            self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
            bias=True, act_layer=nn.ReLU, norm_layer=None, gate_layer=nn.Sigmoid):
        super().__init__()

        self.fc1 = nn.Conv2D(channels, int(channels * rd_ratio), kernel_size=1, bias_attr=bias)
        self.act = act_layer()
        self.fc2 = nn.Conv2D(int(channels * rd_ratio), channels, kernel_size=1, bias_attr=bias)
        self.gate = gate_layer()

    def forward(self, x):
        x_se = x.mean((2, 3), keepdim=True)
        x_se = self.fc1(x_se)
        x_se = self.act(x_se)
        x_se = self.fc2(x_se)
        return x * self.gate(x_se)


class PatchMerging(nn.Layer):
    def __init__(self, dim, out_dim):
        super().__init__()
        hid_dim = int(dim * 4)
        self.conv1 = Conv2D_BN(dim, hid_dim, 1, 1, 0)
        self.act = nn.ReLU()
        self.conv2 = Conv2D_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim)
        self.se = SqueezeExcite(hid_dim, .25)
        self.conv3 = Conv2D_BN(hid_dim, out_dim, 1, 1, 0)

    def forward(self, x):
        x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
class Residual(nn.Layer):
    def __init__(self, m, drop=0.):
        super().__init__()
        self.m = m
        self.dropout = nn.Dropout(drop)

    def forward(self, x):
       return x + self.dropout(self.m(x))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
class FFN(nn.Layer):
    def __init__(self, ed, h):
        super().__init__()
        self.pw1 = Conv2D_BN(ed, h)
        self.act = nn.ReLU()
        self.pw2 = Conv2D_BN(h, ed, bn_weight_init=0)

    def forward(self, x):
        x = self.pw2(self.act(self.pw1(x)))
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
class CascadedGroupAttention(nn.Layer):
    def __init__(self, dim, key_dim, num_heads=8,
                 attn_ratio=4,
                 resolution=14,
                 kernels=[5, 5, 5, 5]):
        super().__init__()
        self.resolution = resolution
        self.num_heads = num_heads
        self.scale = key_dim ** -0.5
        self.key_dim = key_dim
        self.d = int(attn_ratio * key_dim)
        self.attn_ratio = attn_ratio

        qkvs = []
        dws = []
        for i in range(num_heads):
            qkvs.append(Conv2D_BN(dim // (num_heads), self.key_dim * 2 + self.d))
            dws.append(Conv2D_BN(self.key_dim, self.key_dim, kernels[i], stride=1, padding=kernels[i]//2, groups=self.key_dim))
        self.qkvs = nn.LayerList(qkvs)
        self.dws = nn.LayerList(dws)
        self.proj = nn.Sequential(nn.ReLU(), Conv2D_BN(
            self.d * num_heads, dim, bn_weight_init=0))

        points = list(itertools.product(range(resolution), range(resolution)))
        N = len(points)
        self.N = N
        attention_offsets = {}
        idxs = []
        for p1 in points:
            for p2 in points:
                offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
                if offset not in attention_offsets:
                    attention_offsets[offset] = len(attention_offsets)
                idxs.append(attention_offsets[offset])
        self.attention_biases = self.create_parameter((len(attention_offsets), num_heads), default_initializer=nn.initializer.Constant(0.0))
        self.attention_bias_idxs = idxs

    def forward(self, x):  # x (B,C,H,W)
        B, C, H, W = x.shape
        trainingab = self.attention_biases[self.attention_bias_idxs].transpose((1, 0)).reshape((self.num_heads, self.N, self.N))
        feats_in = paddle.chunk(x, len(self.qkvs), axis=1)
        feats_out = []
        feat = feats_in[0]
        for i, qkv in enumerate(self.qkvs):
            if i > 0: # add the previous output to the input
                feat = feat + feats_in[i]
            feat = qkv(feat)
            q, k, v = feat.split([self.key_dim, self.key_dim, self.d], axis=1) # B, C/h, H, W
            q = self.dws[i](q)
            q, k, v = q.flatten(2), k.flatten(2), v.flatten(2) # B, C/h, N
            attn = (q.transpose([0, 2, 1]) @ k) * self.scale
            attn = attn + trainingab[i]
            attn = F.softmax(attn, axis=-1) # BNN
            feat = (v @ attn.transpose([0, 2, 1])).reshape((B, self.d, H, W)) # BCHW
            feats_out.append(feat)
        x = self.proj(paddle.concat(feats_out, axis=1))
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
class LocalWindowAttention(nn.Layer):
    def __init__(self, dim, key_dim, num_heads=8,
                 attn_ratio=4,
                 resolution=14,
                 window_resolution=7,
                 kernels=[5, 5, 5, 5],):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.resolution = resolution
        assert window_resolution > 0, 'window_size must be greater than 0'
        self.window_resolution = window_resolution

        window_resolution = min(window_resolution, resolution)
        self.attn = CascadedGroupAttention(dim, key_dim, num_heads,
                                attn_ratio=attn_ratio,
                                resolution=window_resolution,
                                kernels=kernels,)

    def forward(self, x):
        H = W = self.resolution
        B, C, H_, W_ = x.shape
        # Only check this for classifcation models
        assert H == H_ and W == W_, 'input feature has wrong size, expect {}, got {}'.format((H, W), (H_, W_))

        if H <= self.window_resolution and W <= self.window_resolution:
            x = self.attn(x)
        else:
            x = x.transpose([0, 2, 3, 1])
            pad_b = (self.window_resolution - H %
                     self.window_resolution) % self.window_resolution
            pad_r = (self.window_resolution - W %
                     self.window_resolution) % self.window_resolution
            padding = pad_b > 0 or pad_r > 0

            if padding:
                x = F.pad(x, (0, pad_r, 0, pad_b))

            pH, pW = H + pad_b, W + pad_r
            nH = pH // self.window_resolution
            nW = pW // self.window_resolution
            # window partition, BHWC -> B(nHh)(nWw)C -> BnHnWhwC -> (BnHnW)hwC -> (BnHnW)Chw
            x = x.reshape((B, nH, self.window_resolution, nW, self.window_resolution, C)).transpose([0, 1, 3, 2, 4, 5]).reshape(
                (B * nH * nW, self.window_resolution, self.window_resolution, C)
            ).transpose([0, 3, 1, 2])
            x = self.attn(x)
            # window reverse, (BnHnW)Chw -> (BnHnW)hwC -> BnHnWhwC -> B(nHh)(nWw)C -> BHWC
            x = x.transpose((0, 2, 3, 1)).reshape((B, nH, nW, self.window_resolution, self.window_resolution,
                       C)).transpose([0, 1, 3, 2, 4, 5]).reshape((B, pH, pW, C))
            if padding:
                x = x[:, :H, :W]
            x = x.transpose([0, 3, 1, 2])
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
class EfficientViTBlock(nn.Layer):
    def __init__(self, type,
                 ed, kd, nh=8,
                 ar=4,
                 resolution=14,
                 window_resolution=7,
                 kernels=[5, 5, 5, 5],):
        super().__init__()

        self.dw0 = Residual(Conv2D_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.))
        self.ffn0 = Residual(FFN(ed, int(ed * 2)))

        if type == 's':
            self.mixer = Residual(LocalWindowAttention(ed, kd, nh, attn_ratio=ar, \
                    resolution=resolution, window_resolution=window_resolution, kernels=kernels))

        self.dw1 = Residual(Conv2D_BN(ed, ed, 3, 1, 1, groups=ed, bn_weight_init=0.))
        self.ffn1 = Residual(FFN(ed, int(ed * 2)))

    def forward(self, x):
        return self.ffn1(self.dw1(self.mixer(self.ffn0(self.dw0(x)))))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
class EfficientViT(nn.Layer):
    def __init__(self, img_size=224,
                 patch_size=16,
                 in_chans=3,
                 num_classes=1000,
                 stages=['s', 's', 's'],
                 embed_dim=[64, 128, 192],
                 key_dim=[16, 16, 16],
                 depth=[1, 2, 3],
                 num_heads=[4, 4, 4],
                 window_size=[7, 7, 7],
                 kernels=[5, 5, 5, 5],
                 down_ops=[['subsample', 2], ['subsample', 2], ['']]):
        super().__init__()

        resolution = img_size
        # Patch embedding
        self.patch_embed = nn.Sequential(Conv2D_BN(in_chans, embed_dim[0] // 8, 3, 2, 1), nn.ReLU(),
                           Conv2D_BN(embed_dim[0] // 8, embed_dim[0] // 4, 3, 2, 1), nn.ReLU(),
                           Conv2D_BN(embed_dim[0] // 4, embed_dim[0] // 2, 3, 2, 1), nn.ReLU(),
                           Conv2D_BN(embed_dim[0] // 2, embed_dim[0], 3, 2, 1))

        resolution = img_size // patch_size
        attn_ratio = [embed_dim[i] / (key_dim[i] * num_heads[i]) for i in range(len(embed_dim))]
        self.blocks1 = []
        self.blocks2 = []
        self.blocks3 = []

        # Build EfficientViT blocks
        for i, (stg, ed, kd, dpth, nh, ar, wd, do) in enumerate(
                zip(stages, embed_dim, key_dim, depth, num_heads, attn_ratio, window_size, down_ops)):
            for d in range(dpth):
                eval('self.blocks' + str(i+1)).append(EfficientViTBlock(stg, ed, kd, nh, ar, resolution, wd, kernels))
            if do[0] == 'subsample':
                # Build EfficientViT downsample block
                #('Subsample' stride)
                blk = eval('self.blocks' + str(i+2))
                resolution_ = (resolution - 1) // do[1] + 1
                blk.append(nn.Sequential(Residual(Conv2D_BN(embed_dim[i], embed_dim[i], 3, 1, 1, groups=embed_dim[i])),
                                    Residual(FFN(embed_dim[i], int(embed_dim[i] * 2)))))
                blk.append(PatchMerging(*embed_dim[i:i + 2]))
                resolution = resolution_
                blk.append(nn.Sequential(Residual(Conv2D_BN(embed_dim[i + 1], embed_dim[i + 1], 3, 1, 1, groups=embed_dim[i + 1])),
                                    Residual(FFN(embed_dim[i + 1], int(embed_dim[i + 1] * 2)))))
        self.blocks1 = nn.Sequential(*self.blocks1)
        self.blocks2 = nn.Sequential(*self.blocks2)
        self.blocks3 = nn.Sequential(*self.blocks3)

        # Classification head
        self.head = BN_Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.blocks1(x)
        x = self.blocks2(x)
        x = self.blocks3(x)
        x = F.adaptive_avg_pool2d(x, 1).flatten(1)
        x = self.head(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
num_classes = 10

def EfficientViT_M0():

    model = EfficientViT(embed_dim=[64, 128, 192], depth=[1, 2, 3], num_heads=[4, 4, 4], kernels=[5, 5, 5, 5], num_classes=num_classes)

    return model


def EfficientViT_M1():

    model = EfficientViT(embed_dim=[128, 144, 192], depth=[1, 2, 3], num_heads=[2, 3, 3], kernels=[7, 5, 3, 3], num_classes=num_classes)

    return model


def EfficientViT_M2():

    model = EfficientViT(embed_dim=[128, 192, 224], depth=[1, 2, 3], num_heads=[4, 3, 2], kernels=[7, 5, 3, 3], num_classes=num_classes)

    return model


def EfficientViT_M3():

    model = EfficientViT(embed_dim=[128, 240, 320], depth=[1, 2, 3], num_heads=[4, 3, 4], kernels=[5, 5, 5, 5], num_classes=num_classes)

    return model


def EfficientViT_M4():

    model = EfficientViT(embed_dim=[128, 256, 384], depth=[1, 2, 3], num_heads=[4, 4, 4], kernels=[7, 5, 3, 3], num_classes=num_classes)

    return model


def EfficientViT_M5():

    model = EfficientViT(embed_dim=[192, 288, 384], depth=[1, 3, 4], num_heads=[3, 3, 4], kernels=[7, 5, 3, 3], num_classes=num_classes)

    return model
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
2.3.4 模型的参数
model = EfficientViT_M0()
paddle.summary(model, (1, 3, 224, 224))
  • 1
  • 2

model = EfficientViT_M1()
paddle.summary(model, (1, 3, 224, 224))
  • 1
  • 2

model = EfficientViT_M2()
paddle.summary(model, (1, 3, 224, 224))
  • 1
  • 2

model = EfficientViT_M3()
paddle.summary(model, (1, 3, 224, 224))
  • 1
  • 2

model = EfficientViT_M4()
paddle.summary(model, (1, 3, 224, 224))
  • 1
  • 2

model = EfficientViT_M5()
paddle.summary(model, (1, 3, 224, 224))
  • 1
  • 2

2.4 训练

learning_rate = 0.001
n_epochs = 100
paddle.seed(42)
np.random.seed(42)
  • 1
  • 2
  • 3
  • 4
work_path = 'work/model'

# EfficientViT-M0
model = EfficientViT_M0()

criterion = LabelSmoothingCrossEntropy()

scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=learning_rate, T_max=50000 // batch_size * n_epochs, verbose=False)
optimizer = paddle.optimizer.Adam(parameters=model.parameters(), learning_rate=scheduler, weight_decay=1e-5)

gate = 0.0
threshold = 0.0
best_acc = 0.0
val_acc = 0.0
loss_record = {'train': {'loss': [], 'iter': []}, 'val': {'loss': [], 'iter': []}}   # for recording loss
acc_record = {'train': {'acc': [], 'iter': []}, 'val': {'acc': [], 'iter': []}}      # for recording accuracy

loss_iter = 0
acc_iter = 0

for epoch in range(n_epochs):
    # ---------- Training ----------
    model.train()
    train_num = 0.0
    train_loss = 0.0

    val_num = 0.0
    val_loss = 0.0
    accuracy_manager = paddle.metric.Accuracy()
    val_accuracy_manager = paddle.metric.Accuracy()
    print("#===epoch: {}, lr={:.10f}===#".format(epoch, optimizer.get_lr()))
    for batch_id, data in enumerate(train_loader):
        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)

        logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        accuracy_manager.update(acc)
        if batch_id % 10 == 0:
            loss_record['train']['loss'].append(loss.numpy())
            loss_record['train']['iter'].append(loss_iter)
            loss_iter += 1

        loss.backward()

        optimizer.step()
        scheduler.step()
        optimizer.clear_grad()

        train_loss += loss
        train_num += len(y_data)

    total_train_loss = (train_loss / train_num) * batch_size
    train_acc = accuracy_manager.accumulate()
    acc_record['train']['acc'].append(train_acc)
    acc_record['train']['iter'].append(acc_iter)
    acc_iter += 1
    # Print the information.
    print("#===epoch: {}, train loss is: {}, train acc is: {:2.2f}%===#".format(epoch, total_train_loss.numpy(), train_acc*100))

    # ---------- Validation ----------
    model.eval()

    for batch_id, data in enumerate(val_loader):

        x_data, y_data = data
        labels = paddle.unsqueeze(y_data, axis=1)
        with paddle.no_grad():
          logits = model(x_data)

        loss = criterion(logits, y_data)

        acc = paddle.metric.accuracy(logits, labels)
        val_accuracy_manager.update(acc)

        val_loss += loss
        val_num += len(y_data)

    total_val_loss = (val_loss / val_num) * batch_size
    loss_record['val']['loss'].append(total_val_loss.numpy())
    loss_record['val']['iter'].append(loss_iter)
    val_acc = val_accuracy_manager.accumulate()
    acc_record['val']['acc'].append(val_acc)
    acc_record['val']['iter'].append(acc_iter)

    print("#===epoch: {}, val loss is: {}, val acc is: {:2.2f}%===#".format(epoch, total_val_loss.numpy(), val_acc*100))

    # ===================save====================
    if val_acc > best_acc:
        best_acc = val_acc
        paddle.save(model.state_dict(), os.path.join(work_path, 'best_model.pdparams'))
        paddle.save(optimizer.state_dict(), os.path.join(work_path, 'best_optimizer.pdopt'))

print(best_acc)
paddle.save(model.state_dict(), os.path.join(work_path, 'final_model.pdparams'))
paddle.save(optimizer.state_dict(), os.path.join(work_path, 'final_optimizer.pdopt'))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99

2.5 结果分析

def plot_learning_curve(record, title='loss', ylabel='CE Loss'):
    ''' Plot learning curve of your CNN '''
    maxtrain = max(map(float, record['train'][title]))
    maxval = max(map(float, record['val'][title]))
    ymax = max(maxtrain, maxval) * 1.1
    mintrain = min(map(float, record['train'][title]))
    minval = min(map(float, record['val'][title]))
    ymin = min(mintrain, minval) * 0.9

    total_steps = len(record['train'][title])
    x_1 = list(map(int, record['train']['iter']))
    x_2 = list(map(int, record['val']['iter']))
    figure(figsize=(10, 6))
    plt.plot(x_1, record['train'][title], c='tab:red', label='train')
    plt.plot(x_2, record['val'][title], c='tab:cyan', label='val')
    plt.ylim(ymin, ymax)
    plt.xlabel('Training steps')
    plt.ylabel(ylabel)
    plt.title('Learning curve of {}'.format(title))
    plt.legend()
    plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
plot_learning_curve(loss_record, title='loss', ylabel='CE Loss')
  • 1

在这里插入图片描述

plot_learning_curve(acc_record, title='acc', ylabel='Accuracy')
  • 1

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bIbPy1KT-1685187440917)(main_files/main_47_0.png)]

import time
work_path = 'work/model'
model = EfficientViT_M0()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
aa = time.time()
for batch_id, data in enumerate(val_loader):

    x_data, y_data = data
    labels = paddle.unsqueeze(y_data, axis=1)
    with paddle.no_grad():
        logits = model(x_data)
bb = time.time()
print("Throughout:{}".format(int(len(val_dataset)//(bb - aa))))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
Throughout:1002
  • 1
def get_cifar10_labels(labels):
    """返回CIFAR10数据集的文本标签。"""
    text_labels = [
        'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog',
        'horse', 'ship', 'truck']
    return [text_labels[int(i)] for i in labels]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
def show_images(imgs, num_rows, num_cols, pred=None, gt=None, scale=1.5):
    """Plot a list of images."""
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if paddle.is_tensor(img):
            ax.imshow(img.numpy())
        else:
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if pred or gt:
            ax.set_title("pt: " + pred[i] + "\ngt: " + gt[i])
    return axes
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
work_path = 'work/model'
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
model = EfficientViT_M0()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
model.eval()
logits = model(X)
y_pred = paddle.argmax(logits, -1)
X = paddle.transpose(X, [0, 2, 3, 1])
axes = show_images(X.reshape((18, 224, 224, 3)), 1, 18, pred=get_cifar10_labels(y_pred), gt=get_cifar10_labels(y))
plt.show()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

在这里插入图片描述

!pip install interpretdl
  • 1
import interpretdl as it
  • 1
work_path = 'work/model'
model = EfficientViT_M0()
model_state_dict = paddle.load(os.path.join(work_path, 'best_model.pdparams'))
model.set_state_dict(model_state_dict)
  • 1
  • 2
  • 3
  • 4
X, y = next(iter(DataLoader(val_dataset, batch_size=18)))
lime = it.LIMECVInterpreter(model)
  • 1
  • 2
lime_weights = lime.interpret(X.numpy()[3], interpret_class=y.numpy()[3], batch_size=100, num_samples=10000, visual=True)
  • 1
100%|██████████| 10000/10000 [00:55<00:00, 181.62it/s]
  • 1

55<00:00, 181.62it/s]

在这里插入图片描述

总结

        本文提出的EfficientViT-M0以2.2M的参数在CIFAR上可以达到89.4%的准确率,同时在图像分辨率为224的情况下吞吐量可以达到1002 imgs/s

参考文献

  1. EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention
  2. microsoft/Cream/tree/main/EfficientViT

此文章为搬运
原项目链接

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/439889
推荐阅读
  

闽ICP备14008679号