当前位置:   article > 正文

yolov7改进--添加CBAM模块(注意力机制)_yolov7加入cbam

yolov7加入cbam

1.models/common.py添加如下代码

  1. class ChannelAttention(nn.Module):
  2. def __init__(self, in_planes, ratio=16):
  3. super(ChannelAttention, self).__init__()
  4. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  5. self.max_pool = nn.AdaptiveMaxPool2d(1)
  6. self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
  7. self.relu = nn.ReLU()
  8. self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
  9. self.sigmoid = nn.Sigmoid()
  10. def forward(self, x):
  11. avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
  12. max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
  13. out = self.sigmoid(avg_out + max_out)
  14. return out
  15. class SpatialAttention(nn.Module):
  16. def __init__(self, kernel_size=7):
  17. super(SpatialAttention, self).__init__()
  18. assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
  19. padding = 3 if kernel_size == 7 else 1
  20. self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
  21. self.sigmoid = nn.Sigmoid()
  22. def forward(self, x):
  23. avg_out = torch.mean(x, dim=1, keepdim=True)
  24. max_out, _ = torch.max(x, dim=1, keepdim=True)
  25. x = torch.cat([avg_out, max_out], dim=1)
  26. x = self.conv(x)
  27. return self.sigmoid(x)
  28. class CBAM(nn.Module):
  29. # Standard convolution
  30. def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
  31. super(CBAM, self).__init__()
  32. self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
  33. self.bn = nn.BatchNorm2d(c2)
  34. self.act = nn.Hardswish() if act else nn.Identity()
  35. self.ca = ChannelAttention(c2)
  36. self.sa = SpatialAttention()
  37. def forward(self, x):
  38. x = self.act(self.bn(self.conv(x)))
  39. x = self.ca(x) * x
  40. x = self.sa(x) * x
  41. return x
  42. def fuseforward(self, x):
  43. return self.act(self.conv(x))

2.修改models/yolo.py

2.1定位到parse_model方法

