当前位置:   article > 正文

GitHub:ViT-pytorch相关学习-视觉分类方向-1_vit github

vit github

GitHub - lucidrains/vit-pytorch: Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Vision Transformer的实现,在视觉分类中只需要一个transformer就能实现SOTA。 

不涉及过多的代码,以此为基础进行实验,就可以加快注意力革命。(有点像集成了一个工具?)

基于预训练模型的实验,可参考此处!

1.安装vit-pytorch

pip install vit-pytorch

2.使用教程

  1. import torch
  2. from vit_pytorch import ViT
  3. v = ViT(
  4. image_size = 256,
  5. patch_size = 32,
  6. num_classes = 1000,
  7. dim = 1024,
  8. depth = 6,
  9. heads = 16,
  10. mlp_dim = 2048,
  11. dropout = 0.1,
  12. emb_dropout = 0.1
  13. )
  14. img = torch.randn(1, 3, 256, 256)
  15. preds = v(img) # (1, 1000)

3.参数说明

  • image_size: int, 图像为矩形时,应当保证其取值为长宽中的最大值。

  • patch_size: int,图像划分时的单位尺寸,patch数量为  n=(image_size//patch_size)**2,同时,patch的数量必须大于16。

  • num_classes: int,要分类的数量。(os:这个参数要注意一下

  • dim: int,线性变换后输出张量tensor的最后维度 nn.Linear(..., dim)。

  • depth: int,Transformer块的数量。(Q:Transformer块的概念

  • heads: int,多头注意力层的头数量。

  • mlp_dim: int,MLP(前向)层的维度

  • channels: int,默认3(RGB),图像的通道数。

  • dropout: float between [0,1], 默认0。衰退率。

  • emb_dropout: float between [0,1], 默认0。嵌入衰退率。

  • pool: string,cls token池化或平均池化。

4.Simple ViT

简单的ViT包含2维余弦位置嵌入,全局平均池化(没有cls token),没有衰退,批处理大小为1024而不是4096,使用了随机增强和混合增强。他们还表明,最后使用一个简单的线性处理所得到的效果与原来的MLP头相比效果无明显差异。

Paper

  1. import torch
  2. from vit_pytorch import SimpleViT
  3. v = SimpleViT(
  4. image_size = 256,
  5. patch_size = 32,
  6. num_classes = 1000,
  7. dim = 1024,
  8. depth = 6,
  9. heads = 16,
  10. mlp_dim = 2048
  11. )
  12. img = torch.randn(1, 3, 256, 256)
  13. preds = v(img) # (1, 1000)

5.Distillation

使用蒸馏token从卷积网络提取知识到视觉变压器,可以产生小型和高效的视觉transformer。这个存储库提供了轻松进行蒸馏的方法。

例如. distilling from Resnet50 (or any teacher) to a vision transformer

  1. import torch
  2. from torchvision.models import resnet50
  3. from vit_pytorch.distill import DistillableViT, DistillWrapper
  4. teacher = resnet50(pretrained = True)
  5. v = DistillableViT(
  6. image_size = 256,
  7. patch_size = 32,
  8. num_classes = 1000,
  9. dim = 1024,
  10. depth = 6,
  11. heads = 8,
  12. mlp_dim = 2048,
  13. dropout = 0.1,
  14. emb_dropout = 0.1
  15. )
  16. distiller = DistillWrapper(
  17. student = v,
  18. teacher = teacher,
  19. temperature = 3, # temperature of distillation
  20. alpha = 0.5, # trade between main loss and distillation loss
  21. hard = False # whether to use soft or hard distillation
  22. )
  23. img = torch.randn(2, 3, 256, 256)
  24. labels = torch.randint(0, 1000, (2,))
  25. loss = distiller(img, labels)
  26. loss.backward()
  27. # after lots of training above ...
  28. pred = v(img) # (2, 1000)

除了处理前向传递的方式不同,DistillableViT类与ViT相同,因此在完成蒸馏训练后,能够将参数加载回ViT。

还可以在DistillableViT实例上使用方便的.to_vit方法来返回一个ViT实例。

  1. v = v.to_vit()
  2. type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>

6.DeepViT

研究增加ViT的层数,即网络深度(过去的12层),并建议混合每个头部的注意力后softmax作为一个解决方案,称为重新注意。研究结果与NLP的Talking Heads论文一致。

  1. import torch
  2. from vit_pytorch.deepvit import DeepViT
  3. v = DeepViT(
  4. image_size = 256,
  5. patch_size = 32,
  6. num_classes = 1000,
  7. dim = 1024,
  8. depth = 6,
  9. heads = 16,
  10. mlp_dim = 2048,
  11. dropout = 0.1,
  12. emb_dropout = 0.1
  13. )
  14. img = torch.randn(1, 3, 256, 256)
  15. preds = v(img) # (1, 1000)

7.CaiT

指出了更深入训练视觉变压器困难,并提出了两种解决方案。首先,它提出对剩余块的输出逐通道相乘。其次,它建议让补丁相互关注,只允许CLS令牌关注最后几层的补丁。他们还添加了Talking Heads,提出改进。

  1. import torch
  2. from vit_pytorch.cait import CaiT
  3. v = CaiT(
  4. image_size = 256,
  5. patch_size = 32,
  6. num_classes = 1000,
  7. dim = 1024,
  8. depth = 12, # depth of transformer for patch to patch attention only
  9. cls_depth = 2, # depth of cross attention of CLS tokens to patch
  10. heads = 16,
  11. mlp_dim = 2048,
  12. dropout = 0.1,
  13. emb_dropout = 0.1,
  14. layer_dropout = 0.05 # randomly dropout 5% of the layers
  15. )
  16. img = torch.randn(1, 3, 256, 256)
  17. preds = v(img) # (1, 1000)

8.Token-to-Token ViT

提出前两层通过展开对图像序列进行下采样,使每个令牌的图像数据重叠,如图所示。

  1. import torch
  2. from vit_pytorch.t2t import T2TViT
  3. v = T2TViT(
  4. dim = 512,
  5. image_size = 224,
  6. depth = 5,
  7. heads = 8,
  8. mlp_dim = 512,
  9. num_classes = 1000,
  10. t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module
  11. )
  12. img = torch.randn(1, 3, 224, 224)
  13. preds = v(img) # (1, 1000)

9.CCT

CCT提出了使用卷积而不是补丁和执行序列池的紧凑变压器。这使得CCT具有高精度和低数量的参数。

  1. import torch
  2. from vit_pytorch.cct import CCT
  3. cct = CCT(
  4. img_size = (224, 448),
  5. embedding_dim = 384,
  6. n_conv_layers = 2,
  7. kernel_size = 7,
  8. stride = 2,
  9. padding = 3,
  10. pooling_kernel_size = 3,
  11. pooling_stride = 2,
  12. pooling_padding = 1,
  13. num_layers = 14,
  14. num_heads = 6,
  15. mlp_radio = 3.,
  16. num_classes = 1000,
  17. positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
  18. )
  19. img = torch.randn(1, 3, 224, 448)
  20. pred = cct(img) # (1, 1000)

或者,也可以使用几个预定义的模型[2,4,6,7,8,14,16],这些模型预先定义了层数、注意头数量、mlp比例和嵌入维度。

  1. import torch
  2. from vit_pytorch.cct import cct_14
  3. cct = cct_14(
  4. img_size = 224,
  5. n_conv_layers = 1,
  6. kernel_size = 7,
  7. stride = 2,
  8. padding = 3,
  9. pooling_kernel_size = 3,
  10. pooling_stride = 2,
  11. pooling_padding = 1,
  12. num_classes = 1000,
  13. positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
  14. )

10.Cross ViT

本文提出用两个视觉transformer对图像进行不同尺度的处理,每隔一段时间交叉处理一个图像。它们展示了在基本视觉转换器上的改进。

  1. import torch
  2. from vit_pytorch.cross_vit import CrossViT
  3. v = CrossViT(
  4. image_size = 256,
  5. num_classes = 1000,
  6. depth = 4, # number of multi-scale encoding blocks
  7. sm_dim = 192, # high res dimension
  8. sm_patch_size = 16, # high res patch size (should be smaller than lg_patch_size)
  9. sm_enc_depth = 2, # high res depth
  10. sm_enc_heads = 8, # high res heads
  11. sm_enc_mlp_dim = 2048, # high res feedforward dimension
  12. lg_dim = 384, # low res dimension
  13. lg_patch_size = 64, # low res patch size
  14. lg_enc_depth = 3, # low res depth
  15. lg_enc_heads = 8, # low res heads
  16. lg_enc_mlp_dim = 2048, # low res feedforward dimensions
  17. cross_attn_depth = 2, # cross attention rounds
  18. cross_attn_heads = 8, # cross attention heads
  19. dropout = 0.1,
  20. emb_dropout = 0.1
  21. )
  22. img = torch.randn(1, 3, 256, 256)
  23. pred = v(img) # (1, 1000)

11.PiT

提出通过使用深度卷积的池化过程向下采样令牌。

  1. import torch
  2. from vit_pytorch.pit import PiT
  3. v = PiT(
  4. image_size = 224,
  5. patch_size = 14,
  6. dim = 256,
  7. num_classes = 1000,
  8. depth = (3, 3, 3), # list of depths, indicating the number of rounds of each stage before a downsample
  9. heads = 16,
  10. mlp_dim = 2048,
  11. dropout = 0.1,
  12. emb_dropout = 0.1
  13. )
  14. # forward pass now returns predictions and the attention maps
  15. img = torch.randn(1, 3, 224, 224)
  16. preds = v(img) # (1, 1000)

未完待续。。。。。。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/爱喝兽奶帝天荒/article/detail/810362
推荐阅读
相关标签
  

闽ICP备14008679号