当前位置:   article > 正文

TransReID学习记录

transreid

 1、TransReID论文链接

原文:TransReID: Transformer-based Object Re-Identification
代码:GitHub - damo-cv/TransReID: [ICCV-2021] TransReID: Transformer-based Object Re-Identification
作者:阿里巴巴&浙江大学

本文是罗浩大佬把视觉TransformerViT应用在ReID领域的研究工作,在多个ReID基准数据集上取得了超过CNN的性能。成功刷榜的VIT reid。

论文思路:

1、Overlapping Patches

本文的思想核心,在Swin Transformer中提到如果仅仅是平分图像为多个patch,那么由于自注意力的原因,导致边界信息被丢下。在下面代码中,本文提出了Overlapping Patches,相比较平分patch有很大的优势

  1. # 接下来要把图片转换成Patch,一种做法是直接把Image转化成Patch,另一种做法是把Backbone输出的特征转化成Patch。
  2. class PatchEmbed(nn.Module):
  3. """ Image to Patch Embedding 图片切块分为patch 按照 Transformer 结构中的位置编码习惯,这个工作也使用了位置编码。不同的是,ViT 中的位置编码没有采用原版
  4. Transformer 中的 sincossincossincos 编码,而是直接设置为可学习的 Positional Encoding。对训练好的 Positional Encoding 进行可视化
  5. 位置越接近,往往具有更相似的位置编码。此外,出现了行列结构,同一行/列中的 patch 具有相似的位置编码。 embed_dim怎么计算得到的
  6. """
  7. # 1) 直接把Image转化成Patch:
  8. # 输入的x的维度是:(B, C, H, W)
  9. # 输出的PatchEmbedding的维度是:(B, 14*14, 768),768表示embed_dim,14*14表示一共有196个Patches。
  10. def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
  11. super().__init__()
  12. img_size = to_2tuple(img_size)
  13. patch_size = to_2tuple(patch_size)
  14. num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
  15. self.img_size = img_size
  16. self.patch_size = patch_size
  17. self.num_patches = num_patches
  18. # kernel_size=块大小,即每个块输出一个值,类似每个块展平后使用相同的全连接层进行处理
  19. # 输入维度为3,输出维度为块向量长度
  20. # 与原文中:分块、展平、全连接降维保持一致
  21. # 输出为[B, C, H, W]
  22. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
  23. def forward(self, x):
  24. B, C, H, W = x.shape
  25. # FIXME look at relaxing size constraints
  26. assert H == self.img_size[0] and W == self.img_size[1], \
  27. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  28. # [B, C, H, W] -> [B, C, H*W] ->[B, H*W, C]
  29. x = self.proj(x).flatten(2).transpose(1, 2)
  30. # 展平为位置序列,.transpose(1, 2)与.transpose(2,1)在实现结果上是没有区别的
  31. return x
  32. # 2) 把Backbone输出的特征转化成Patch:
  33. # 输入的x的维度是:(B, C, H, W)
  34. # 得到Backbone输出的维度是:(B, feature_size, feature_size, feature_dim)
  35. # 输出的PatchEmbedding的维度是:(B, feature_size, feature_size, embed_dim),一共有feature_size * feature_size个Patches。
  36. class HybridEmbed(nn.Module):
  37. """ CNN Feature Map Embedding 混合嵌入
  38. Extract feature map from CNN, flatten, project to embedding dim.
  39. 从CNN提取特征图,展平,投影到嵌入dim。
  40. """
  41. def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
  42. super().__init__()
  43. assert isinstance(backbone, nn.Module)
  44. img_size = to_2tuple(img_size)
  45. self.img_size = img_size
  46. self.backbone = backbone
  47. if feature_size is None:
  48. with torch.no_grad():
  49. # FIXME this is hacky, but most reliable way of determining the exact dim of the output feature
  50. # FIXME这是确定输出特性的确切尺寸的一种简单但最可靠的方法
  51. # map for all networks, the feature metadata has reliable channel and stride info, but using
  52. # stride to calc feature dim requires info about padding of each stage that isn't captured.
  53. # 对于所有网络,功能元数据都有可靠的通道和步幅信息,但使用步幅到计算功能dim需要有关未捕获的每个阶段填充的信息。
  54. training = backbone.training
  55. if training:
  56. backbone.eval()
  57. o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1]))
  58. if isinstance(o, (list, tuple)):
  59. o = o[-1] # last feature if backbone outputs list/tuple of features
  60. feature_size = o.shape[-2:]
  61. feature_dim = o.shape[1]
  62. backbone.train(training)
  63. else:
  64. feature_size = to_2tuple(feature_size)
  65. if hasattr(self.backbone, 'feature_info'):
  66. feature_dim = self.backbone.feature_info.channels()[-1]
  67. else:
  68. feature_dim = self.backbone.num_features
  69. self.num_patches = feature_size[0] * feature_size[1]
  70. self.proj = nn.Conv2d(feature_dim, embed_dim, 1) # projection 映射,投影
  71. def forward(self, x):
  72. x = self.backbone(x)
  73. if isinstance(x, (list, tuple)):
  74. x = x[-1] # last feature if backbone outputs list/tuple of features
  75. x = self.proj(x).flatten(2).transpose(1, 2)
  76. return x
  77. class PatchEmbed_overlap(nn.Module):
  78. """ Image to Patch Embedding with overlapping patches
  79. """
  80. def __init__(self, img_size=224, patch_size=16, stride_size=20, in_chans=3, embed_dim=768):
  81. super().__init__()
  82. img_size = to_2tuple(img_size)
  83. patch_size = to_2tuple(patch_size)
  84. stride_size_tuple = to_2tuple(stride_size)
  85. self.num_x = (img_size[1] - patch_size[1]) // stride_size_tuple[1] + 1 # python中“//”是一个算术运算符,表示整数除法,
  86. # 它可以返回商的整数部分(向下取整) (224-16)//20+1=10+1=11
  87. self.num_y = (img_size[0] - patch_size[0]) // stride_size_tuple[0] + 1
  88. print('using stride: {}, and patch number is num_y{} * num_x{}'.format(stride_size, self.num_y, self.num_x))
  89. num_patches = self.num_x * self.num_y # 总的patch数
  90. self.img_size = img_size
  91. self.patch_size = patch_size
  92. self.num_patches = num_patches
  93. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride_size)
  94. for m in self.modules():
  95. if isinstance(m, nn.Conv2d):
  96. n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
  97. m.weight.data.normal_(0, math.sqrt(2. / n))
  98. elif isinstance(m, nn.BatchNorm2d):
  99. m.weight.data.fill_(1)
  100. m.bias.data.zero_()
  101. elif isinstance(m, nn.InstanceNorm2d):
  102. m.weight.data.fill_(1)
  103. m.bias.data.zero_()
  104. def forward(self, x):
  105. B, C, H, W = x.shape
  106. # FIXME look at relaxing size constraints
  107. assert H == self.img_size[0] and W == self.img_size[1], \
  108. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  109. x = self.proj(x)
  110. x = x.flatten(2).transpose(1, 2) # [64, 8, 768]
  111. return x

