赞
踩
在VisionTransformer模型中,使用一个二维的卷积核,将图片展开成一个patch序列
patch_embed = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
通过训练一个位置编码参数来学习记录图片的位置信息
num_patches
为图片展开的patch数目,加一是包含了cls_token,详细请阅读VisionTransformer论文
pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
部分代码实现
# 对图片进行展开操作 class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): super().__init__() num_patches = (img_size // patch_size) * (img_size // patch_size) self.img_size = img_size self.patch_size = patch_size self.num_patches = num_patches self.proj = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x: Tensor): B, C, H, W = x.shape x = self.proj(x) x = x.flatten(2) x = x.transpose(1, 2) return x x = PatchEmbed(img) # 添加位置信息 x = x + pos_embed
对于一个已经训练好的VisionTransformer模型,如何将学习的位置信息转换到一张任意分辨率的图片上
# 代码引用自https://github.com/facebookresearch/dino/blob/main/vision_transformer.py def interpolate_pos_encoding(self, x, w, h): npatch = x.shape[1] - 1 # ,减去cls_token,得到输入图片的patch数量 N = self.pos_embed.shape[1] - 1 # 原模型训练时,patch数量 if npatch == N and w == h: #输入图片符合训练模型时的图片大小,可直接使用训练好的位置编码信息 return self.pos_embed # 将位置编码转换到一张任意分辨率的图片上 class_pos_embed = self.pos_embed[:, 0] # 提取cls_token的位置编码信息 patch_pos_embed = self.pos_embed[:, 1:] # 提取图片patch序列的位置编码信息 dim = x.shape[-1] # 对图片进行patch分割 w0 = w // self.patch_embed.patch_size h0 = h // self.patch_embed.patch_size w0, h0 = w0 + 0.1, h0 + 0.1 # 根据给定的size或scale_factor参数来对输入进行下/上采样 # 将原本在224*224训练得到的位置编码信息,转换到任意大小图片上 patch_pos_embed = nn.functional.interpolate( patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), # 指定输出为输入的多少倍数。 mode='bicubic', ) assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) # img已通过transform转换,并添加batch信息 B, c, w, h = img.shape x = PatchEmbed(img) # 添加cls_token,保持与模型一致 x = torch.cat((cls_tokens, x), dim=1) # 添加位置信息 x = x + interpolate_pos_encoding(self, x, w, h)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。