当前位置:   article > 正文

「解析」Vision Transformer 在图像分类中的应用_vision transformer 网络对动物图像分类数据集的分类,包含训练权重和数据集

vision transformer 网络对动物图像分类数据集的分类,包含训练权重和数据集

在这里插入图片描述

An Image is Worth 16x16 Words:Transformers for Image Recognition at Scale

代码:https://github.com/google-research/vision_transformer


小序

ViT(Vision Transformer) 是直接将Transformer直接应用在图像,经过微调:将图像拆成16x16 patch,然后将patch 的 the sequence of linear embedding 作为 Transformers 的输入。

  1. 数据处理部分:
    先对图片作分块,再将每个图片块展平成一维向量
  2. 数据嵌入部分:
    Patch Embedding:对每个向量都做一个线性变换
    Positional Encoding:通过一个可学习向量加入序列的位置信息
  3. 编码部分:
    class_token:额外的一个用于分类的可学习向量,与输入进行拼接
  4. 分类部分:
    mlp_head:使用 LayerNorm 和两层全连接层实现的,采用的是GELU激活函数

但是实验表明,在中等尺寸数据集训练后,分类正确率相比于ResNet上往往降低几个百分点,这是由于transformer缺乏CNN的固有的inductive bias 如 translation equivariance and locality,因而在数据不充分情况时不能很好泛化。而在数据尺寸足够的情况下训练transfprmer,是能够应对这种inductive bias,实现对流行模型的性能逼近甚至超越。

1. 什么是CNN 的 inductive bisa?
表现为:transformers 在小数据上的预测正确率比 CNN 低,当采用混合结构时(即将CNN的输出特征作为输入序列时,尽在小数据上实现性能提升),这与我们预期有差,期望CNN的引入能够提升所有尺寸训练样本下的性能。就是凭借一些规律得出的偏好:如CNN天然的对图像处理的较好,天然的具有平移不变性等;

2. Patch 如何理解?
patch 是将 3 维图像 reshape 为2维之后进行切分,使用的 position embedding 是1维,将 patch作为一个小整体,然后对patch在整个图像中的位置进行编码,还是按照分割后的位置信息。


1、ViT原理分析:

这个工作本着尽可能少修改的原则,将原版的Transformer开箱即用地迁移到分类任务上面。并且作者认为没有必要总是依赖于CNN,只用Transformer也能够在分类任务中表现很好,尤其是在使用大规模训练集的时候。同时,在大规模数据集上预训练好的模型,在迁移到中等数据集或小数据集的分类任务上以后,也能取得比CNN更优的性能。下面看具体的方法:

这个工作首先把 x ∈ H × W × C x\in H \times W \times C xH×W×C 的图像,变成一个 x p ∈ N × ( P 2 ⋅ C ) x_p \in N \times (P^2 \cdot C) xpN×(P2C) 的sequence of flattened 2D patches。它可以看做是一系列的展平的2D块的序列,这个序列中一共有 N = H W / P 2 N =HW/P^2 N=HW/P2 个展平的2D块,每个块的维度是 ( P 2 × C ) (P^2\times C) (P2×C) 。其中 P P P 是块大小, C C C 是channel数。

注意作者做这步变化的意图:
根据之前的讲解,Transformer希望输入一个二维的矩阵 ( N , D ) (N,D) (N,D) ,其中 N N N 是sequence的长度, D D D 是sequence的每个向量的维度,常用256。所以这里也要设法把 H × W × C H\times W \times C H×W×C 的三维图片转化成 ( N , D ) (N,D) (N,D) 的二维输入。

所以有: H × W × C → N × ( P 2 ⋅ C ) H \times W \times C \to N \times (P^2 \cdot C) H×W×CN×(P2C),where N = H W / P 2 N=HW/P^2 N=HW/P2

其中, N N N 是Transformer输入的sequence的长度。

代码是:

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=p, p2=p)
  • 1

具体是采用了einops库实现,具体可以参考这篇博客,科技猛兽:PyTorch 70.einops:优雅地操作张量维度

现在得到的向量维度是: x p ∈ N × ( P 2 × C ) x_p \in N \times (P^2 \times C) xpN×(P2×C) ,要转化成 ( N , D ) (N,D) (N,D) 的二维输入,我们还需要做一步叫做Patch Embedding的步骤。


