当前位置:   article > 正文

【一站式梳理】ViT - Vision Transformer 流程+代码 学习记录_visiontransform mlphead

visiontransform mlphead

ViT Paper: AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE, ICLR 2021.

目录

ViT Paper: AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE, ICLR 2021.

模型流程:

输入:

Transformer Encoder:

MLP head 分类 输出

训练

代码解析:请参考


模型流程:

输入:

1)原始图像:

        将图像按顺序分成指定的Patches,输入Linear Projection后进行Flatten操作。

2)位置编码:

 举个例子讲下transformer的输入输出细节及其他 - 知乎

(pytorch进阶之路)四种Position Embedding的原理及实现-CSDN博客

Learned Positional Embedding ,这个是绝对位置编码,即直接对不同的位置随机初始化一个postion embedding,这个postion embedding作为参数进行训练。(1D PE)

Sinusoidal Position Embedding ,相对位置编码,即三角函数编码。(2D PE)

 ViT使用1D位置编码得到position embedding,因为实验表明使用1DPE和2DPE的对性能影响不大。

  1. import torch
  2. import torch.nn as nn
  3. def create_1d_learnable_embedding(pos_len, dim):
  4. pos_emb = nn.Embedding(pos_len, dim)
  5. # 初始化成全0
  6. nn.init.constant_(pos_emb.weight, 0)
  7. return pos_emb

 3) Class token

class token的embedding被随机初始化并与pos embedding相加,论文里面是class token是放在首位,也就是第0个位置. VIT中特殊class token的一些问题-CSDN博客

  1. # 随机初始化
  2. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  3. # Classifier head
  4. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  5. # 具体forward过程
  6. B = x.shape[0]
  7. x = self.patch_embed(x)
  8. # stole cls_tokens impl from Phil Wang, thanks
  9. cls_tokens = self.cls_token.expand(B, -1, -1)
  10. x = torch.cat((cls_tokens, x), dim=1)
  11. x = x + self.pos_embed

Transformer Encoder:

所有flatten之后的Patches与class token + PE执行stacked或者concatenated输入encoder中。

模型图片来源:霹雳吧啦Wz

Multi-head Self-attention & 应用到图片_多头注意力机制的图像应用-CSDN博客

10.6. 自注意力和位置编码 — 动手学深度学习 2.0.0 documentation

MLP head 分类 输出

ViT 模型中只使用了 class token 的输出,将其送入 MLP 模块中,去输出最终的分类结果。class token的输出里包含了其他patches的综合编码信息。

训练

  1. 在较大的数据集上预训练;
  2. 在下游数据集上微调用于图像分类。

代码解析:请参考

【超详细】初学者包会的Vision Transformer(ViT)的PyTorch实现代码学习_vit_base_patch16_224_in21k模型-CSDN博客Vision Transformer(ViT)PyTorch代码全解析(附图解)_vit代码-CSDN博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/寸_铁/article/detail/875181
推荐阅读
相关标签
  

闽ICP备14008679号