2.2添加CBAM

 3.创建cfg/training/yolov7-cbam.yaml,填充如下代码

  1. # parameters
  2. nc: 2 # number of classes
  3. depth_multiple: 1.0 # model depth multiple
  4. width_multiple: 1.0 # layer channel multiple
  5. # anchors
  6. anchors:
  7. - [12,16, 19,36, 40,28] # P3/8
  8. - [36,75, 76,55, 72,146] # P4/16
  9. - [142,110, 192,243, 459,401] # P5/32
  10. # yolov7 backbone
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [32, 3, 1]], # 0
  14. [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
  15. [-1, 1, Conv, [64, 3, 1]],
  16. [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
  17. [-1, 1, Conv, [64, 1, 1]],
  18. [-2, 1, Conv, [64, 1, 1]],
  19. [-1, 1, Conv, [64, 3, 1]],
  20. [-1, 1, Conv, [64, 3, 1]],
  21. [-1, 1, Conv, [64, 3, 1]],
  22. [-1, 1, Conv, [64, 3, 1]],
  23. [[-1, -3, -5, -6], 1, Concat, [1]],
  24. [-1, 1, Conv, [256, 1, 1]], # 11
  25. [-1, 1, MP, []],
  26. [-1, 1, Conv, [128, 1, 1]],
  27. [-3, 1, Conv, [128, 1, 1]],
  28. [-1, 1, Conv, [128, 3, 2]],
  29. [[-1, -3], 1, Concat, [1]], # 16-P3/8
  30. [-1, 1, Conv, [128, 1, 1]],
  31. [-2, 1, Conv, [128, 1, 1]],
  32. [-1, 1, Conv, [128, 3, 1]],
  33. [-1, 1, Conv, [128, 3, 1]],
  34. [-1, 1, Conv, [128, 3, 1]],
  35. [-1, 1, Conv, [128, 3, 1]],
  36. [[-1, -3, -5, -6], 1, Concat, [1]],
  37. [-1, 1, CBAM, [512, 1, 1]], # 24
  38. [-1, 1, MP, []],
  39. [-1, 1, Conv, [256, 1, 1]],
  40. [-3, 1, Conv, [256, 1, 1]],
  41. [-1, 1, Conv, [256, 3, 2]],
  42. [[-1, -3], 1, Concat, [1]], # 29-P4/16
  43. [-1, 1, Conv, [256, 1, 1]],
  44. [-2, 1, Conv, [256, 1, 1]],
  45. [-1, 1, Conv, [256, 3, 1]],
  46. [-1, 1, Conv, [256, 3, 1]],
  47. [-1, 1, Conv, [256, 3, 1]],
  48. [-1, 1, Conv, [256, 3, 1]],
  49. [[-1, -3, -5, -6], 1, Concat, [1]],
  50. [-1, 1, CBAM, [1024, 1, 1]], # 37
  51. [-1, 1, MP, []],
  52. [-1, 1, Conv, [512, 1, 1]],
  53. [-3, 1, Conv, [512, 1, 1]],
  54. [-1, 1, Conv, [512, 3, 2]],
  55. [[-1, -3], 1, Concat, [1]], # 42-P5/32
  56. [-1, 1, Conv, [256, 1, 1]],
  57. [-2, 1, Conv, [256, 1, 1]],
  58. [-1, 1, Conv, [256, 3, 1]],
  59. [-1, 1, Conv, [256, 3, 1]],
  60. [-1, 1, Conv, [256, 3, 1]],
  61. [-1, 1, Conv, [256, 3, 1]],
  62. [[-1, -3, -5, -6], 1, Concat, [1]],
  63. [-1, 1, Conv, [1024, 1, 1]], # 50
  64. ]
  65. # yolov7 head
  66. head:
  67. [[-1, 1, SPPCSPC, [512]], # 51
  68. [-1, 1, Conv, [256, 1, 1]],
  69. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  70. [37, 1, Conv, [256, 1, 1]], # route backbone P4
  71. [[-1, -2], 1, Concat, [1]],
  72. [-1, 1, Conv, [256, 1, 1]],
  73. [-2, 1, Conv, [256, 1, 1]],
  74. [-1, 1, Conv, [128, 3, 1]],
  75. [-1, 1, Conv, [128, 3, 1]],
  76. [-1, 1, Conv, [128, 3, 1]],
  77. [-1, 1, Conv, [128, 3, 1]],
  78. [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
  79. [-1, 1, Conv, [256, 1, 1]], # 63
  80. [-1, 1, Conv, [128, 1, 1]],
  81. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  82. [24, 1, Conv, [128, 1, 1]], # route backbone P3
  83. [[-1, -2], 1, Concat, [1]],
  84. [-1, 1, Conv, [128, 1, 1]],
  85. [-2, 1, Conv, [128, 1, 1]],
  86. [-1, 1, Conv, [64, 3, 1]],
  87. [-1, 1, Conv, [64, 3, 1]],
  88. [-1, 1, Conv, [64, 3, 1]],
  89. [-1, 1, Conv, [64, 3, 1]],
  90. [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
  91. [-1, 1, Conv, [128, 1, 1]], # 75
  92. [-1, 1, MP, []],
  93. [-1, 1, Conv, [128, 1, 1]],
  94. [-3, 1, Conv, [128, 1, 1]],
  95. [-1, 1, Conv, [128, 3, 2]],
  96. [[-1, -3, 63], 1, Concat, [1]],
  97. [-1, 1, Conv, [256, 1, 1]],
  98. [-2, 1, Conv, [256, 1, 1]],
  99. [-1, 1, Conv, [128, 3, 1]],
  100. [-1, 1, Conv, [128, 3, 1]],
  101. [-1, 1, Conv, [128, 3, 1]],
  102. [-1, 1, Conv, [128, 3, 1]],
  103. [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
  104. [-1, 1, Conv, [256, 1, 1]], # 88
  105. [-1, 1, MP, []],
  106. [-1, 1, Conv, [256, 1, 1]],
  107. [-3, 1, Conv, [256, 1, 1]],
  108. [-1, 1, Conv, [256, 3, 2]],
  109. [[-1, -3, 51], 1, Concat, [1]],
  110. [-1, 1, Conv, [512, 1, 1]],
  111. [-2, 1, Conv, [512, 1, 1]],
  112. [-1, 1, Conv, [256, 3, 1]],
  113. [-1, 1, Conv, [256, 3, 1]],
  114. [-1, 1, Conv, [256, 3, 1]],
  115. [-1, 1, Conv, [256, 3, 1]],
  116. [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
  117. [-1, 1, Conv, [512, 1, 1]], # 101
  118. [75, 1, RepConv, [256, 3, 1]],
  119. [88, 1, RepConv, [512, 3, 1]],
  120. [101, 1, RepConv, [1024, 3, 1]],
  121. [[102,103,104], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5)
  122. ]

4.修改train.py参数,使其应用修改后的yolov7-cbam.yaml

    parser.add_argument('--cfg', type=str, default='cfg/training/yolov7-cbam.yaml', help='model.yaml path')

5.结束

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/698037
推荐阅读
相关标签
  

闽ICP备14008679号