当前位置:   article > 正文

Pytorch学习之VisionTransformer图片位置编码实现_神经网络中处理图像位置的位置编码器

神经网络中处理图像位置的位置编码器

Pytorch学习之图片位置编码

前提

在VisionTransformer模型中,使用一个二维的卷积核,将图片展开成一个patch序列

patch_embed = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
  • 1

通过训练一个位置编码参数来学习记录图片的位置信息

num_patches为图片展开的patch数目,加一是包含了cls_token,详细请阅读VisionTransformer论文

pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
  • 1

部分代码实现

# 对图片进行展开操作
class PatchEmbed(nn.Module):

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x: Tensor):
        B, C, H, W = x.shape
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)

        return x

x = PatchEmbed(img)

# 添加位置信息
x = x + pos_embed
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24

位置编码转换

对于一个已经训练好的VisionTransformer模型,如何将学习的位置信息转换到一张任意分辨率的图片上

# 代码引用自https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
def interpolate_pos_encoding(self, x, w, h):
        npatch = x.shape[1] - 1	# ,减去cls_token,得到输入图片的patch数量
        N = self.pos_embed.shape[1] - 1 # 原模型训练时,patch数量
        
        if npatch == N and w == h: #输入图片符合训练模型时的图片大小,可直接使用训练好的位置编码信息
            return self.pos_embed
        
        # 将位置编码转换到一张任意分辨率的图片上
        
        class_pos_embed = self.pos_embed[:, 0]	# 提取cls_token的位置编码信息
        patch_pos_embed = self.pos_embed[:, 1:]	# 提取图片patch序列的位置编码信息
        dim = x.shape[-1]
        
        # 对图片进行patch分割
        w0 = w // self.patch_embed.patch_size
        h0 = h // self.patch_embed.patch_size
        w0, h0 = w0 + 0.1, h0 + 0.1
        
        # 根据给定的size或scale_factor参数来对输入进行下/上采样
        # 将原本在224*224训练得到的位置编码信息,转换到任意大小图片上
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
            scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), # 指定输出为输入的多少倍数。
            mode='bicubic',
        ) 
        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
        
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

# img已通过transform转换,并添加batch信息
B, c, w, h = img.shape
x = PatchEmbed(img)

# 添加cls_token,保持与模型一致
x = torch.cat((cls_tokens, x), dim=1)

# 添加位置信息
x = x + interpolate_pos_encoding(self, x, w, h)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/你好赵伟/article/detail/763125
推荐阅读
相关标签
  

闽ICP备14008679号