当前位置:   article > 正文

Mamba-Yolo:基于Mamba架构的yolov8目标检测模型_mamba yolo

mamba yolo

关注up主的B站号:Ai缝合怪,

点赞评论截图发给我,代码和配置文件免费私发给大家。

mamba-yolo:在YOLOV8,V9,V10 中通用

一、mamba-yolo配置文件

二、mamba-yolo代码:

动机:

基于 CNN 和 Transformer 的模型各有局限性。CNN 在捕获长距离信息上存在局部感受野限制,导致在某些情况下难以有效捕获长距离信息,可能导致分割等任务的结果不佳。另一方面,Transformer 在全局建模方面表现出色,能够有效捕获长距离依赖关系,但自注意力机制在处理图像尺寸较大时的复杂度较高,特别是在处理超高清图像检测以及小目标检测等任务时可能面临挑战。

CNN 主要局限性:

    局部感受野限制:CNN 的卷积操作在每一层只能感知局部区域的特征,难以捕获长距离依赖关系。
    参数共享:CNN 中参数共享的特性可能限制其在处理某些复杂模式和全局信息时的表现。

Transformer 主要局限性:

    自注意力机制复杂度:Transformer 中的自注意力机制在处理大规模图像时需要高计算复杂度和显存消耗。
    缺乏局部信息:Transformer 更注重全局关系,可能在一些需要局部信息的任务中表现不佳。

因此,为了克服CNN和Transformer的局限性,SSMs(如Mamba)通过建立远距离依赖关系并保持线性复杂度,展现出在各种任务中的潜力。本文首次提出了 mamba-Yolov8,这是一种将Mamba结合到Yolov8架构中的方法,旨在展示其在目标检测任务中的潜力。通过结合Mamba的优势,mamba-Yolov8旨在改善长距离信息捕获和全局建模能力,以提高目标检测任务的性能和效果。这种结合可能有助于克服传统CNN和Transformer在某些任务中的局限性,为目标检测等任务带来新的发展和进步。

若有想进行魔改、发文章的小伙伴,可在此基础上进行调整、以适配个人发文章的需求。

下图为打印出的结构

其中ultralytics.nn.Addmodules.mamba.MambaLayer 为mamba结构

核心:VSSblock(上图中的MambaLayer)

mamba-yolov8的核心模块是来自 VMamba 的 VSS 块,如图下图所示。

对于经过层归一化后的输入,模型分为两个分支处理:第一个分支经过线性层和激活函数处理,第二个分支经过线性层、深度可分离卷积和激活函数处理,然后进入2D-Selective-Scan(SS2D)。处理后的特征再次归一化,并与第一个分支的输出进行逐元素乘积合并,随后经过一个线性层混合特征,再与残差连接相加形成VSS块的输出。默认情况下,使用激活函数SiLU。

主要还是在 SS2D 这个新的模块,大家可以参考下下面的示意图。

SS2D模块通过扫描展开操作将输入图像在四个方向上展开成序列,然后通过S6块提取特征,以确保全面扫描信息并捕获多样特征。随后,扫描合并操作对四个方向的序列进行求和合并,将输出图像恢复为输入大小。S6块是基于Mamba模块的进一步发展,在S4基础上引入选择机制,有助于保留相关信息并过滤无关信息。

YoloV8改进步骤

1.在该ultralytics/nn下创建Addmodules文件夹,并在下面新建mamba.py文件

2.在mamba.py文件中写入。(注:全部代码私信博主获取,将博主所给代码文件mamba.py,放置在ultralytics/nn/Addmodules/目录结构下)

  1. class MambaLayer(nn.Module):
  2. def __init__(self, dim, d_state=16, d_conv=4, expand=2):
  3. super().__init__()
  4. self.dim = dim
  5. self.norm = nn.LayerNorm(dim)
  6. self.mamba = Mamba(
  7. d_model=dim, # Model dimension d_model
  8. d_state=d_state, # SSM state expansion factor
  9. d_conv=d_conv, # Local convolution width
  10. expand=expand, # Block expansion factor
  11. bimamba_type="v2",
  12. )
  13. def forward(self, x):
  14. B, C = x.shape[:2]
  15. assert C == self.dim
  16. n_tokens = x.shape[2:].numel()
  17. img_dims = x.shape[2:]
  18. x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2)
  19. x_norm = self.norm(x_flat)
  20. # x_norm = x_norm.to('cuda')
  21. x_mamba = self.mamba(x_norm)
  22. out = x_mamba.transpose(-1, -2).reshape(B, C, *img_dims)
  23. #out = out.to(x.device)
  24. return out

3.在ultralytics/nn/Addmodules/__init__.py文件中写入

  1. from .mamba import *
  2. 如下图(注:全部代码私信博主获取,将博主所给代码文件__init__.py,放置在ultralytics/nn/Addmodules/目录结构下)
  3. 4. 在ultralytics/nn/tasks.py中导入MambaLayer
  4. from .Addmodules import *
  5. 5.在在ultralytics/nn/tasks.py中加入MambaLayer模块
  6. 6.在ultralytics/nn/tasks.py的class DetectionModel(BaseModel)类中进行如下修改
  7. class DetectionModel(BaseModel):
  8. """YOLOv8 detection model."""
  9. def __init__(self, cfg='yolov8n.yaml', ch=3, nc=None, verbose=True): # model, input channels, number of classes
  10. """Initialize the YOLOv8 detection model with the given config and parameters."""
  11. super().__init__()
  12. self.yaml = cfg if isinstance(cfg, dict) else yaml_model_load(cfg) # cfg dict
  13. # Define model
  14. ch = self.yaml['ch'] = self.yaml.get('ch', ch) # input channels
  15. if nc and nc != self.yaml['nc']:
  16. LOGGER.info(f"Overriding model.yaml nc={self.yaml['nc']} with nc={nc}")
  17. self.yaml['nc'] = nc # override YAML value
  18. self.model, self.save = parse_model(deepcopy(self.yaml), ch=ch, verbose=verbose) # model, savelist
  19. self.names = {i: f'{i}' for i in range(self.yaml['nc'])} # default names dict
  20. self.inplace = self.yaml.get('inplace', True)
  21. # Build strides
  22. m = self.model[-1] # Detect()
  23. if isinstance(m, (Detect, Segment, Pose)):
  24. s = 256 # 2x min stride
  25. m.inplace = self.inplace
  26. forward = lambda x: self.forward(x)[0] if isinstance(m, (Segment, Pose)) else self.forward(x)
  27. # -------原始---------
  28. #m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))]) # forward ,模型是通过一次前向传播的方式进行输入、输出比来知道步长缩放比
  29. #self.stride = m.stride
  30. # --------------------
  31. #--基于mamba的改进
  32. self.stride=torch.tensor([8., 16., 32.])
  33. m.stride=self.stride
  34. #----------------------
  35. m.bias_init() # only run once
  36. else:
  37. self.stride = torch.Tensor([32]) # default stride for i.e. RTDETR
  38. # Init weights, biases
  39. initialize_weights(self)
  40. if verbose:
  41. self.info()
  42. LOGGER.info('')

7. 在ultralytics/cfg/models/v8/mamba.yaml中配置网络模型结构文件

  

  1. # Ultralytics YOLO
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/神奇cpp/article/detail/915418
    推荐阅读
    相关标签