2、Position Embeddings.

本文的Position Embeddings.并不是原创新的,也是采用了VIT中最常用的方法。

Fixed Positional Encodings:即将各个位置的标志设定为固定值,一般是采用不同频率的Sin函数来表示。
Learnable Positional Encoding:即训练开始时,初始化一个和输入token数目一致的tensor,这个tensor会在训练过程中逐步更新

  1. # posemb代表未插值的位置编码权值,posemb_tok为位置编码的token部分,posemb_grid为位置编码的插值部分。
  2. # 首先把要插值部分posemb_grid给reshape成(1, gs_old, gs_old, -1)的形式,再插值成(1, gs_new, gs_new, -1)的形式,
  3. # 最后与token部分在第1维度拼接在一起,得到插值后的位置编码posemb。
  4. def resize_pos_embed(posemb, posemb_new, hight, width):
  5. # Rescale the grid of position embeddings when loading from state_dict. Adapted from
  6. # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
  7. ntok_new = posemb_new.shape[1]
  8. posemb_token, posemb_grid = posemb[:, :1], posemb[0, 1:]
  9. ntok_new -= 1
  10. gs_old = int(math.sqrt(len(posemb_grid)))
  11. print('Resized position embedding from size:{} to size: {} with height:{} width: {}'.format(posemb.shape,
  12. posemb_new.shape, hight,
  13. width))
  14. posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
  15. posemb_grid = F.interpolate(posemb_grid, size=(hight, width), mode='bilinear')
  16. posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, hight * width, -1)
  17. posemb = torch.cat([posemb_token, posemb_grid], dim=1)
  18. return posemb

3、Jigsaw Patch Module

我们提出了一个拼图补丁模块(JPM)来打乱补丁嵌入,然后将它们重新组合成不同的部分,每个部分包含整个图像的多个随机补丁嵌入。此外,在训练中引入额外的扰动也有助于提高目标ReID模型的鲁棒性。

(1)Patch Shuffle Operation

