当前位置:   article > 正文

【YOLOv5/v7改进系列】引入卷积块注意力模块CBAM注意力机制

【YOLOv5/v7改进系列】引入卷积块注意力模块CBAM注意力机制
一、导言

CBAM(Convolutional Block Attention Module)是一种简单而有效的注意力机制模块,旨在增强卷积神经网络(CNN)的表现力。该模块通过引入两个独立的注意力机制——通道注意力和空间注意力——来适应性地精炼特征图,从而提升模型的整体性能。

通道注意力

通道注意力关注于“哪些”特征应该被强调或抑制。它基于全局信息,通过平均池化和最大池化操作从特征图中提取两种类型的全局描述符。这两种描述符分别捕捉了特征图中不同通道的重要性和分布情况。之后,这些描述符被传递给一系列卷积层,最终生成一个与输入特征图通道数相同的注意力图。此注意力图被乘以原始特征图,从而实现对不同通道特征的选择性放大或减弱。

空间注意力

空间注意力则关注于“哪里”的问题,即特征图中的哪些位置应该被强调或抑制。它通过分析特征图的局部模式来生成一个二维的注意力图。该过程首先计算特征图中每个位置的平均值和最大值,然后将这两个统计量连接起来并传递给一个卷积层,生成一个单一的二维注意力图。这个注意力图同样被乘以原始特征图,使得模型能够聚焦于关键区域。

CBAM模块的工作流程

CBAM模块依次应用通道注意力和空间注意力机制。具体来说:

  1. 通道注意力:对输入特征图进行平均池化和最大池化,生成两种全局描述符,接着通过卷积层计算出通道注意力图。
  2. 空间注意力:将经过通道注意力处理后的特征图再进行平均和最大值计算,生成两种空间描述符,然后通过卷积层计算出空间注意力图。

最后,通道注意力图和空间注意力图分别与输入特征图相乘,以实现特征的精炼。CBAM模块轻量化且通用,可以无缝集成到任何卷积神经网络架构中,并且整个模块可以与基础CNN一起端到端训练。

实验结果

通过在ImageNet-1K、MS COCO检测和VOC 2007检测数据集上的广泛实验,验证了CBAM的有效性。实验结果显示,CBAM能够显著提高各种基准模型的分类和检测性能,而且这种性能提升不是因为增加了大量的额外参数,而是因为更有效地利用了现有的特征。此外,使用轻量级骨干网络时,CBAM也显示出了良好的适用性,这意味着它对于资源受限的设备同样有益。

总之,CBAM是一种强大的注意力机制,能够帮助CNN更好地聚焦于目标对象,从而提高其识别和定位能力。

优点
  1. 通用性:CBAM可以无缝地集成到任何CNN架构中,而且不会增加显著的计算负担,这使得它能够作为一种“即插即用”的组件被广泛采用。
  2. 有效性验证:作者通过在多个数据集上的实验验证了CBAM的有效性,包括ImageNet-1K分类任务、MS COCO目标检测以及VOC 2007目标检测任务,证明了其在分类和检测任务上的一致改进。
  3. 轻量级设计:CBAM的设计考虑到了计算效率和参数数量,这意味着它可以被轻易地添加到现有的深度学习模型中而不增加过多的开销。
  4. 注意力机制的分解:CBAM将注意力机制分解为通道注意力和空间注意力两个独立的部分,这种方法简化了注意力的学习过程,并且降低了计算复杂度和参数数量。
  5. 性能提升:通过在不同模型上实现CBAM,作者表明这种机制不仅可以提高分类准确率,还可以提高目标检测任务的性能,甚至在某些情况下达到了最先进的水平。

二、准备工作

