当前位置:   article > 正文

Vision Mamba代码笔记

Vision Mamba代码笔记

1 论文回顾

  • 基本思路

论文解读见:

《VideoMamba》论文笔记_video mamba-CSDN博客

  • 注意

  1. Vision Mamba和VIT的输入和输出的shape的形状是相同的(VIT基于Transformer的Encoder设计,输入经过多层MHA和MLP计算,输入和输出的形状相同,Mamba的SSM架构就可以做到输入与输出token的个数以及每个token的维度相同,自然也可以做到整个输入和输出的形状相同,再者Vision Mamba的设计参照VIT的结构,自然也要注意输入与输出形状相同。两者的输入流经过各自对应的Encoder之后都具备了上下文信息,其效果相同,效率上基于Mamba的模型会更胜一筹。
  2. 正如1所说,Vision Mamba的设计参照VIT,这两个工作的流程是相同的,这里主要指的是图片打patch 再concat上class token再加上Position Embedding这个流程,两个模型唯一不同的地方就是Emcoder部分的不同,VIT使用的是Transformer的Encoder,Vim使用的是Mamba的Encoder,二者都是用于token间信息交互,上下文建模的

2 环境配置

按照官方readme.md配置,如果有问题照着下面这个链接改

vision mamba 运行训练记录,解决bimamba_type错误-CSDN博客

值得说明的一点是,如果你之前在跑其他的mamba,环境拿过来是不能直接直接用的,因为标准的Mamba类是没有bimamba_type这个参数的,

所以,需要去Vim代码官网去找到mamba-1p1p1包,下载之后放自己项目里

事实上Vision Mamba重写了这个Mamba类,可以看到里边是由bimamba_type这个参数的(这其实也是Vision Mamba的主要贡献),执行如下代码

  1. cp -rf mamba-1p1p1/mamba_ssm /home/liyhc/anaconda3/envs/mamba/lib/python3.10/site-packages
  2. #后边是系统的mamba的安装路径,自己照着自己环境mamba的安装路径进行修改

3 代码笔记

3.1 代码链接

官方代码链接

Vim/vim/models_mamba.py at main · hustvl/Vim (github.com)

我手敲的带中文注释的链接

Johnny-Haytham/Vim: Vim with chinese notation (github.com)

3.2 Module

3.2.1 PatchEmbed

     

  1. class PatchEmbed(nn.Module):
  2. def __init__(self, img_size=224,patch_size=16,stride=16,in_channels=3,embed_dim=768,norm_layer=None,flatten=True):
  3. super(PatchEmbed, self).__init__()
  4. img_size = to_2tuple(img_size)
  5. patch_size = to_2tuple(patch_size)#将img_size和patch_size化成元组的形式
  6. self.img_size = img_size
  7. self.patch_size = patch_size
  8. #一个patch形成一个grid(网格),这里记录网格的形状
  9. self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1 , (img_size[1] - patch_size[1]) // stride + 1)
  10. self.num_patches = self.grid_size[0] * self.grid_size[1]#总共的patch个数
  11. self.flatten = flatten
  12. #打patch的操作,实际为卷积的操作(为了不重复卷积,步长的大小理论上因该等于卷积核的大小)
  13. self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride)
  14. self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()#nn.Identity的输入等于输出,通常作为占位层使用
  15. def forward(self, x):
  16. B, C, H, W = x.shape
  17. assert H == self.img_size[0] and W == self.img_size[1],\
  18. f"Input img size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})"
  19. x = self.proj(x)#B,C,H,W——>B,embed_dim,grid_size,grid_size
  20. if self.flatten:
  21. x = x.flatten(2).transpose(1, 2)#B,embed_dim,grid_size,grid_size——>B,embed_dim,grid_size*grid_size——>B,grid_size*grid_size,embed_dim
  22. x = self.norm(x)
  23. return x

