当前位置:   article > 正文

Mamba-yolo|结合Mamba注意力机制的视觉检测_memba+yolo

memba+yolo

一、本文介绍

 PDF地址:https://arxiv.org/pdf/2405.16605v1

代码地址:GitHub - LeapLabTHU/MLLA: Official repository of MLLA

Demystify Mamba in Vision: A Linear AttentionPerspective一文中引入Baseline Mamba,指明Mamba在处理各种高分辨率图像的视觉任务有着很好的效率。发现了强大的Mamba和线性注意力Transformer( linear attention Transformer)非常相似,然后就分析了两者之间的异同。将Mamba模型重述为linear attention Transformer的变体,并且主要有六大差异,分别是:input gate, forget gate,shortcut, no attention normalization, single-head, and modified block design。作者对每个设计都细致的分析了优缺点,评估了性能,最终发现forget gate和block design是Mamba这么给力的主要贡献点。基于以上发现,作者提出了一个类似mamba的线性注意力模型,Mamba-Like Linear Attention (MLLA) ,相当于取其精华,去其糟粕,把mamba两个最为关键的优点设计结合到线性注意力模型当中,具有可并行计算和快速推理的特点。本文将结合YOlOV8检测模型通过添加MLLA模块提升检测精度。

二、宏观架构设计

线性注意 Transformer 模型通常采用图 (a) 中的设计,它由线性注意力模块和 MLP 模块组成。相比之下,Mamba 通过结合 H3和 Gated Attention这两个设计来改进,得到如图 (b) 所示的架构。改进的 Mamba Block 集成了多种操作,例如选择性 SSM、深度卷积、线性映射、激活函数、门控机制等,并且往往比传统的 Transformer 设计更有效。

MLLA (Mamba-Like Linear Attention)的则是通过将Mamba模型的一些核心设计融入线性注意力机制,从而提升模型的性能。具体来说,MLLA主要整合了Mamba中的"忘记门”(forget gate9)和模块设计(block design)这两个关键因素,这些因素被认为是Mamba成功的主要原因。
以下是对MLLA原理的详细分析:
1.忘记门(Forget Gate)
1.忘记门提供了局部偏差和位置信息。所有的忘记门元素严格限制在0到1之间,这意味着模型在接收到当前输入后会持续衰减失前的隐藏状态。这种特性确保了模型对输入序列的顺序敏感。
2.忘记门的局部偏差和位置信息对于图像处理任务来说非常重要,尽管引入忘记门会导致计算需要采用递归的形式,从而降低并行计算的效率。
2.模块设计(Block Design)
1.Mamba的模块设计在保持相似的浮点运算次数(FLOPS)的同时,通过替换注意力子模块为线性注意力来提升性能。结果表明,采用这种模块设计能够显著提高模型的表现。
3.线性注意力的改进:
1.线性注意力被重新设计以整合忘记门和模块设计,这种改进后的模型被称为MLLA。实验结果显示,MLLA在图像分类和高分辨率密集预测任务中均优于各种视觉Mamba模型
4.并行计算和快速推理速度:
1.MLLA通过使用位置编码(ROPE)来替代忘记门,从而在保持并行计算和快速推理速度的同时,提供必要的位置信息。这使得MLLA在处理非自回归的视觉任务时更加有效

结合yolov8改进

