赞
踩
【GiantPandaCV导语】来自商汤和南洋理工的工作,也是使用卷积来增强模型提出low-level特征的能力,增强模型获取局部性的能力,核心贡献是LCA模块,可以用于捕获多层特征表示。
针对先前Transformer架构需要大量额外数据或者额外的监督(Deit),才能获得与卷积神经网络结构相当的性能,为了克服这种缺陷,提出结合CNN来弥补Transformer的缺陷,提出了CeiT:
(1)设计Image-to-Tokens模块来从low-level特征中得到embedding。
(2)将Transformer中的Feed Forward模块替换为Locally-enhanced Feed-Forward(LeFF)模块,增加了相邻token之间的相关性。
(3)使用Layer-wise Class Token Attention(LCA)捕获多层的特征表示。
经过以上修改,可以发现模型效率方面以及泛化能力得到了提升,收敛性也有所改善,如下图所示:
使用卷积+池化来取代原先ViT中7x7的大型patch。
x ′ = I 2 T ( x ) = MaxPool ( BN ( Conv ( x ) ) ) \mathbf{x}^{\prime}=\mathrm{I} 2 \mathrm{~T}(\mathbf{x})=\operatorname{MaxPool}(\operatorname{BN}(\operatorname{Conv}(\mathbf{x}))) x′=I2 T(x)=MaxPool(BN(Conv(x)))
将tokens重新拼成feature map,然后使用深度可分离卷积添加局部性的处理,然后再使用一个Linear层映射至tokens。
x
c
h
,
x
p
h
=
Split
(
x
t
h
)
x
p
l
1
=
GELU
(
BN
(
Linear
(
(
x
p
h
)
)
)
x
p
s
=
SpatialRestore
(
x
p
l
1
)
x
p
d
=
GELU
(
BN
(
DWConv
(
x
p
s
)
)
)
x
p
f
=
Flatten
(
x
p
d
)
x
p
l
2
=
GELU
(
BN
(
Linear
2
(
x
p
f
)
)
)
x
t
h
+
1
=
Concat
(
x
c
h
,
x
p
l
2
)
前两个都比较常规,最后一个比较有特色,经过所有Transformer层以后使用的Layer-wise Class-token Attention,如下图所示:
LCA模块会将所有Transformer Block中得到的class token作为输入,然后再在其基础上使用一个MSA+FFN得到最终的logits输出。作者认为这样可以获取多尺度的表征。
SOTA比较:
I2T消融实验:
LeFF消融实验:
LCA有效性比较:
收敛速度比较:
模块1:I2T Image-to-Token
# IoT self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, conv_kernel, stride, 4), nn.BatchNorm2d(out_channels), nn.MaxPool2d(pool_kernel, stride) ) feature_size = image_size // 4 assert feature_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' num_patches = (feature_size // patch_size) ** 2 patch_dim = out_channels * patch_size ** 2 self.to_patch_embedding = nn.Sequential( Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size), nn.Linear(patch_dim, dim), )
模块2:LeFF
class LeFF(nn.Module): def __init__(self, dim = 192, scale = 4, depth_kernel = 3): super().__init__() scale_dim = dim*scale self.up_proj = nn.Sequential(nn.Linear(dim, scale_dim), Rearrange('b n c -> b c n'), nn.BatchNorm1d(scale_dim), nn.GELU(), Rearrange('b c (h w) -> b c h w', h=14, w=14) ) self.depth_conv = nn.Sequential(nn.Conv2d(scale_dim, scale_dim, kernel_size=depth_kernel, padding=1, groups=scale_dim, bias=False), nn.BatchNorm2d(scale_dim), nn.GELU(), Rearrange('b c h w -> b (h w) c', h=14, w=14) ) self.down_proj = nn.Sequential(nn.Linear(scale_dim, dim), Rearrange('b n c -> b c n'), nn.BatchNorm1d(dim), nn.GELU(), Rearrange('b c n -> b n c') ) def forward(self, x): x = self.up_proj(x) x = self.depth_conv(x) x = self.down_proj(x) return x class TransformerLeFF(nn.Module): def __init__(self, dim, depth, heads, dim_head, scale = 4, depth_kernel = 3, dropout = 0.): super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))), Residual(PreNorm(dim, LeFF(dim, scale, depth_kernel))) ])) def forward(self, x): c = list() for attn, leff in self.layers: x = attn(x) cls_tokens = x[:, 0] c.append(cls_tokens) x = leff(x[:, 1:]) x = torch.cat((cls_tokens.unsqueeze(1), x), dim=1) return x, torch.stack(c).transpose(0, 1)
模块3:LCA
class LCAttention(nn.Module): def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): super().__init__() inner_dim = dim_head * heads project_out = not (heads == 1 and dim_head == dim) self.heads = heads self.scale = dim_head ** -0.5 self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Sequential( nn.Linear(inner_dim, dim), nn.Dropout(dropout) ) if project_out else nn.Identity() def forward(self, x): b, n, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) q = q[:, :, -1, :].unsqueeze(2) # Only Lth element use as query dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale attn = dots.softmax(dim=-1) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') out = self.to_out(out) return out class LCA(nn.Module): # I remove Residual connection from here, in paper author didn't explicitly mentioned to use Residual connection, # so I removed it, althougth with Residual connection also this code will work. def __init__(self, dim, heads, dim_head, mlp_dim, dropout = 0.): super().__init__() self.layers = nn.ModuleList([]) self.layers.append(nn.ModuleList([ PreNorm(dim, LCAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) ])) def forward(self, x): for attn, ff in self.layers: x = attn(x) + x[:, -1].unsqueeze(1) x = x[:, -1].unsqueeze(1) + ff(x) return x
https://arxiv.org/abs/2103.11816
https://github.com/rishikksh20/CeiT-pytorch/blob/master/ceit.py
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。