当前位置:   article > 正文

MAE技术总结_mae csdn

mae csdn

原理解读

        MAE 方法很简单:mask 输入图像的随机 patch,并重建缺失的像素。它基于两个核心设计。首先,作者开发了一种非对称编码器-解码器结构,其中的编码器仅对可见的 patch 子集(不带 mask token)进行操作,而轻量级解码器则从潜在表示和 mask token 重建原始图像。其次,作者发现对输入图像的高比例(例如 75%)进行 mask 会产生一项困难且有意义的自监督任务。将这两种设计结合起来,能够高效地训练大型模型:加快训练速度(3 倍或更多)并提高精度。

MAE 的结构如上图所示,与所有自动编码器一样,MAE 有一个编码器,将观察到的信号映射到潜在表示,还有一个解码器,从潜在表示重建原始信号。与经典的自动编码器不同,作者采用了一种非对称设计,允许编码器仅对部分观察信号(无mask token)进行操作,并采用一种轻量级解码器,从潜在表示和 mask token 重建完整信号。

具体来说,作者首先将图像划分为规则的非重叠 patch。然后,对一个子集的 patch 进行采样,并移除其余的 patch。然后将这些剩余的 patch 送入到编码器中,编码器是一个标准的 ViT 结构,由于编码器只处理很少一部分的 patch,因此可以用很少的计算和显存来训练非常大的编码器。编码器输出 token 后,作者在 mask 的位置加入了可学习的向量,组成完整的全套 token。

此外,作者向这个完整集合中的所有 token 添加位置嵌入;如果没有这一点,mask token 将没有关于其在图像中位置的信息。MAE 解码器仅在预训练期间用于执行图像重建任务(仅编码器用于生成用于识别的图像表示)。因此,可以以独立于编码器设计的方式灵活地设计解码器架构。作者用比编码器更窄、更浅的解码器进行实验。使用这种非对称设计,全套 token 仅由轻量级解码器处理,这大大减少了预训练时间。

代码实现(tiny版本)

  1. # MAE encoder
  2. class MAE_Encoder(torch.nn.Module):
  3. # cls_token 维度是(1,1,emb_dim)
  4. self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
  5. # 位置嵌入 维度是(img//patch **2,1,emb_dim)
  6. self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2, 1, emb_dim))
  7. # 打乱Patches并按mask_ratio筛选部分patches
  8. self.shuffle = PatchShuffle(mask_ratio)
  9. #按patch_size分割patch
  10. self.patchify = torch.nn.Conv2d(3, emb_dim, patch_size, patch_size)
  11. # transformer block
  12. self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
  13. # LN层
  14. self.layer_norm = torch.nn.LayerNorm(emb_dim)
  15. # 参数初始化
  16. self.init_weight()
  17. # 前向传播流程(2,3,32,32):
  18. #1、原图分割成patch (2, 192, 16, 16)——>rearrange(256,2,192)
  19. #2、位置嵌入
  20. #3、打乱并mask部分patch,只保留没被mask的patch (64,2,192)
  21. #4、cls_token嵌入 (65,2,192) ————>rearrange(2,256,192)
  22. #5、transformer&LayerNorm ————>rearrange(65,2,192)
  23. def forward(self, img):
  24. patches = self.patchify(img)
  25. patches = rearrange(patches, 'b c h w -> (h w) b c')
  26. patches = patches + self.pos_embedding
  27. patches, forward_indexes, backward_indexes = self.shuffle(patches)
  28. patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
  29. patches = rearrange(patches, 't b c -> b t c')
  30. features = self.layer_norm(self.transformer(patches))
  31. features = rearrange(features, 'b t c -> t b c')
  32. return features, backward_indexes
  33. # MAE Decoder
  34. class MAE_Decoder(torch.nn.Module):
  35. #定义mask_token,设置成全0,后续会初始化
  36. self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
  37. #位置嵌入
  38. self.pos_embedding = torch.nn.Parameter(torch.zeros((image_size // patch_size) ** 2 + 1, 1, emb_dim))
  39. #transfoemer block
  40. self.transformer = torch.nn.Sequential(*[Block(emb_dim, num_head) for _ in range(num_layer)])
  41. #线性层
  42. self.head = torch.nn.Linear(emb_dim, 3 * patch_size ** 2)
  43. #patch到原图的rerrange
  44. self.patch2img = Rearrange('(h w) b (c p1 p2) -> b c (h p1) (w p2)', p1=patch_size, p2=patch_size, h=image_size//patch_size)
  45. # 前向传播流程(2,3,32,32):
  46. #1、backward_indexes排序索引先增加一维(cls_token)
  47. #2、features增加backward_indexes-features维 (65, 2, 192)——>257, 2, 192,增加的是随机初始化的数值
  48. #3、按照backward_indexes重新排序features,排成原图对应位置的patch
  49. #4、位置嵌入————>rearrange(2,257,192)
  50. #5、transformer&LayerNorm ————>rearrange(257,2,192)
  51. #6、筛掉cls_token(256,2,192)
  52. #7、全连接层——>(256,2,12)
  53. #8、新建一个掩膜mask,被mask掉的patch位置值为1,其余是0
  54. #9、对这个掩膜重新排序,使得和原图对应
  55. #10、对patch和mask Rearrange成原图大小(2,3,32,32)
  56. def forward(self, features, backward_indexes):
  57. T = features.shape[0]
  58. backward_indexes = torch.cat([torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0)
  59. features = torch.cat([features, self.mask_token.expand(backward_indexes.shape[0] - features.shape[0], features.shape[1], -1)], dim=0)
  60. features = take_indexes(features, backward_indexes)
  61. features = features + self.pos_embedding
  62. features = rearrange(features, 't b c -> b t c')
  63. features = self.transformer(features)
  64. features = rearrange(features, 'b t c -> t b c')
  65. features = features[1:]
  66. patches = self.head(features)
  67. mask = torch.zeros_like(patches)
  68. mask[T-1:] = 1
  69. mask = take_indexes(mask, backward_indexes[1:] - 1)
  70. img = self.patch2img(patches)
  71. mask = self.patch2img(mask)
  72. return img, mask
  73. #损失函数部分使用MSE,但是只计算被mask掉部分的损失
  74. loss = torch.mean((predicted_img - img) ** 2 * mask) / args.mask_ratio

推理阶段

       选取MAE的encoder部分的feature,并在图像检索过程中选取cls_token作为特征向量进行匹配

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/364436
推荐阅读
相关标签
  

闽ICP备14008679号