1.1 Patch Embedding

方法是对每个向量都做一个线性变换(即全连接层),压缩后的维度为 D D D ,这里我们称其为 Patch Embedding。

z 0 = [ x c l a s s ; x p 1 E ; x p 2 E ; . . . . ; x p n E ] + E p o s (1) z_0 = [\color{green}x_{class}; \color{back} x_p^1E; x_p^2E; .... ; x_p^nE]+ E_{pos} \tag1 z0=[xclass;xp1E;xp2E;....;xpnE]+Epos(1)

这个全连接层就是上式(5.1)中的 E \color{red}E E ,它的输入维度大小是 ( P 2 ⋅ C ) (P^2 \cdot C) (P2C) ,输出维度大小是 D D D

# 将3072变成dim,假设是1024
self.patch_to_embedding = nn.Linear(patch_dim, dim)
x = self.patch_to_embedding(x)
  • 1
  • 2
  • 3

注意这里的绿色字体 x c l a s s \color{green}x_{class} xclass ,假设切成9个块,但是最终到Transfomer输入是10个向量,这是人为增加的一个向量。


为什么要追加这个向量?

如果没有这个向量,假设 N = 9 N=9 N=9 个向量输入Transformer Encoder,输出9个编码向量,然后呢?对于分类任务而言,我应该取哪个输出向量进行后续分类呢?
不知道。干脆就再来一个向量 x c l a s s ( v e c t o r , d i m = D ) \color{green}x_{class}(vector ,dim =D) xclass(vector,dim=D) ,这个向量是可学习的嵌入向量,它和那9个向量一并输入Transfomer Encoder,输出1+9个编码向量。然后就用第0个编码向量,即 x c l a s s \color{green}x_{class} xclass 的输出进行分类预测即可。

这么做的原因可以理解为:ViT其实只用到了Transformer的Encoder,而并没有用到Decoder,而 x c l a s s \color{green}x_{class} xclass 的作用有点类似于解码器中的 Q u e r y Query Query 的作用,相对应的 K e y , V a l u e Key, Value Key,Value 就是其他9个编码向量的输出。 x c l a s s \color{green}x_{class} xclass 是一个可学习的嵌入向量,它的意义说通俗一点为:寻找其他9个输入向量对应的 i m a g e image image 的类别。

代码为:

# dim=1024
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