首先在YOLOv5/v7的models文件夹下新建文件moreattention.py,导入如下代码

  1. from models.common import *
  2. # CBAM注意力机制 https://arxiv.org/pdf/1807.06521
  3. class ChannelAttention(nn.Module):
  4. def __init__(self, in_planes, ratio=16):
  5. super(ChannelAttention, self).__init__()
  6. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  7. self.max_pool = nn.AdaptiveMaxPool2d(1)
  8. self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
  9. self.relu = nn.ReLU()
  10. self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
  11. self.sigmoid = nn.Sigmoid()
  12. def forward(self, x):
  13. avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
  14. max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
  15. out = self.sigmoid(avg_out + max_out)
  16. return out
  17. class SpatialAttention(nn.Module):
  18. def __init__(self, kernel_size=7):
  19. super(SpatialAttention, self).__init__()
  20. assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
  21. padding = 3 if kernel_size == 7 else 1
  22. self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
  23. self.sigmoid = nn.Sigmoid()
  24. def forward(self, x):
  25. avg_out = torch.mean(x, dim=1, keepdim=True)
  26. max_out, _ = torch.max(x, dim=1, keepdim=True)
  27. x = torch.cat([avg_out, max_out], dim=1)
  28. x = self.conv(x)
  29. return self.sigmoid(x)
  30. class CBAM(nn.Module):
  31. # Standard convolution
  32. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  33. super(CBAM, self).__init__()
  34. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  35. self.bn = nn.BatchNorm2d(c2)
  36. self.act = nn.LeakyReLU(0.1) if act else nn.Identity()
  37. self.ca = ChannelAttention(c2)
  38. self.sa = SpatialAttention()
  39. def forward(self, x):
  40. x = self.act(self.bn(self.conv(x)))
  41. x = self.ca(x) * x
  42. x = self.sa(x) * x
  43. return x
  44. def fuseforward(self, x):
  45. return self.act(self.conv(x))

其次在在YOLOv5/v7项目文件下的models/yolo.py中在文件首部添加代码

from models.moreattention import CBAM

并搜索def parse_model(d, ch)

定位到如下行添加以下代码

                 CBAM,

三、YOLOv7-tiny改进工作