3.2.2 Vim Encoder Block

  1. class Block(nn.Module):
  2. def __init__(
  3. self, dim, mixer_cls,
  4. norm_cls = nn.LayerNorm,
  5. fused_add_norm=False,residual_in_fp32=False,drop_path=0.
  6. ):
  7. super(Block, self).__init__()
  8. self.residual_in_fp32 = residual_in_fp32
  9. self.fused_add_norm = fused_add_norm
  10. self.mixer = mixer_cls(dim)#这其实是Mamba的部分固定参数的调用
  11. self.norm = norm_cls(dim)
  12. self.drop_path = DropPath(drop_path)
  13. if self.fused_add_norm:
  14. assert RMSNorm is not None,"RMSNorm import Fails"
  15. assert isinstance(
  16. self.norm, (nn.LayerNorm, RMSNorm)
  17. ),"Only LayerNorm and RMSNorm are supported for fused_add_norm"
  18. def forward(self,
  19. hidden_states: Tensor,#上一个时间状态的输出,也就是ht-1
  20. residual: Optional[Tensor]=None,
  21. inference_params = None):
  22. if not self.fused_add_norm:#如果fused_add_norm为False
  23. if residual is None:#如果残差为空,这个是if用于第一个block处理输入数据
  24. residual = hidden_states
  25. else:#如果残差不为空,这个if用于处理除了第一个block以外的所有block的操作
  26. residual = residual + self.drop_path(self.mixer(hidden_states))
  27. # 将residual的数据类型转化为self.norm.weight.dtype,将residual归一化后保存为hidden_states
  28. hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
  29. if self.residual_in_fp32:#如果指定self_residual的类型是float32的话
  30. residual = residual.to(torch.float32)
  31. else:#如果fused_add_norm不为False
  32. fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
  33. if residual is None:#如果残差为空,这个是if用于第一个block处理输入数据
  34. hidden_states,residual = fused_add_norm_fn(
  35. hidden_states,
  36. self.norm.weight,
  37. self.norm.bias,
  38. residual=residual,
  39. prenorm=True,
  40. residual_in_fp32=self.residual_in_fp32,
  41. eps=self.norm.eps,
  42. )
  43. else:#如果残差不为空,这个if用于处理除了第一个block以外的所有block的操作
  44. hidden_states,residual = fused_add_norm_fn(
  45. self.drop_path(hidden_states),
  46. self.norm.weight,
  47. self.norm.bias,
  48. residual=residual,
  49. prenorm=True,
  50. residual_in_fp32=self.residual_in_fp32,
  51. eps=self.norm.eps,
  52. )
  53. hidden_states = self.mixer(hidden_states,inference_params=inference_params)
  54. return hidden_states, residual
  55. def create_block(
  56. d_model, #token维度
  57. ssm_cfg=None, #ssm模型的配置文件
  58. norm_epsilon=1e-5, #
  59. drop_path=0.,
  60. rms_norm=False,
  61. residual_in_fp32=False,
  62. fused_add_norm=False,
  63. layer_idx=None,
  64. device=None,
  65. dtype=None,
  66. if_bimamba=False, #是否使用双向mamba扫描
  67. bimamba_type="none",
  68. if_devide_out=False,
  69. init_layer_scale=None,
  70. ):
  71. if if_bimamba:#如果使用双向mamba扫描
  72. bimamba_type = "v1" #这是一个模型的版本号
  73. if ssm_cfg is None:
  74. ssm_cfg = {}
  75. factory_kwargs = {"device": device, "dtype": dtype}
  76. mixer_cls = partial( #代表着VIM Encoder对class token的拼接方式,cls token可以拼接到不同位置(所有token前面,所有token中间,...或是随机位置)
  77. Mamba,
  78. layer_idx=layer_idx,
  79. bimamba_type=bimamba_type,
  80. if_devide_out=if_devide_out,
  81. init_layer_scale=init_layer_scale,
  82. **ssm_cfg,
  83. **factory_kwargs
  84. )
  85. norm_cls=partial( #对于class token的normalization函数
  86. nn.LayerNorm if not rms_norm else RMSNorm,eps=norm_epsilon,**factory_kwargs
  87. ) #eps用于避免归一化过程中分母为0的情况
  88. block =Block(
  89. d_model,
  90. mixer_cls,
  91. norm_cls=norm_cls,
  92. drop_path=drop_path,
  93. fused_add_norm=fused_add_norm,
  94. residual_in_fp32=residual_in_fp32,
  95. )
  96. block.layer_idx = layer_idx
  97. return block

