赞
踩
- class ChannelAttention(nn.Module):
- def __init__(self, in_planes, ratio=16):
- super(ChannelAttention, self).__init__()
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.max_pool = nn.AdaptiveMaxPool2d(1)
- self.f1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
- self.relu = nn.ReLU()
- self.f2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, x):
- avg_out = self.f2(self.relu(self.f1(self.avg_pool(x))))
- max_out = self.f2(self.relu(self.f1(self.max_pool(x))))
- out = self.sigmoid(avg_out + max_out)
- return out
-
- class SpatialAttention(nn.Module):
- def __init__(self, kernel_size=7):
- super(SpatialAttention, self).__init__()
- assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
- padding = 3 if kernel_size == 7 else 1
- self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
- self.sigmoid = nn.Sigmoid()
-
- def forward(self, x):
- avg_out = torch.mean(x, dim=1, keepdim=True)
- max_out, _ = torch.max(x, dim=1, keepdim=True)
- x = torch.cat([avg_out, max_out], dim=1)
- x = self.conv(x)
- return self.sigmoid(x)
-
- class CBAM(nn.Module):
- # Standard convolution
- def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groups
- super(CBAM, self).__init__()
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
- self.bn = nn.BatchNorm2d(c2)
- self.act = nn.Hardswish() if act else nn.Identity()
- self.ca = ChannelAttention(c2)
- self.sa = SpatialAttention()
-
- def forward(self, x):
- x = self.act(self.bn(self.conv(x)))
- x = self.ca(x) * x
- x = self.sa(x) * x
- return x
-
- def fuseforward(self, x):
- return self.act(self.conv(x))
- # parameters
- nc: 2 # number of classes
- depth_multiple: 1.0 # model depth multiple
- width_multiple: 1.0 # layer channel multiple
-
- # anchors
- anchors:
- - [12,16, 19,36, 40,28] # P3/8
- - [36,75, 76,55, 72,146] # P4/16
- - [142,110, 192,243, 459,401] # P5/32
-
- # yolov7 backbone
- backbone:
- # [from, number, module, args]
- [[-1, 1, Conv, [32, 3, 1]], # 0
-
- [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
- [-1, 1, Conv, [64, 3, 1]],
-
- [-1, 1, Conv, [128, 3, 2]], # 3-P2/4
- [-1, 1, Conv, [64, 1, 1]],
- [-2, 1, Conv, [64, 1, 1]],
- [-1, 1, Conv, [64, 3, 1]],
- [-1, 1, Conv, [64, 3, 1]],
- [-1, 1, Conv, [64, 3, 1]],
- [-1, 1, Conv, [64, 3, 1]],
- [[-1, -3, -5, -6], 1, Concat, [1]],
- [-1, 1, Conv, [256, 1, 1]], # 11
-
- [-1, 1, MP, []],
- [-1, 1, Conv, [128, 1, 1]],
- [-3, 1, Conv, [128, 1, 1]],
- [-1, 1, Conv, [128, 3, 2]],
- [[-1, -3], 1, Concat, [1]], # 16-P3/8
- [-1, 1, Conv, [128, 1, 1]],
- [-2, 1, Conv, [128, 1, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [[-1, -3, -5, -6], 1, Concat, [1]],
- [-1, 1, CBAM, [512, 1, 1]], # 24
-
- [-1, 1, MP, []],
- [-1, 1, Conv, [256, 1, 1]],
- [-3, 1, Conv, [256, 1, 1]],
- [-1, 1, Conv, [256, 3, 2]],
- [[-1, -3], 1, Concat, [1]], # 29-P4/16
- [-1, 1, Conv, [256, 1, 1]],
- [-2, 1, Conv, [256, 1, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [[-1, -3, -5, -6], 1, Concat, [1]],
- [-1, 1, CBAM, [1024, 1, 1]], # 37
-
- [-1, 1, MP, []],
- [-1, 1, Conv, [512, 1, 1]],
- [-3, 1, Conv, [512, 1, 1]],
- [-1, 1, Conv, [512, 3, 2]],
- [[-1, -3], 1, Concat, [1]], # 42-P5/32
- [-1, 1, Conv, [256, 1, 1]],
- [-2, 1, Conv, [256, 1, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [[-1, -3, -5, -6], 1, Concat, [1]],
- [-1, 1, Conv, [1024, 1, 1]], # 50
- ]
-
- # yolov7 head
- head:
- [[-1, 1, SPPCSPC, [512]], # 51
-
- [-1, 1, Conv, [256, 1, 1]],
- [-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [37, 1, Conv, [256, 1, 1]], # route backbone P4
- [[-1, -2], 1, Concat, [1]],
-
- [-1, 1, Conv, [256, 1, 1]],
- [-2, 1, Conv, [256, 1, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
- [-1, 1, Conv, [256, 1, 1]], # 63
-
- [-1, 1, Conv, [128, 1, 1]],
- [-1, 1, nn.Upsample, [None, 2, 'nearest']],
- [24, 1, Conv, [128, 1, 1]], # route backbone P3
- [[-1, -2], 1, Concat, [1]],
-
- [-1, 1, Conv, [128, 1, 1]],
- [-2, 1, Conv, [128, 1, 1]],
- [-1, 1, Conv, [64, 3, 1]],
- [-1, 1, Conv, [64, 3, 1]],
- [-1, 1, Conv, [64, 3, 1]],
- [-1, 1, Conv, [64, 3, 1]],
- [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
- [-1, 1, Conv, [128, 1, 1]], # 75
-
- [-1, 1, MP, []],
- [-1, 1, Conv, [128, 1, 1]],
- [-3, 1, Conv, [128, 1, 1]],
- [-1, 1, Conv, [128, 3, 2]],
- [[-1, -3, 63], 1, Concat, [1]],
-
- [-1, 1, Conv, [256, 1, 1]],
- [-2, 1, Conv, [256, 1, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [-1, 1, Conv, [128, 3, 1]],
- [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
- [-1, 1, Conv, [256, 1, 1]], # 88
-
- [-1, 1, MP, []],
- [-1, 1, Conv, [256, 1, 1]],
- [-3, 1, Conv, [256, 1, 1]],
- [-1, 1, Conv, [256, 3, 2]],
- [[-1, -3, 51], 1, Concat, [1]],
-
- [-1, 1, Conv, [512, 1, 1]],
- [-2, 1, Conv, [512, 1, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [-1, 1, Conv, [256, 3, 1]],
- [[-1, -2, -3, -4, -5, -6], 1, Concat, [1]],
- [-1, 1, Conv, [512, 1, 1]], # 101
-
- [75, 1, RepConv, [256, 3, 1]],
- [88, 1, RepConv, [512, 3, 1]],
- [101, 1, RepConv, [1024, 3, 1]],
-
- [[102,103,104], 1, IDetect, [nc, anchors]], # Detect(P3, P4, P5)
- ]
parser.add_argument('--cfg', type=str, default='cfg/training/yolov7-cbam.yaml', help='model.yaml path')
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。