# forward前向代码
# 变成(b,64,1024)
cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
# 跟前面的分块进行concat
# 额外追加token,变成b,65,1024
x = torch.cat((cls_tokens, x), dim=1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

1.2 Positional Encoding

按照Transformer的位置编码的习惯,这个工作也使用了位置编码。引入了一个 Positional encoding E p o s \color{violet}E_{pos} Epos来加入序列的位置信息,同样在这里也引入了pos_embedding,是用一个可训练的变量。

z 0 = [ x c l a s s ; x p 1 E ; x p 2 E ; . . . . ; x p n E ] + E p o s z_0 = [x_{class}; x_p^1E; x_p^2E; .... ; x_p^nE]+ \color{violet}E_{pos} z0=[xclass;xp1E;xp2E;....;xpnE]+Epos

没有采用原版Transformer的 s i n c o s sincos sincos 编码,而是直接设置为可学习的Positional Encoding,效果差不多。对训练好的pos_embedding进行可视化,如下图所示。
我们发现,位置越接近,往往具有更相似的位置编码。此外,出现了行列结构;同一行/列中的patch具有相似的位置编码。

在这里插入图片描述

# num_patches=64,dim=1024,+1是因为多了一个cls开启解码标志
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
  • 1
  • 2

1.3 Transformer Encoder的前向过程

z 0 = [ x c l a s s ; x p 1 E ; x p 2 E ; . . . . ; x p n E ] + E p o s , E ∈ R P 2 × C × D , E p o s ∈ R ( N + 1 ) × D (2) z_0 = [x_{class}; x_p^1E; x_p^2E; .... ; x_p^nE]+ E_{pos}, \qquad \qquad E\in \mathbb{R}^{P^2 \times C\times D}, E_{pos} \in \mathbb{R}^{(N+1)\times D} \tag2 z0=[xclass;xp1E;xp2E;....;xpnE]+Epos,ERP2×C×D,EposR(N+1)×D(2)

z ℓ ′ = M S A ( L N ( z ℓ − 1 ) ) + z ℓ − 1 , ℓ = 1... L (3) {z}'_\ell = \color{violet}MSA(LN(z_{\ell-1}))+z_{\ell-1}, \qquad \qquad \color{back}\ell=1...L \qquad \qquad \qquad \tag3 z=MSA(LN(z1))+z1,=1...L(3)

z ℓ = M L P ( L N ( z ℓ ′ ) ) + z e l l ′ , ℓ = 1... L (4) z_{\ell} = \color{blue}MLP(LN({z}'_\ell))+{z}'_{ell}, \qquad \qquad \color{back} \ell=1...L \qquad \qquad \quad \tag4 z=MLP(LN(z))+zell,=1...L(4)

y = L N ( z ℓ 0 ) (5) y = LN(z^0_{\ell}) \qquad\qquad\qquad \tag5 y=LN(z0)(5)

  • 其中,第1个式子为上面讲到的 Patch Embedding 和 Positional Encoding 的过程。
  • 第2个式子为Transformer Encoder的 M u l t i − h e a d S e l f − A t t e n t i o n , A d d a n d N o r m \color{violet}Multi-head \quad Self-Attention, Add and Norm MultiheadSelfAttention,AddandNorm 的过程,重复 L L L 次。
  • 第3个式子为Transformer Encoder的 F e e d F o r w a r d , A d d a n d N o r m \color{blue}Feed Forward, AddandNorm FeedForward,AddandNorm 的过程,重复 L L L 次。

作者采用的是没有任何改动的 Transformer。

最后是一个 M L P MLP MLP C l a s s f i c a t i o n − H e a d Classfication - Head ClassficationHead ,整个的结构只有这些,如下图所示,为了方便读者的理解,我把变量的维度变化过程标注在了图中。

x  = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
  • 1

在这里插入图片描述


1.4 训练方法:

先在大数据集上预训练,再迁移到小数据集上面。做法是把ViT的 p r e d i c t i o n − h e a d \color{violet}prediction-head predictionhead 去掉,换成一个 D × K D \times K D×K F e e d F o r w a r d L a y e r \color{violet}FeedForwardLayer FeedForwardLayer 。其中 K K K 为对应数据集的类别数。

当输入的图片是更大的shape时,patch size P P P 保持不变,则 N = H W / P 2 N=HW/P^2 N=HW/P2 会增大。

ViT可以处理任意 N N N 的输入,但是Positional Encoding是按照预训练的输入图片的尺寸设计的,所以输入图片变大之后,Positional Encoding需要根据它们在原始图像中的位置做2D插值。


1.5 最后,展示下ViT的动态过程:

ViT的动态过程

在这里插入图片描述

整个流程:

  • 一个图片256x256,分成了64个32x32的patch;
  • 对这么多的patch做embedding,成64个1024向量;
  • 再拼接一个cls_tokens,变成65个1024向量;
  • 再加上pos_embedding,还是65个1024向量;
  • 这些向量输入到transformer中进行自注意力的特征提取;
  • 输出的是64个1024向量,然后对这个50个求均值,变成一个1024向量;
  • 然后线性层把1024维变成 mlp_head维从而完成分类任务的transformer模型。

1.6 Experiments:

预训练模型使用到的数据集有:

  • ILSVRC-2012 ImageNet dataset:1000 classes
  • ImageNet-21k:21k classes
  • JFT:18k High Resolution Images

将预训练迁移到的数据集有:

  • CIFAR-10/100
  • Oxford-IIIT Pets
  • Oxford Flowers-102
  • VTAB

作者设计了3种不同答小的ViT模型,它们分别是:

DModelLayersHidden sizeMLP sizeHeadsParams
ViT-Base1276830721286M
ViT-Large241024409616307M
ViT-Huge321280512016632M

ViT-L/16代表ViT-Large + 16 patch size


评价指标 Metrics :

结果都是下游数据集上经过finetune之后的Accuracy,记录的是在各自数据集上finetune后的性能。

在这里插入图片描述


实验1:性能对比

实验结果如下图所示,整体模型还是挺大的,而经过大数据集的预训练后,性能也超过了当前CNN的一些SOTA结果。对比的CNN模型主要是:

2020年ECCV的Big Transfer (BiT)模型,它使用大的ResNet进行有监督转移学习。

2020年CVPR的Noisy Student模型,这是一个在ImageNet和JFT300M上使用半监督学习进行训练的大型高效网络,去掉了标签。

All models were trained on TPUv3 hardware。

在这里插入图片描述

在JFT-300M上预先训练的较小的ViT-L/16模型在所有任务上都优于BiT-L(在同一数据集上预先训练的),同时训练所需的计算资源要少得多。 更大的模型ViT-H/14进一步提高了性能,特别是在更具挑战性的数据集上——ImageNet, CIFAR-100和VTAB数据集。 与现有技术相比,该模型预训练所需的计算量仍然要少得多。

下图为VTAB数据集在Natural, Specialized, 和Structured子任务与CNN模型相比的性能,ViT模型仍然可以取得最优。

在这里插入图片描述


实验2:ViT对预训练数据的要求
ViT对于预训练数据的规模要求到底有多苛刻?

作者分别在下面这几个数据集上进行预训练:ImageNet, ImageNet-21k, 和JFT-300M。

结果如下图所示:

在这里插入图片描述

我们发现: 当在最小数据集ImageNet上进行预训练时,尽管进行了大量的正则化等操作,但ViT-大模型的性能不如ViT-Base模型

但是有了稍微大一点的ImageNet-21k预训练,它们的表现也差不多

只有到了JFT 300M,我们才能看到更大的ViT模型全部优势。 图3还显示了不同大小的BiT模型跨越的性能区域。BiT CNNs在ImageNet上的表现优于ViT(尽管进行了正则化优化),但在更大的数据集上,ViT超过了所有的模型,取得了SOTA。

作者还进行了一个实验: 在9M、30M和90M的随机子集以及完整的JFT300M数据集上训练模型,结果如下图所示。 ViT在较小数据集上的计算成本比ResNet高, ViT-B/32比ResNet50稍快;它在9M子集上表现更差, 但在90M+子集上表现更好。ResNet152x2和ViT-L/16也是如此。这个结果强化了一种直觉,即:

残差对于较小的数据集是有用的,但是对于较大的数据集,像attention一样学习相关性就足够了,甚至是更好的选择。

在这里插入图片描述


实验3:ViT的注意力机制Attention

作者还给了注意力观察得到的图片块, Self-attention使得ViT能够整合整个图像中的信息,甚至是最底层的信息。作者欲探究网络在多大程度上利用了这种能力。

具体来说,我们根据注意力权重计算图像空间中整合信息的平均距离,如下图所示。

在这里插入图片描述

注意这里我们只使用了attention,而没有使用CNN,所以这里的attention distance相当于CNN的receptive field的大小。
作者发现:在最底层, 有些head也已经注意到了图像的大部分,说明模型已经可以globally地整合信息了,说明它们负责global信息的整合。其他的head 只注意到图像的一小部分,说明它们负责local信息的整合。Attention Distance随深度的增加而增加。

整合局部信息的attention head在混合模型(有CNN存在)时,效果并不好,说明它可能与CNN的底层卷积有着类似的功能。

作者给出了attention的可视化,注意到了适合分类的位置:

在这里插入图片描述


2. ViT代码解读:

2.1 使用:

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)
mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to

preds = v(img, mask = mask) # (1, 1000)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 传入参数的意义: image_size:输入图片大小。
  • patch_size:论文中 patch size: 图片 的大小。
  • num_classes:数据集类别数。
  • dim:Transformer的隐变量的维度。
  • depth:Transformer的Encoder,Decoder的Layer数。
  • heads:Multi-head Attention
  • layer的head数。
  • mlp_dim:MLP层的hidden dim。
  • dropout:Dropout rate。
  • emb_dropout:Embedding dropout rate。

2.2 定义残差,FeedForward Layer 等:

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

Attention和Transformer,注释已标注在代码中:

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, mask = None):
		# b, 65, 1024, heads = 8
        b, n, _, h = *x.shape, self.heads

		# self.to_qkv(x): b, 65, 64*8*3
		# qkv: b, 65, 64*8
        qkv = self.to_qkv(x).chunk(3, dim = -1)

		# b, 65, 64, 8
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

		# dots:b, 65, 64, 64
        dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

		# attn:b, 65, 64, 64
        attn = dots.softmax(dim=-1)

		# 使用einsum表示矩阵乘法:
		# out:b, 65, 64, 8
        out = torch.einsum('bhij,bhjd->bhid', attn, v)

		# out:b, 64, 65*8
        out = rearrange(out, 'b h n d -> b n (h d)')

		# out:b, 64, 1024
        out =  self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            x = ff(x)
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63

2.3 Class ViT:

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_size = patch_size

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img, mask = None):
        p = self.patch_size

		# 图片分块
        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)

		# 降维(b,N,d)
        x = self.patch_to_embedding(x)
        b, n, _ = x.shape

		# 多一个可学习的x_class,与输入concat在一起,一起输入Transformer的Encoder。(b,1,d)
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)

		# Positional Encoding:(b,N+1,d)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

		# Transformer的输入维度x的shape是:(b,N+1,d)
        x = self.transformer(x, mask)

		# (b,1,d)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)	# (b,1,num_class)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52


