当前位置:   article > 正文

如何用pytorch调用预训练Swin Transformer中的一个Swin block模块_swin transformer加载预训练

swin transformer加载预训练

1,首先,我们需要知道的是,想要调用预训练的Swin Transformer模型,必须要安装pytorch2,因为pytorch1对应的torchvision中不包含Swin Transformer。

2,pytorch2调用预训练模型时,不建议使用pretrained=True,这个用法即将淘汰,会报警告。最好用如下方式:

  1. from torchvision.models.swin_transformer import swin_b, Swin_B_Weights
  2. model = swin_b(weights=Swin_B_Weights.DEFAULT)

这里调用的就是swin_b在imagenet上的预训练模型

3,swin_b的模型结构如下(仅展示到第一个patch merging部分),在绝大部分情况下,我们可能需要的不是整个模型,而是其中的一个模块,比如SwinTransformerBlock。

  1. SwinTransformer(
  2. (features): Sequential(
  3. (0): Sequential(
  4. (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
  5. (1): Permute()
  6. (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  7. )
  8. (1): Sequential(
  9. (0): SwinTransformerBlock(
  10. (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  11. (attn): ShiftedWindowAttention(
  12. (qkv): Linear(in_features=128, out_features=384, bias=True)
  13. (proj): Linear(in_features=128, out_features=128, bias=True)
  14. )
  15. (stochastic_depth): StochasticDepth(p=0.0, mode=row)
  16. (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  17. (mlp): MLP(
  18. (0): Linear(in_features=128, out_features=512, bias=True)
  19. (1): GELU(approximate='none')
  20. (2): Dropout(p=0.0, inplace=False)
  21. (3): Linear(in_features=512, out_features=128, bias=True)
  22. (4): Dropout(p=0.0, inplace=False)
  23. )
  24. )
  25. (1): SwinTransformerBlock(
  26. (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  27. (attn): ShiftedWindowAttention(
  28. (qkv): Linear(in_features=128, out_features=384, bias=True)
  29. (proj): Linear(in_features=128, out_features=128, bias=True)
  30. )
  31. (stochastic_depth): StochasticDepth(p=0.021739130434782608, mode=row)
  32. (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  33. (mlp): MLP(
  34. (0): Linear(in_features=128, out_features=512, bias=True)
  35. (1): GELU(approximate='none')
  36. (2): Dropout(p=0.0, inplace=False)
  37. (3): Linear(in_features=512, out_features=128, bias=True)
  38. (4): Dropout(p=0.0, inplace=False)
  39. )
  40. )
  41. )
  42. (2): PatchMerging(
  43. (reduction): Linear(in_features=512, out_features=256, bias=False)
  44. (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  45. )

那么如何调用其中的SwinTransformerBlock呢。

由于该模型是个嵌套结构,而不是类似vgg一样简单的结构,所以不能直接用layer0=model.SwinTransformerBlock调用。

因为SwinTransformerBlock是Sequential下的子模块,故正确的调用代码如下:

swinblock = model.features[1][0]

结果如下,调用成功:

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/空白诗007/article/detail/799643
推荐阅读
相关标签
  

闽ICP备14008679号