核心代码
 

  1. import torch
  2. import torch.nn as nn
  3. __all__ = ['MLLAttention']
  4. class Mlp(nn.Module):
  5. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  6. super().__init__()
  7. out_features = out_features or in_features
  8. hidden_features = hidden_features or in_features
  9. self.fc1 = nn.Linear(in_features, hidden_features)
  10. self.act = act_layer()
  11. self.fc2 = nn.Linear(hidden_features, out_features)
  12. self.drop = nn.Dropout(drop)
  13. def forward(self, x):
  14. x = self.fc1(x)
  15. x = self.act(x)
  16. x = self.drop(x)
  17. x = self.fc2(x)
  18. x = self.drop(x)
  19. return x
  20. class ConvLayer(nn.Module):
  21. def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, groups=1,
  22. bias=True, dropout=0, norm=nn.BatchNorm2d, act_func=nn.ReLU):
  23. super(ConvLayer, self).__init__()
  24. self.dropout = nn.Dropout2d(dropout, inplace=False) if dropout > 0 else None
  25. self.conv = nn.Conv2d(
  26. in_channels,
  27. out_channels,
  28. kernel_size=(kernel_size, kernel_size),
  29. stride=(stride, stride),
  30. padding=(padding, padding),
  31. dilation=(dilation, dilation),
  32. groups=groups,
  33. bias=bias,
  34. )
  35. self.norm = norm(num_features=out_channels) if norm else None
  36. self.act = act_func() if act_func else None
  37. def forward(self, x: torch.Tensor) -> torch.Tensor:
  38. if self.dropout is not None:
  39. x = self.dropout(x)
  40. x = self.conv(x)
  41. if self.norm:
  42. x = self.norm(x)
  43. if self.act:
  44. x = self.act(x)
  45. return x
  46. class RoPE(torch.nn.Module):
  47. r"""Rotary Positional Embedding.
  48. """
  49. def __init__(self, base=10000):
  50. super(RoPE, self).__init__()
  51. self.base = base
  52. def generate_rotations(self, x):
  53. # 获取输入张量的形状
  54. *channel_dims, feature_dim = x.shape[1:-1][0], x.shape[-1]
  55. k_max = feature_dim // (2 * len(channel_dims))
  56. assert feature_dim % k_max == 0, "Feature dimension must be divisible by 2 * k_max"
  57. # 生成角度
  58. theta_ks = 1 / (self.base ** (torch.arange(k_max, dtype=x.dtype, device=x.device) / k_max))
  59. angles = torch.cat([t.unsqueeze(-1) * theta_ks for t in
  60. torch.meshgrid([torch.arange(d, dtype=x.dtype, device=x.device) for d in channel_dims],
  61. indexing='ij')], dim=-1)
  62. # 计算旋转矩阵的实部和虚部
  63. rotations_re = torch.cos(angles).unsqueeze(dim=-1)
  64. rotations_im = torch.sin(angles).unsqueeze(dim=-1)
  65. rotations = torch.cat([rotations_re, rotations_im], dim=-1)
  66. return rotations
  67. def forward(self, x):
  68. # 生成旋转矩阵
  69. rotations = self.generate_rotations(x)
  70. # 将 x 转换为复数形式
  71. x_complex = torch.view_as_complex(x.reshape(*x.shape[:-1], -1, 2))
  72. # 应用旋转矩阵
  73. pe_x = torch.view_as_complex(rotations) * x_complex
  74. # 将结果转换回实数形式并展平最后两个维度
  75. return torch.view_as_real(pe_x).flatten(-2)
  76. class MLLAttention(nn.Module):
  77. r""" Linear Attention with LePE and RoPE.
  78. Args:
  79. dim (int): Number of input channels.
  80. num_heads (int): Number of attention heads.
  81. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  82. """
  83. def __init__(self, dim=3, input_resolution=[160, 160], num_heads=4, qkv_bias=True, **kwargs):
  84. super().__init__()
  85. self.dim = dim
  86. self.input_resolution = input_resolution
  87. self.num_heads = num_heads
  88. self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
  89. self.elu = nn.ELU()
  90. self.lepe = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
  91. self.rope = RoPE()
  92. def forward(self, x):
  93. """
  94. Args:
  95. x: input features with shape of (B, N, C)
  96. """
  97. x = x.reshape((x.size(0), x.size(2) * x.size(3), x.size(1)))
  98. b, n, c = x.shape
  99. h = int(n ** 0.5)
  100. w = int(n ** 0.5)
  101. # self.rope = RoPE(shape=(h, w, self.dim))
  102. num_heads = self.num_heads
  103. head_dim = c // num_heads
  104. qk = self.qk(x).reshape(b, n, 2, c).permute(2, 0, 1, 3)
  105. q, k, v = qk[0], qk[1], x
  106. # q, k, v: b, n, c
  107. q = self.elu(q) + 1.0
  108. k = self.elu(k) + 1.0
  109. q_rope = self.rope(q.reshape(b, h, w, c)).reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
  110. k_rope = self.rope(k.reshape(b, h, w, c)).reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
  111. q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
  112. k = k.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
  113. v = v.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)
  114. z = 1 / (q @ k.mean(dim=-2, keepdim=True).transpose(-2, -1) + 1e-6)
  115. kv = (k_rope.transpose(-2, -1) * (n ** -0.5)) @ (v * (n ** -0.5))
  116. x = q_rope @ kv * z
  117. x = x.transpose(1, 2).reshape(b, n, c)
  118. v = v.transpose(1, 2).reshape(b, h, w, c).permute(0, 3, 1, 2)
  119. x = x + self.lepe(v).permute(0, 2, 3, 1).reshape(b, n, c)
  120. x = x.transpose(2, 1).reshape((b, c, h, w))
  121. return x
  122. def extra_repr(self) -> str:
  123. return f'dim={self.dim}, num_heads={self.num_heads}'
  124. if __name__ == "__main__":
  125. # Generating Sample image
  126. image_size = (1, 64, 160, 160)
  127. image = torch.rand(*image_size)
  128. # Model
  129. model = MLLAttention(64)
  130. out = model(image)
  131. print(out.size())

修改一

第一还是建立文件,我们找到如下ultralvtics/n文件夹下建立一个目录名字呢就是'Addmodules文件夹(用群内的文件的话已经有了无需新建)!然后在其内部建立一个新的py文件将核心代码复制粘贴进去即可。

修改二

第二步我们在该目录下创建一个新的py文件名字为'  __init__ .py,然后在其内部导入我们的检测头如
下图所示。

修改三 

第三步我门中到如下文件uitralytics/nn/tasks.py进行导入和注册我们的模块

修改四

按照我的添加在parse model里添加即可。

修改5

修改6 配置yolov8-MLLA.yaml文件

# Ultralytics YOLO

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/黑客灵魂/article/detail/920808
推荐阅读
相关标签