(2)Shift Operation

  1. # The first m patches(except for [cls] token) are moved to the end,
  2. # Patch Shuffle Operation The shifted patches are further shuffled by the patch shuffle
  3. # operation with k groups.
  4. def shuffle_unit(features, shift, group, begin=1):
  5. batchsize = features.size(0)
  6. dim = features.size(-1)
  7. # Shift Operation
  8. feature_random = torch.cat([features[:, begin - 1 + shift:], features[:, begin:begin - 1 + shift]], dim=1)
  9. x = feature_random
  10. # The first m patches(except for [cls] token) are moved to the end,
  11. # Patch Shuffle Operation The shifted patches are further shuffled by the patch shuffle
  12. # operation with k groups.
  13. try:
  14. x = x.view(batchsize, group, -1, dim)
  15. except:
  16. x = torch.cat([x, x[:, -2:-1, :]], dim=1)
  17. x = x.view(batchsize, group, -1, dim)
  18. x = torch.transpose(x, 1, 2).contiguous() ##相邻
  19. x = x.view(batchsize, -1, dim)
  20. return x

4、 Side Information Embeddings

  1. class TransReID(nn.Module):
  2. """ Transformer-based Object Re-Identification
  3. 这里把VIT写成了TransReID
  4. """
  5. def __init__(self, img_size=224, patch_size=16, stride_size=16, in_chans=3, num_classes=1000, embed_dim=768,
  6. depth=12,
  7. num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., camera=0,
  8. view=0,
  9. drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, local_feature=False, sie_xishu=1.0):
  10. # 得到分块后的Patch的数量:
  11. super().__init__()
  12. self.num_classes = num_classes
  13. self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
  14. self.local_feature = local_feature
  15. if hybrid_backbone is not None:
  16. self.patch_embed = HybridEmbed(
  17. hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
  18. else:
  19. self.patch_embed = PatchEmbed_overlap(
  20. img_size=img_size, patch_size=patch_size, stride_size=stride_size, in_chans=in_chans,
  21. embed_dim=embed_dim)
  22. num_patches = self.patch_embed.num_patches
  23. # 一开始定义成(1, 1, 768),之后再变成(B, 1, 768)。
  24. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  25. # 定义位置编码:
  26. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
  27. self.cam_num = camera
  28. self.view_num = view
  29. self.sie_xishu = sie_xishu # 侧信息嵌入(SIE)
  30. # Initialize SIE Embedding
  31. if camera > 1 and view > 1:
  32. self.sie_embed = nn.Parameter(torch.zeros(camera * view, 1, embed_dim))
  33. trunc_normal_(self.sie_embed, std=.02)
  34. print('camera number is : {} and viewpoint number is : {}'.format(camera, view))
  35. print('using SIE_Lambda is : {}'.format(sie_xishu))
  36. elif camera > 1:
  37. self.sie_embed = nn.Parameter(torch.zeros(camera, 1, embed_dim))
  38. trunc_normal_(self.sie_embed, std=.02)
  39. print('camera number is : {}'.format(camera))
  40. print('using SIE_Lambda is : {}'.format(sie_xishu))
  41. elif view > 1:
  42. self.sie_embed = nn.Parameter(torch.zeros(view, 1, embed_dim))
  43. trunc_normal_(self.sie_embed, std=.02)
  44. print('viewpoint number is : {}'.format(view))
  45. print('using SIE_Lambda is : {}'.format(sie_xishu))
  46. print('using drop_out rate is : {}'.format(drop_rate))
  47. print('using attn_drop_out rate is : {}'.format(attn_drop_rate))
  48. print('using drop_path rate is : {}'.format(drop_path_rate))
  49. self.pos_drop = nn.Dropout(p=drop_rate)
  50. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
  51. # 把12个Block连接起来
  52. self.blocks = nn.ModuleList([
  53. Block(
  54. dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  55. drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
  56. for i in range(depth)])
  57. self.norm = norm_layer(embed_dim)

5、transformer block

一共有 12个transformer block 

  1. # 先进行Norm,再Attention;进行drop path 再进行Norm,再通过FFN (MLP)。
  2. class Block(nn.Module):
  3. # Transformer Encoder Block
  4. # |_________________________________________| |__________________|
  5. # Embedded Patches ==> Layer Norm ==> Muliti-Head Attention + ==> Layer Norm ==> MLP + ==>
  6. def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
  7. drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
  8. super().__init__()
  9. self.norm1 = norm_layer(dim)
  10. # Multi-head Self-attention
  11. self.attn = Attention(
  12. dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
  13. # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
  14. # DropPath
  15. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  16. self.norm2 = norm_layer(dim)
  17. mlp_hidden_dim = int(dim * mlp_ratio)
  18. self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
  19. def forward(self, x):
  20. # Multi-head Self-attention, Add, LayerNorm
  21. x = x + self.drop_path(self.attn(self.norm1(x)))
  22. # Feed Forward, Add, LayerNorm
  23. x = x + self.drop_path(self.mlp(self.norm2(x)))
  24. return x

 6、 Attention

  1. # 注意力模块,也是多头注意力模块num_heads=8,8个头,初始化的超参数有 维度,多头的数目,qkv的偏置,随机drop
  2. class Attention(nn.Module):
  3. def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
  4. super().__init__()
  5. self.num_heads = num_heads
  6. head_dim = dim // num_heads
  7. # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
  8. # 注意:比例因子在我的原始版本中是错误的,可以手动设置为与上一个权重兼容
  9. # 计算 q,k,v 的转移矩阵
  10. self.scale = qk_scale or head_dim ** -0.5
  11. # # 输出 Q K V
  12. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  13. self.attn_drop = nn.Dropout(attn_drop)
  14. # 最终的线性层
  15. self.proj = nn.Linear(dim, dim)
  16. self.proj_drop = nn.Dropout(proj_drop)
  17. def forward(self, x):
  18. B, N, C = x.shape
  19. # 线性变换
  20. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
  21. # 分割 query key value
  22. q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
  23. # Scaled Dot-Product Attention
  24. # Matmul + Scale
  25. attn = (q @ k.transpose(-2, -1)) * self.scale # @是一个操作符,表示矩阵-向量乘法
  26. # SoftMax
  27. attn = attn.softmax(dim=-1)
  28. attn = self.attn_drop(attn)
  29. # Matmul
  30. x = (attn @ v).transpose(1, 2).reshape(B, N, C)
  31. # 线性变换
  32. x = self.proj(x)
  33. x = self.proj_drop(x)
  34. return x

7、Drop Path

本文使用了Drop Path来提高模型的鲁棒性

DropPath正则化_烟雨行舟#的博客-CSDN博客

参考这篇

8、Class Token

为什么输入的tokens里要加一个额外的Learnable Embedding?
因为transformer输入为一系列的patch embedding,输出也是同样长的序列patch feature,但是最后进行类别的判断时不知道用哪一个feature,需要一个代表总体的feature,简单方法可以用avg pool,把所有的patch feature都考虑算出image feature。但是作者没有用这种方式,而是引入一个class token,在输出的feature后加上一个线性分类器就可以实现分类。class token在训练时随机初始化,然后通过训练学习得到。
参考原文链接:Vision Transformer(ViT) --TransReID学习记录(一)_陈朔怡的博客-CSDN博客_transreid代码

  1. self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
  2. # 定义位置编码:
  3. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
  4. self.cam_num = camera
  5. self.view_num = view
  6. self.sie_xishu = sie_xishu # 侧信息嵌入(SIE)
  7. # Initialize SIE Embedding
  8. if camera > 1 and view > 1:
  9. self.sie_embed = nn.Parameter(torch.zeros(camera * view, 1, embed_dim))
  10. trunc_normal_(self.sie_embed, std=.02)
  11. print('camera number is : {} and viewpoint number is : {}'.format(camera, view))
  12. print('using SIE_Lambda is : {}'.format(sie_xishu))
  13. elif camera > 1:
  14. self.sie_embed = nn.Parameter(torch.zeros(camera, 1, embed_dim))
  15. trunc_normal_(self.sie_embed, std=.02)
  16. print('camera number is : {}'.format(camera))
  17. print('using SIE_Lambda is : {}'.format(sie_xishu))
  18. elif view > 1:
  19. self.sie_embed = nn.Parameter(torch.zeros(view, 1, embed_dim))
  20. trunc_normal_(self.sie_embed, std=.02)
  21. print('viewpoint number is : {}'.format(view))
  22. print('using SIE_Lambda is : {}'.format(sie_xishu))
  23. print('using drop_out rate is : {}'.format(drop_rate))
  24. print('using attn_drop_out rate is : {}'.format(attn_drop_rate))
  25. print('using drop_path rate is : {}'.format(drop_path_rate))
  26. self.pos_drop = nn.Dropout(p=drop_rate)
  27. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
  28. # 把12个Block连接起来
  29. self.blocks = nn.ModuleList([
  30. Block(
  31. dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
  32. drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
  33. for i in range(depth)])
  34. self.norm = norm_layer(embed_dim)
  35. # Classifier head 表示层输出维度是representation_size,分类头输出维度是num_classes
  36. self.fc = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
  37. trunc_normal_(self.cls_token, std=.02)
  38. trunc_normal_(self.pos_embed, std=.02)
  39. self.apply(self._init_weights)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/136601
推荐阅读
相关标签
  

闽ICP备14008679号