2.4 ViT 模型完整代码

#  !/usr/bin/env  python
#  -*- coding:utf-8 -*-
# @Time   :  2021.
# @Author :  绿色羽毛
# @Email  :  lvseyumao@foxmail.com
# @Blog   :  https://blog.csdn.net/ViatorSun
# @Note   :




import torch
from   torch import nn, einsum
import torch.nn.functional as F
from   einops import rearrange, repeat
# from   einops.layers.torch import Rearrange




class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x




class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn   = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)




class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(   nn.Linear(dim, hidden_dim),
                                    nn.GELU(),
                                    nn.Dropout(dropout),
                                    nn.Linear(hidden_dim, dim),
                                    nn.Dropout(dropout) )
    def forward(self, x):
        return self.net(x)





class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout) )


    def forward(self, x, mask = None):
        # b, 65, 1024, heads = 8
        b, n, _ = x.shape
        h = self.heads
        # self.to_qkv(x): b, 65, 64*8*3
		# qkv: b, 65, 64*8
        qkv = self.to_qkv(x).chunk(3, dim = -1)     # 沿-1轴分为3块

        # b, 65, 64, 8
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        # dots:b, 65, 64, 64
        dots       =  torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max

        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

        # attn:b, 65, 64, 64
        attn = dots.softmax(dim=-1)

        # 使用einsum表示矩阵乘法:
		# out:b, 65, 64, 8
        out = torch.einsum('bhij,bhjd->bhid', attn, v)

        # out:b, 64, 65*8
        out = rearrange(out, 'b h n d -> b n (h d)')

        # out:b, 64, 1024
        out =  self.to_out(out)
        return out