3.3 Vision Mamba整体

  1. class VisionMamba(nn.Module):
  2. def __init__(self,
  3. img_size=224,
  4. patch_size=16,
  5. stride=16,
  6. depth=24, #需要构造的block的个数
  7. embed_dim=192,
  8. channels=3,
  9. num_classes=1000, #这里用imagenet做分类任务所以有1000个类,也就代表了最后的mlp的输出层包含1000个节点
  10. ssm_cfg=None, #ssm的配置文件
  11. drop_rate=0., #drop_rate是针对于dropout的频率(对某个节点进行失活的操作)
  12. drop_path_rate=0.1, #drop_path_rate是针对drop_path的频率(对某个层进行失活的操作)
  13. norm_epsilon:float=1e-5,
  14. rms_norm:bool=False, #是否使用rms_norm这种方法
  15. fused_add_norm=False,
  16. residual_in_fp32=False, #残差链接的时候是不是浮点型
  17. device=None,
  18. dtype=None,
  19. pt_hw_seq_len=14, #代表sequence的长度
  20. if_bidirectional=False,
  21. final_pool_type='none', #最后池化层的类型
  22. if_abs_pos_embed=False, #在位置编码的时候是不是需要用绝对值编码(有两种位置编码方式:1、直接给出的绝对值位置编码 2、可学习的位置编码)
  23. if_rope=False, #rope也是一种对positionembeding的特殊编码方式
  24. if_rope_residual=False, #对 residual的rope 旨在增加鲁棒性
  25. flip_img_sequences_ratio=-1., #image_squence的反转概率
  26. if_bimamba=False,
  27. bimamba_type="none", #表示使用的mamba的版本
  28. if_cls_token=False, #拼不拼clstoken
  29. if_devide_out=False,
  30. init_layer_scale=None,
  31. use_double_cls_token=False,
  32. use_middle_cls_token=False,
  33. **kwargs): #为了保证模型的可扩展性所以加一个**kwargs
  34. factory_kwargs = {"device": device, "dtype": dtype}
  35. # add factory_kwargs into kwargs
  36. kwargs.update(factory_kwargs)
  37. super(VisionMamba,self).__init__()
  38. self.residual_in_fp32 = residual_in_fp32
  39. self.fused_add_norm = fused_add_norm
  40. self.if_bidirectional = if_bidirectional
  41. self.final_pool_type = final_pool_type
  42. self.if_abs_pos_embed = if_abs_pos_embed
  43. self.if_rope = if_rope
  44. self.if_rope_residual = if_rope_residual
  45. self.flip_img_sequences_ratio = flip_img_sequences_ratio
  46. self.if_cls_token = if_cls_token
  47. self.use_double_cls_token = use_double_cls_token #这个拼接clstoken的方式是头拼一个尾拼一个
  48. self.use_middle_cls_token = use_middle_cls_token #这个拼接clstoken的方式是中间拼一个
  49. self.num_tokens = 1 if if_cls_token else 0 #表示拼了几个cls token进去?存疑
  50. # pretrain parameters
  51. self.num_classes = num_classes
  52. self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
  53. self.patch_embed = PatchEmbed(
  54. img_size=img_size, patch_size=patch_size, stride=stride, in_channels=channels, embed_dim=embed_dim)
  55. num_patches = self.patch_embed.num_patches
  56. if if_cls_token: #如果使用cls token的话
  57. if use_double_cls_token:
  58. self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim))#拼在token序列最前面的clstoken
  59. self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim))#拼在token序列最后面的clstoken
  60. self.num_tokens = 2 #代表了拼了几个cls token
  61. else:
  62. self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
  63. # self.num_tokens = 1
  64. if if_abs_pos_embed: #如果使用给定的位置编码(给定绝对值)
  65. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim))
  66. self.pos_drop = nn.Dropout(p=drop_rate)
  67. #if if_rope: #Rope(Rotary Position Embedding)对于Position Embedding的翻转操作,(数据增强操作)
  68. # half_head_dim = embed_dim // 2
  69. # hw_seq_len = img_size // patch_size #高/宽方向的序列长度
  70. # self.rope = VisionRotaryEmbeddingFast(
  71. # dim=half_head_dim,
  72. # pt_seq_len=pt_hw_seq_len,
  73. # ft_seq_len=hw_seq_len
  74. # )
  75. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() #这个是最终的分类头
  76. #drop path rate 随机失活一些东西,目的是让模型的鲁棒性更强,效果更好
  77. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] #构建从start到end的等距张量,目的是为每层网络设置独立的drop_path_rate
  78. inter_dpr = [0.0] +dpr #第一层不需要dropout,所以要在最开始拼个0
  79. self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
  80. self.layers = nn.ModuleList(
  81. [
  82. create_block(#对VisionMamba的Encoder进行初始化的操作
  83. embed_dim,
  84. ssm_cfg=ssm_cfg,
  85. norm_epsilon=norm_epsilon,
  86. rms_norm=rms_norm,
  87. residual_in_fp32=residual_in_fp32,
  88. fused_add_norm=fused_add_norm,
  89. layer_idx=i,
  90. if_bimamba=if_bimamba,
  91. bimamba_type=bimamba_type,
  92. drop_path=inter_dpr[i],
  93. if_devide_out=if_devide_out,
  94. init_layer_scale=init_layer_scale,
  95. **factory_kwargs
  96. )
  97. for i in range(depth)
  98. ]
  99. )
  100. self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(embed_dim, eps=norm_epsilon,**factory_kwargs)
  101. #trunc_normal_函数是一个用于对张量进行截断正态分布初始化的函数。它通常用于初始化神经网络的权重或偏置。
  102. if if_abs_pos_embed:
  103. trunc_normal_(self.pos_embed, std=.02)
  104. if if_cls_token:
  105. if use_double_cls_token:
  106. trunc_normal_(self.cls_token_head, std=.02)
  107. trunc_normal_(self.cls_token_tail, std=.02)
  108. else:
  109. trunc_normal_(self.cls_token, std=.02)
  110. #定义前向特征传播的方法
  111. def forward_features(self, x,inference_params=None,
  112. if_random_cls_token_position=False,
  113. if_random_token_rank=False):
  114. x = self.patch_embed(x)
  115. B, M, _ = x.shape
  116. if self.if_cls_token:
  117. if self.use_double_cls_token: #在序列前后拼double_cls_token
  118. cls_token_head = self.cls_token_head.expand(B, -1, -1)#expend 是共享内存的拓展 并不是创建新的张量
  119. cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
  120. token_position = [0, M+1]
  121. x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
  122. M = x.shape[1]
  123. else:
  124. if self.use_middle_cls_token:
  125. cls_token = self.cls_token.expand(B, -1, -1)
  126. token_position =M//2
  127. x = torch.cat((x[:,:token_position,:], cls_token, x[:,token_position:,:]), dim=1)
  128. elif if_random_cls_token_position:
  129. cls_token = self.cls_token.expand(B, -1, -1)
  130. token_position = random.randint(0,M)
  131. x = torch.cat((x[:,:token_position,:], cls_token, x[:,token_position:,:]), dim=1)
  132. print("token_position: ", token_position)
  133. else:
  134. cls_token = self.cls_token.expand(B, -1, -1)
  135. token_position = 0
  136. x = torch.cat((cls_token, x), dim=1)
  137. M = x.shape[1]
  138. if self.if_abs_pos_embed:
  139. x= x+self.pos_embed
  140. x = self.pos_drop(x)
  141. if if_random_token_rank:#是否要把所有的token序列打乱,如果打乱了的话自然要更新存储clstoken的位置
  142. #生成随机 shuffle索引
  143. shuffle_indices = torch.randperm(M)#torch.randperm(M)是用于生成一个从0到M-1的随机排列的整数序列的函数。
  144. if isinstance(token_position, list):
  145. print("original value: ",x[0, token_position[0],0], x[0, token_position[1],0])
  146. else:
  147. print("original value: ",x[0, token_position,0])
  148. print("original token_position: ", token_position)
  149. #执行shuffle
  150. x = x[:, shuffle_indices, :]
  151. if isinstance(token_position, list):
  152. new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))]
  153. token_position = new_token_position
  154. else:
  155. token_position = torch.where(shuffle_indices == token_position)[0].item()
  156. if isinstance(token_position, list):
  157. print("new value: ", x[0, token_position[0],0], x[0, token_position[1],0])
  158. else:
  159. print("new value: ", x[0, token_position, 0])
  160. print("new token_position: ", token_position)
  161. if_flip_img_suquences = False
  162. if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) >1e-5:
  163. x=x.flip([1])#会创建一个与张量 x 的形状相同的新张量,其中第一个维度的元素被翻转。翻转是指将第一个维度中的元素按相反的顺序重新排列。
  164. if_flip_img_suquences = True
  165. #mamba的整体部分
  166. residual = None
  167. hidden_states = x
  168. if not self.if_bidirectional:#只使用单向扫描(所以单向扫描就既可以选择正向单向扫描进行rope,也可以选择反向单项扫描进行rope)
  169. for layer in self.layers:
  170. if if_flip_img_suquences and self.if_rope:#反转序列并使用加强版的position Embedding
  171. hidden_states = hidden_states.flip([1])
  172. if residual is not None:
  173. residual = residual.flip([1])
  174. #rope about
  175. if self.if_rope:
  176. hidden_states = self.rope(hidden_states)
  177. if residual is not None and self.if_rope_residuals:
  178. residual = self.rope(residual)
  179. if if_flip_img_suquences and self.if_rope:#这里并不是跟上上段代码重复,而是filp了之后要再反转过来
  180. hidden_states = hidden_states.flip([1])
  181. if residual is not None:
  182. residual = residual.flip([1])
  183. hidden_states, residual = layer(
  184. hidden_states, residual, inference_params=inference_params,
  185. )
  186. else:#如果采用双向扫描
  187. for i in range(len(self.layers)//2):
  188. if self.if_rope:
  189. hidden_states = self.rope(hidden_states)
  190. if residual is not None and self.if_rope_residuals:
  191. residual = self.rope(residual)
  192. hidden_states_f, residual_f = self.layers[i * 2](
  193. hidden_states, residual, inference_params=inference_params
  194. )
  195. hidden_state_b, residual_b = self.layers[i * 2 + 1](
  196. hidden_states.flip([1]),None if residual is None else residual.flip([1]),inference_params=inference_params
  197. )
  198. hidden_states = hidden_states_f + hidden_state_b.flip([1])
  199. residual = residual_f + residual_b.flip([1])
  200. if not self.fused_add_norm:#如果不使用fused_add_norm
  201. if residual is None:#如果残差为空
  202. residual = hidden_states
  203. else:#如果残差不为空
  204. residual = residual + self.drop_path(hidden_states)
  205. hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
  206. else:
  207. fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f,RMSNorm) else layer_norm_fn
  208. hidden_states = fused_add_norm_fn(
  209. self.drop_path(hidden_states),
  210. self.norm_f.weight,
  211. self.norm_f.bias,
  212. eps=self.norm_f.eps,
  213. residual=residual,
  214. prenorm=False,
  215. residual_in_fp32=self.residual_in_fp32,
  216. )
  217. # return only cls token if it exists
  218. if self.if_cls_token:
  219. if self.use_double_cls_token:
  220. return (hidden_states[:,token_position[0],:] + hidden_states[:,token_position[1],:]) / 2
  221. else:
  222. if self.use_middle_cls_token:
  223. return hidden_states[:,token_position,:]
  224. elif if_random_cls_token_position:
  225. return hidden_states[:,token_position,:]
  226. else:
  227. return hidden_states[:,token_position,:]
  228. if self.final_pool_type == 'none':
  229. return hidden_states[:,-1,:]#这个切片是为了之后的mlp所做出的妥协
  230. elif self.final_pool_type == 'mean':
  231. return hidden_states.mean(dim=1)
  232. elif self.final_pool_type == 'max':
  233. return hidden_states
  234. elif self.final_pool_type == 'all':
  235. return hidden_states
  236. else:
  237. raise NotImplementedError
  238. def forward(self,x,return_features=False,inference_params=None,if_random_cls_token_position=False,if_random_token_rank=False):
  239. x = self.forward_features(x,inference_params,if_random_cls_token_position = if_random_cls_token_position,if_random_token_rank = if_random_token_rank)
  240. if return_features:
  241. return x
  242. x = self.head(x)
  243. if self.final_pool_type == 'max':
  244. x = x.max(dim=1)[0]
  245. return x

3.4 测试

  1. def test():
  2. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  3. model = VisionMamba(
  4. patch_size=16,
  5. embed_dim=192,
  6. depth=24,
  7. rms_norm=True,
  8. residual_in_fp32=True,
  9. fused_add_norm=True,
  10. final_pool_type='mean',
  11. if_abs_pos_embed=True,
  12. if_rope=False,
  13. if_rope_residual=False,
  14. bimamba_type="V2",
  15. if_cls_token=True,
  16. if_device_out=True,
  17. use_double_cls_token=True
  18. ).to(device)
  19. x = torch.randn(size=(4,3,224,224)).to(device)
  20. preds = model(x)
  21. print(preds.shape)
  22. if __name__ == '__main__':
  23. test()

3.5 输出

参考文献

下个风口?Mamba手推公式&代码手搓_哔哩哔哩_bilibili

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/622867
推荐阅读
  

闽ICP备14008679号