赞
踩
1,首先,我们需要知道的是,想要调用预训练的Swin Transformer模型,必须要安装pytorch2,因为pytorch1对应的torchvision中不包含Swin Transformer。
2,pytorch2调用预训练模型时,不建议使用pretrained=True,这个用法即将淘汰,会报警告。最好用如下方式:
- from torchvision.models.swin_transformer import swin_b, Swin_B_Weights
-
- model = swin_b(weights=Swin_B_Weights.DEFAULT)
这里调用的就是swin_b在imagenet上的预训练模型
3,swin_b的模型结构如下(仅展示到第一个patch merging部分),在绝大部分情况下,我们可能需要的不是整个模型,而是其中的一个模块,比如SwinTransformerBlock。
- SwinTransformer(
- (features): Sequential(
- (0): Sequential(
- (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
- (1): Permute()
- (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
- )
- (1): Sequential(
- (0): SwinTransformerBlock(
- (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
- (attn): ShiftedWindowAttention(
- (qkv): Linear(in_features=128, out_features=384, bias=True)
- (proj): Linear(in_features=128, out_features=128, bias=True)
- )
- (stochastic_depth): StochasticDepth(p=0.0, mode=row)
- (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
- (mlp): MLP(
- (0): Linear(in_features=128, out_features=512, bias=True)
- (1): GELU(approximate='none')
- (2): Dropout(p=0.0, inplace=False)
- (3): Linear(in_features=512, out_features=128, bias=True)
- (4): Dropout(p=0.0, inplace=False)
- )
- )
- (1): SwinTransformerBlock(
- (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
- (attn): ShiftedWindowAttention(
- (qkv): Linear(in_features=128, out_features=384, bias=True)
- (proj): Linear(in_features=128, out_features=128, bias=True)
- )
- (stochastic_depth): StochasticDepth(p=0.021739130434782608, mode=row)
- (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
- (mlp): MLP(
- (0): Linear(in_features=128, out_features=512, bias=True)
- (1): GELU(approximate='none')
- (2): Dropout(p=0.0, inplace=False)
- (3): Linear(in_features=512, out_features=128, bias=True)
- (4): Dropout(p=0.0, inplace=False)
- )
- )
- )
- (2): PatchMerging(
- (reduction): Linear(in_features=512, out_features=256, bias=False)
- (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
- )
那么如何调用其中的SwinTransformerBlock呢。
由于该模型是个嵌套结构,而不是类似vgg一样简单的结构,所以不能直接用layer0=model.SwinTransformerBlock调用。
因为SwinTransformerBlock是Sequential下的子模块,故正确的调用代码如下:
swinblock = model.features[1][0]
结果如下,调用成功:
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。