class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([  Residual(PreNorm(dim, Attention(  dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                                                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))    ]))
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            x = ff(x)
        return x






class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim   = channels * patch_size ** 2

        # assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_size         = patch_size
        self.pos_embedding      = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token          = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout            = nn.Dropout(emb_dropout)
        self.transformer        = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool               = pool
        self.to_latent          = nn.Identity()

        self.mlp_head           = nn.Sequential(  nn.LayerNorm(dim), nn.Linear(dim, num_classes) )

    def forward(self, img, mask = None):
        p = self.patch_size

        # 图片分块
        # print(img.shape)
        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)    # 1,3,256,256  ->  1,64,3072

        # 降维(b,N,d)
        x       = self.patch_to_embedding(x)
        b, n, _ = x.shape

        # 多一个可学习的x_class,与输入concat在一起,一起输入Transformer的Encoder。(b,1,d)
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)

        # Positional Encoding:(b,N+1,d)
        x += self.pos_embedding[:, :(n + 1)]
        x  = self.dropout(x)

        # Transformer的输入维度x的shape是:(b,N+1,d)
        x = self.transformer(x, mask)

        # (b,1,d)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x)	# (b,1,num_class)








if __name__ == '__main__':
    v = ViT(image_size=256, patch_size=32, num_classes=10, dim=1024, depth=6, heads=16, mlp_dim=2048, dropout=0.1,
            emb_dropout=0.1)

    img = torch.randn(1, 3, 256, 256)
    mask = torch.ones(1, 8, 8).bool()  # optional mask, designating which patch to attend to

    preds = v(img, mask=mask)  # (1, 1000)
    print(preds)



  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/935109
推荐阅读
相关标签
  

闽ICP备14008679号