赞
踩
class SE(nn.Module): def __init__(self, c1, c2, r=16): super(SE, self).__init__() self.avgpool = nn.AdaptiveAvgPool2d(1) self.l1 = nn.Linear(c1, c1 // r, bias=False) self.relu = nn.ReLU(inplace=True) self.l2 = nn.Linear(c1 // r, c1, bias=False) self.sig = nn.Sigmoid() def forward(self, x): print(x.size()) b, c, _, _ = x.size() y = self.avgpool(x).view(b, c) y = self.l1(y) y = self.relu(y) y = self.l2(y) y = self.sig(y) y = y.view(b, c, 1, 1) return x * y.expand_as(x)
以直接修改yolov5s.yaml 为例讲两种思路
思路1:直接放在backbone末尾
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]], # 6
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SPPF, [1024, 5]], # 9
[-1, 1, SE, [1024, 2]],
]
思路2:放在SPPF前
backbone:
# [from, number, module, args]
[[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[-1, 1, Conv, [128, 3, 2]], # 1-P2/4
[-1, 3, C3, [128]],
[-1, 1, Conv, [256, 3, 2]], # 3-P3/8
[-1, 6, C3, [256]],
[-1, 1, Conv, [512, 3, 2]], # 5-P4/16
[-1, 9, C3, [512]], # 6
[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
[-1, 3, C3, [1024]],
[-1, 1, SE, [1024,2]],
[-1, 1, SPPF, [1024, 5]], # 10
]
(以上思路2选1)
重要!:添加完SE之后,相应的head层(超过10的)都需要将层数+1
head修改为:
head: [[-1, 1, Conv, [512, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 6], 1, Concat, [1]], # cat backbone P4 [-1, 3, C3, [512, False]], # 13 [-1, 1, Conv, [256, 1, 1]], [-1, 1, nn.Upsample, [None, 2, 'nearest']], [[-1, 4], 1, Concat, [1]], # cat backbone P3 [-1, 3, C3, [256, False]], # 17 (P3/8-small) [-1, 1, Conv, [256, 3, 2]], [[-1, 15], 1, Concat, [1]], # cat head P4 +! [-1, 3, C3, [512, False]], # 20 (P4/16-medium) [-1, 1, Conv, [512, 3, 2]], [[-1, 11], 1, Concat, [1]], # cat head P5 +1 [-1, 3, C3, [1024, False]], # 23 (P5/32-large) [[18, 21, 24], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5) +1 ]
修改yolo.py ,在def parse_model(d, ch): 下添加SE模块判断语句:
elif m is SE:
c1 = ch[f]
c2 = args[0]
if c2 !=no:
c2 = make_divisible(c2 * gw, 8)
args = [c1, args[1]]
即完整yolo.py:
# YOLOv5 声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/508811
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。