完成二后,在YOLOv7项目文件下的models文件夹下创建新的文件yolov7-tiny-cbam.yaml,导入如下代码。

  1. # parameters
  2. nc: 80 # number of classes
  3. depth_multiple: 1.0 # model depth multiple
  4. width_multiple: 1.0 # layer channel multiple
  5. # anchors
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # yolov7-tiny backbone
  11. backbone:
  12. # [from, number, module, args] c2, k=1, s=1, p=None, g=1, act=True
  13. [[-1, 1, Conv, [32, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 0-P1/2
  14. [-1, 1, Conv, [64, 3, 2, None, 1, nn.LeakyReLU(0.1)]], # 1-P2/4
  15. [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  16. [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  17. [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  18. [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  19. [[-1, -2, -3, -4], 1, Concat, [1]],
  20. [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 7
  21. [-1, 1, MP, []], # 8-P3/8
  22. [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  23. [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  24. [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  25. [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  26. [[-1, -2, -3, -4], 1, Concat, [1]],
  27. [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 14
  28. [-1, 1, MP, []], # 15-P4/16
  29. [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  30. [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  31. [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  32. [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  33. [[-1, -2, -3, -4], 1, Concat, [1]],
  34. [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 21
  35. [-1, 1, MP, []], # 22-P5/32
  36. [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  37. [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  38. [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  39. [-1, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  40. [[-1, -2, -3, -4], 1, Concat, [1]],
  41. [-1, 1, Conv, [512, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 28
  42. ]
  43. # yolov7-tiny head
  44. head:
  45. [[-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  46. [-2, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  47. [-1, 1, SP, [5]],
  48. [-2, 1, SP, [9]],
  49. [-3, 1, SP, [13]],
  50. [[-1, -2, -3, -4], 1, Concat, [1]],
  51. [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  52. [[-1, -7], 1, Concat, [1]],
  53. [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 37
  54. [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  55. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  56. [21, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P4
  57. [[-1, -2], 1, Concat, [1]],
  58. [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  59. [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  60. [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  61. [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  62. [[-1, -2, -3, -4], 1, Concat, [1]],
  63. [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 47
  64. [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  65. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  66. [14, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # route backbone P3
  67. [[-1, -2], 1, Concat, [1]],
  68. [-1, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  69. [-2, 1, Conv, [32, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  70. [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  71. [-1, 1, Conv, [32, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  72. [[-1, -2, -3, -4], 1, Concat, [1]],
  73. [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 57
  74. [-1, 1, Conv, [128, 3, 2, None, 1, nn.LeakyReLU(0.1)]],
  75. [[-1, 47], 1, Concat, [1]],
  76. [-1, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  77. [-2, 1, Conv, [64, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  78. [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  79. [-1, 1, Conv, [64, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  80. [[-1, -2, -3, -4], 1, Concat, [1]],
  81. [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 65
  82. [-1, 1, Conv, [256, 3, 2, None, 1, nn.LeakyReLU(0.1)]],
  83. [[-1, 37], 1, Concat, [1]],
  84. [-1, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  85. [-2, 1, Conv, [128, 1, 1, None, 1, nn.LeakyReLU(0.1)]],
  86. [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  87. [-1, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  88. [[-1, -2, -3, -4], 1, Concat, [1]],
  89. [-1, 1, Conv, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 73
  90. [-1, 1, CBAM, [256, 1, 1, None, 1, nn.LeakyReLU(0.1)]], # 74
  91. [57, 1, Conv, [128, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  92. [65, 1, Conv, [256, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  93. [74, 1, Conv, [512, 3, 1, None, 1, nn.LeakyReLU(0.1)]],
  94. [[75,76,77], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5)
  95. ]
  1. from n params module arguments
  2. 0 -1 1 928 models.common.Conv [3, 32, 3, 2, None, 1, LeakyReLU(negative_slope=0.1)]
  3. 1 -1 1 18560 models.common.Conv [32, 64, 3, 2, None, 1, LeakyReLU(negative_slope=0.1)]
  4. 2 -1 1 2112 models.common.Conv [64, 32, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  5. 3 -2 1 2112 models.common.Conv [64, 32, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  6. 4 -1 1 9280 models.common.Conv [32, 32, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  7. 5 -1 1 9280 models.common.Conv [32, 32, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  8. 6 [-1, -2, -3, -4] 1 0 models.common.Concat [1]
  9. 7 -1 1 8320 models.common.Conv [128, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  10. 8 -1 1 0 models.common.MP []
  11. 9 -1 1 4224 models.common.Conv [64, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  12. 10 -2 1 4224 models.common.Conv [64, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  13. 11 -1 1 36992 models.common.Conv [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  14. 12 -1 1 36992 models.common.Conv [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  15. 13 [-1, -2, -3, -4] 1 0 models.common.Concat [1]
  16. 14 -1 1 33024 models.common.Conv [256, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  17. 15 -1 1 0 models.common.MP []
  18. 16 -1 1 16640 models.common.Conv [128, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  19. 17 -2 1 16640 models.common.Conv [128, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  20. 18 -1 1 147712 models.common.Conv [128, 128, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  21. 19 -1 1 147712 models.common.Conv [128, 128, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  22. 20 [-1, -2, -3, -4] 1 0 models.common.Concat [1]
  23. 21 -1 1 131584 models.common.Conv [512, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  24. 22 -1 1 0 models.common.MP []
  25. 23 -1 1 66048 models.common.Conv [256, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  26. 24 -2 1 66048 models.common.Conv [256, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  27. 25 -1 1 590336 models.common.Conv [256, 256, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  28. 26 -1 1 590336 models.common.Conv [256, 256, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  29. 27 [-1, -2, -3, -4] 1 0 models.common.Concat [1]
  30. 28 -1 1 525312 models.common.Conv [1024, 512, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  31. 29 -1 1 131584 models.common.Conv [512, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  32. 30 -2 1 131584 models.common.Conv [512, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  33. 31 -1 1 0 models.common.SP [5]
  34. 32 -2 1 0 models.common.SP [9]
  35. 33 -3 1 0 models.common.SP [13]
  36. 34 [-1, -2, -3, -4] 1 0 models.common.Concat [1]
  37. 35 -1 1 262656 models.common.Conv [1024, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  38. 36 [-1, -7] 1 0 models.common.Concat [1]
  39. 37 -1 1 131584 models.common.Conv [512, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  40. 38 -1 1 33024 models.common.Conv [256, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  41. 39 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
  42. 40 21 1 33024 models.common.Conv [256, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  43. 41 [-1, -2] 1 0 models.common.Concat [1]
  44. 42 -1 1 16512 models.common.Conv [256, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  45. 43 -2 1 16512 models.common.Conv [256, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  46. 44 -1 1 36992 models.common.Conv [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  47. 45 -1 1 36992 models.common.Conv [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  48. 46 [-1, -2, -3, -4] 1 0 models.common.Concat [1]
  49. 47 -1 1 33024 models.common.Conv [256, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  50. 48 -1 1 8320 models.common.Conv [128, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  51. 49 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
  52. 50 14 1 8320 models.common.Conv [128, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  53. 51 [-1, -2] 1 0 models.common.Concat [1]
  54. 52 -1 1 4160 models.common.Conv [128, 32, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  55. 53 -2 1 4160 models.common.Conv [128, 32, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  56. 54 -1 1 9280 models.common.Conv [32, 32, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  57. 55 -1 1 9280 models.common.Conv [32, 32, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  58. 56 [-1, -2, -3, -4] 1 0 models.common.Concat [1]
  59. 57 -1 1 8320 models.common.Conv [128, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  60. 58 -1 1 73984 models.common.Conv [64, 128, 3, 2, None, 1, LeakyReLU(negative_slope=0.1)]
  61. 59 [-1, 47] 1 0 models.common.Concat [1]
  62. 60 -1 1 16512 models.common.Conv [256, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  63. 61 -2 1 16512 models.common.Conv [256, 64, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  64. 62 -1 1 36992 models.common.Conv [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  65. 63 -1 1 36992 models.common.Conv [64, 64, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  66. 64 [-1, -2, -3, -4] 1 0 models.common.Concat [1]
  67. 65 -1 1 33024 models.common.Conv [256, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  68. 66 -1 1 295424 models.common.Conv [128, 256, 3, 2, None, 1, LeakyReLU(negative_slope=0.1)]
  69. 67 [-1, 37] 1 0 models.common.Concat [1]
  70. 68 -1 1 65792 models.common.Conv [512, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  71. 69 -2 1 65792 models.common.Conv [512, 128, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  72. 70 -1 1 147712 models.common.Conv [128, 128, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  73. 71 -1 1 147712 models.common.Conv [128, 128, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  74. 72 [-1, -2, -3, -4] 1 0 models.common.Concat [1]
  75. 73 -1 1 131584 models.common.Conv [512, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  76. 74 -1 1 74338 models.moreattention.CBAM [256, 256, 1, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  77. 75 57 1 73984 models.common.Conv [64, 128, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  78. 76 65 1 295424 models.common.Conv [128, 256, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  79. 77 74 1 1180672 models.common.Conv [256, 512, 3, 1, None, 1, LeakyReLU(negative_slope=0.1)]
  80. 78 [75, 76, 77] 1 17132 models.yolo.IDetect [1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
  81. Model Summary: 277 layers, 6089326 parameters, 6089326 gradients, 13.2 GFLOPS

运行后若打印出如上文本代表改进成功。

四、YOLOv5s改进工作

完成二后,在YOLOv5项目文件下的models文件夹下创建新的文件yolov5s-cbam.yaml,导入如下代码。

  1. # Parameters
  2. nc: 1 # number of classes
  3. depth_multiple: 0.33 # model depth multiple
  4. width_multiple: 0.50 # layer channel multiple
  5. anchors:
  6. - [10,13, 16,30, 33,23] # P3/8
  7. - [30,61, 62,45, 59,119] # P4/16
  8. - [116,90, 156,198, 373,326] # P5/32
  9. # YOLOv5 v6.0 backbone
  10. backbone:
  11. # [from, number, module, args]
  12. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  13. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  14. [-1, 3, C3, [128]],
  15. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  16. [-1, 6, C3, [256]],
  17. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  18. [-1, 9, C3, [512]],
  19. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  20. [-1, 3, C3, [1024]],
  21. [-1, 1, SPPF, [1024, 5]], # 9
  22. ]
  23. # YOLOv5 v6.0 head
  24. head:
  25. [[-1, 1, Conv, [512, 1, 1]],
  26. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  27. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  28. [-1, 3, C3, [512, False]], # 13
  29. [-1, 1, Conv, [256, 1, 1]],
  30. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  31. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  32. [-1, 3, C3, [256, False]], # 17 (P3/8-small)
  33. [-1, 1, Conv, [256, 3, 2]],
  34. [[-1, 14], 1, Concat, [1]], # cat head P4
  35. [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
  36. [-1, 1, Conv, [512, 3, 2]],
  37. [[-1, 10], 1, Concat, [1]], # cat head P5
  38. [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
  39. [-1, 1, CBAM, [1024]],# 24 (P5/32-large)+attention
  40. [[17, 20, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  41. ]
  1. from n params module arguments
  2. 0 -1 1 3520 models.common.Conv [3, 32, 6, 2, 2]
  3. 1 -1 1 18560 models.common.Conv [32, 64, 3, 2]
  4. 2 -1 1 18816 models.common.C3 [64, 64, 1]
  5. 3 -1 1 73984 models.common.Conv [64, 128, 3, 2]
  6. 4 -1 2 115712 models.common.C3 [128, 128, 2]
  7. 5 -1 1 295424 models.common.Conv [128, 256, 3, 2]
  8. 6 -1 3 625152 models.common.C3 [256, 256, 3]
  9. 7 -1 1 1180672 models.common.Conv [256, 512, 3, 2]
  10. 8 -1 1 1182720 models.common.C3 [512, 512, 1]
  11. 9 -1 1 656896 models.common.SPPF [512, 512, 5]
  12. 10 -1 1 131584 models.common.Conv [512, 256, 1, 1]
  13. 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
  14. 12 [-1, 6] 1 0 models.common.Concat [1]
  15. 13 -1 1 361984 models.common.C3 [512, 256, 1, False]
  16. 14 -1 1 33024 models.common.Conv [256, 128, 1, 1]
  17. 15 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
  18. 16 [-1, 4] 1 0 models.common.Concat [1]
  19. 17 -1 1 90880 models.common.C3 [256, 128, 1, False]
  20. 18 -1 1 147712 models.common.Conv [128, 128, 3, 2]
  21. 19 [-1, 14] 1 0 models.common.Concat [1]
  22. 20 -1 1 296448 models.common.C3 [256, 256, 1, False]
  23. 21 -1 1 590336 models.common.Conv [256, 256, 3, 2]
  24. 22 [-1, 10] 1 0 models.common.Concat [1]
  25. 23 -1 1 1182720 models.common.C3 [512, 512, 1, False]
  26. 24 -1 1 296034 models.moreattention.CBAM [512, 512]
  27. 25 [17, 20, 24] 1 16182 models.yolo.Detect [1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]
  28. Model Summary: 284 layers, 7318360 parameters, 7318360 gradients, 16.2 GFLOPs

运行后若打印出如上文本代表改进成功。

五、YOLOv5n改进工作

完成二后,在YOLOv5项目文件下的models文件夹下创建新的文件yolov5n-cbam.yaml,导入如下代码。

  1. # Parameters
  2. nc: 1 # number of classes
  3. depth_multiple: 0.33 # model depth multiple
  4. width_multiple: 0.25 # layer channel multiple
  5. anchors:
  6. - [10,13, 16,30, 33,23] # P3/8
  7. - [30,61, 62,45, 59,119] # P4/16
  8. - [116,90, 156,198, 373,326] # P5/32
  9. # YOLOv5 v6.0 backbone
  10. backbone:
  11. # [from, number, module, args]
  12. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  13. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  14. [-1, 3, C3, [128]],
  15. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  16. [-1, 6, C3, [256]],
  17. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  18. [-1, 9, C3, [512]],
  19. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  20. [-1, 3, C3, [1024]],
  21. [-1, 1, SPPF, [1024, 5]], # 9
  22. ]
  23. # YOLOv5 v6.0 head
  24. head:
  25. [[-1, 1, Conv, [512, 1, 1]],
  26. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  27. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  28. [-1, 3, C3, [512, False]], # 13
  29. [-1, 1, Conv, [256, 1, 1]],
  30. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  31. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  32. [-1, 3, C3, [256, False]], # 17 (P3/8-small)
  33. [-1, 1, Conv, [256, 3, 2]],
  34. [[-1, 14], 1, Concat, [1]], # cat head P4
  35. [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
  36. [-1, 1, Conv, [512, 3, 2]],
  37. [[-1, 10], 1, Concat, [1]], # cat head P5
  38. [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
  39. [-1, 1, CBAM, [1024]],# 24 (P5/32-large)+attention
  40. [[17, 20, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  41. ]
  1. from n params module arguments
  2. 0 -1 1 1760 models.common.Conv [3, 16, 6, 2, 2]
  3. 1 -1 1 4672 models.common.Conv [16, 32, 3, 2]
  4. 2 -1 1 4800 models.common.C3 [32, 32, 1]
  5. 3 -1 1 18560 models.common.Conv [32, 64, 3, 2]
  6. 4 -1 2 29184 models.common.C3 [64, 64, 2]
  7. 5 -1 1 73984 models.common.Conv [64, 128, 3, 2]
  8. 6 -1 3 156928 models.common.C3 [128, 128, 3]
  9. 7 -1 1 295424 models.common.Conv [128, 256, 3, 2]
  10. 8 -1 1 296448 models.common.C3 [256, 256, 1]
  11. 9 -1 1 164608 models.common.SPPF [256, 256, 5]
  12. 10 -1 1 33024 models.common.Conv [256, 128, 1, 1]
  13. 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
  14. 12 [-1, 6] 1 0 models.common.Concat [1]
  15. 13 -1 1 90880 models.common.C3 [256, 128, 1, False]
  16. 14 -1 1 8320 models.common.Conv [128, 64, 1, 1]
  17. 15 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest']
  18. 16 [-1, 4] 1 0 models.common.Concat [1]
  19. 17 -1 1 22912 models.common.C3 [128, 64, 1, False]
  20. 18 -1 1 36992 models.common.Conv [64, 64, 3, 2]
  21. 19 [-1, 14] 1 0 models.common.Concat [1]
  22. 20 -1 1 74496 models.common.C3 [128, 128, 1, False]
  23. 21 -1 1 147712 models.common.Conv [128, 128, 3, 2]
  24. 22 [-1, 10] 1 0 models.common.Concat [1]
  25. 23 -1 1 296448 models.common.C3 [256, 256, 1, False]
  26. 24 -1 1 74338 models.moreattention.CBAM [256, 256]
  27. 25 [17, 20, 24] 1 8118 models.yolo.Detect [1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [64, 128, 256]]
  28. Model Summary: 284 layers, 1839608 parameters, 1839608 gradients, 4.3 GFLOPs

运行后打印如上代码说明改进成功。

六、注意

还有一些可添加的位置,常见可分为骨干和颈部,可作用于局部或全局。

本文只是一个示例修改,实际上还可以将注意力机制添加在更多地方,另外需要注意的是

第二步中self.act = nn.LeakyReLU(0.1) if act else nn.Identity()适用于YOLOv7-tiny,若你采用的是YOLOv5或YOLOv7,则需要修改为SiLU(),即self.act = nn.SiLU() if act else nn.Identity()

注意力机制能加在哪?会在下一篇文章中具体阐述给出,敬请关注。

更多文章产出中,主打简洁和准确,欢迎关注我,共同探讨!

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/在线问答5/article/detail/933920
推荐阅读
相关标签
  

闽ICP备14008679号