当前位置:   article > 正文

[YoloV5修改]基于GnConv卷积模块的yolov5修改_yolo更换conv模块

yolo更换conv模块

HorNet论文地址:https://arxiv.org/pdf/2207.14284.pdf

HorNet是在Swin transformer结构的基础上,结合大核思想提出的新的网络结构模块,使用该模块,作者在ImageNet-1k数据集上做分类,分割以及检测任务都在当时达到了SOTA的效果,是一个能有效增强各种网络的性能而不会引入太大参数量的一种改进思路,已经有很多博主提出将该模块用于yolo系列网络中,以期望达到更好的效果。本文主要是针对YoloV5系列的网络进行C3模块的替换,替换成HorNet模块。

HorNet模块的结构如下图所示:

该图来源于论文中。从图中我们可以清晰的看到,HorNet模块和Swin transformer模块有着相似的结构,不同的是HorNet中使用到了GnConv这样一个新的模块,GnConv的结构也在上图中给出来了。

下面进行yolo网络的修改。

首先在common.py文件中添加如下代码:

  1. class HorLayerNorm(nn.Module):
  2. def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
  3. super().__init__()
  4. self.weight = nn.Parameter(torch.ones(normalized_shape))
  5. self.bias = nn.Parameter(torch.zeros((normalized_shape)))
  6. self.eps = eps
  7. self.data_format = data_format
  8. if self.data_format not in ["channels_last", "channels_first"]:
  9. raise NotImplementedError #by iscyy/air
  10. self.normalized_shape = (normalized_shape,)
  11. def forward(self, x):
  12. if self.data_format == "channels_last":
  13. return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
  14. elif self.data_format == "channels_first":
  15. u = x.mean(1, keepdim=True)
  16. s = (x - u).pow(2).mean(1, keepdim=True)
  17. x = (x - u)/torch.sqrt(s + self.eps)
  18. x = self.weight[:, None, None] * x + self.bias[:, None, None]
  19. return x
  20. class GlobalLocalFilter(nn.Module):
  21. def __init__(self, dim, h=14, w=8):
  22. super().__init__()
  23. self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2)
  24. self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
  25. trunc_normal_(self.complex_weight, std=0.02)
  26. self.pre_norm = HorLayerNorm(dim, eps=1e-6, data_format='channels_first')
  27. self.post_norm = HorLayerNorm(dim, eps=1e-6, data_format='channels_first')
  28. def forward(self, x):
  29. x = self.pre_norm(x)
  30. x1, x2 = torch.chunk(x, 2, dim=1)
  31. x1 = self.dw(x1)
  32. x2 = x2.to(torch.float32)
  33. B, C, a, b = x2.shape
  34. x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
  35. weight = self.complex_weight
  36. if not weight.shape[1:3] == x2.shape[2:4]:
  37. weight = F.interpolate(weight.permute(3, 0, 1, 2), size=x2.shape[2:4], mode='bilinear', align_corners=True).permute(1,2,3,0)
  38. weight = torch.view_as_complex(weight.contiguous())
  39. x2 = x2*weight
  40. x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2,3), norm='ortho')
  41. x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(B,2*C,a,b)
  42. x = self.post_norm(x)
  43. return x
  44. class gnconv(nn.Module):
  45. def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
  46. super().__init__()
  47. self.order = order
  48. self.dims = [dim//2**i for i in range(order)]
  49. self.dims.reverse()
  50. self.proj_in = nn.Conv2d(dim, 2*dim, 1)
  51. if gflayer is None:
  52. self.dwconv = get_dwconv(sum(self.dims), 7, True)
  53. else:
  54. self.dwconv = gflayer(sum(self.dims), h=h, w=w)
  55. self.proj_out = nn.Conv2d(dim, dim, 1)
  56. self.pws = nn.ModuleList(
  57. [nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)]
  58. )
  59. self.scale = s
  60. def forward(self, x, mask=None, dummy=False):
  61. # B, C, H, W = x.shape gnconv [512]by iscyy/air
  62. fused_x = self.proj_in(x)
  63. pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
  64. dw_abc = self.dwconv(abc) * self.scale
  65. dw_list = torch.split(dw_abc, self.dims, dim=1)
  66. x = pwa* dw_list[0]
  67. for i in range(self.order-1):
  68. x = self.pws[i](x) * dw_list[i+1]
  69. x = self.proj_out(x)
  70. return x
  71. def get_dwconv(dim, kernel, bias):
  72. return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2, bias=bias, groups=dim)
  73. class HorBlock(nn.Module):
  74. def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6, gnconv=gnconv):
  75. super().__init__()
  76. self.norm1 = HorLayerNorm(dim, eps=1e-6, data_format='channels_first')
  77. self.gnconv = gnconv(dim)
  78. self.norm2 = HorLayerNorm(dim, eps=1e-6)
  79. self.pwconv1 = nn.Linear(dim, 4*dim)
  80. self.act = nn.GELU()
  81. self.pwconv2 = nn.Linear(4 * dim, dim)
  82. self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim),
  83. requires_grad=True) if layer_scale_init_value >0 else None
  84. self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones(dim),
  85. requires_grad=True) if layer_scale_init_value > 0 else None
  86. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
  87. def forward(self, x):
  88. B, C, H, W = x.shape
  89. if self.gamma1 is not None:
  90. gamma1 = self.gamma1.view(C, 1, 1)
  91. else:
  92. gamma1 = 1
  93. x = x + self.drop_path(gamma1 * self.gnconv(self.norm1(x)))
  94. input = x
  95. x = x.permute(0, 2, 3, 1)
  96. x = self.norm2(x)
  97. x = self.pwconv1(x)
  98. x = self.act(x)
  99. x = self.pwconv2(x)
  100. if self.gamma2 is not None:
  101. x = self.gamma2 * x
  102. x = x.permute(0, 3, 1, 2)
  103. x = input + self.drop_path(x)
  104. return x

然后在yolo.py中找到parse_model函数,对HorBlock类进行声明。声明的位置如下图所示。可通过Ctrl+F的形式找关键词找到这部分代码。

最后就是针对配置文件.yaml进行修改。修改后的代码如下:

  1. # parameters
  2. nc: 2 # number of classes
  3. depth_multiple: 0.33 # model depth multiple
  4. width_multiple: 0.50 # layer channel multiple
  5. # anchors
  6. anchors:
  7. - [10,13, 16,30, 33,23] # P3/8
  8. - [30,61, 62,45, 59,119] # P4/16
  9. - [116,90, 156,198, 373,326] # P5/32
  10. # YOLOv5 v6.0 backbone
  11. backbone:
  12. # [from, number, module, args]
  13. [[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
  14. [-1, 1, Conv, [128, 3, 2]], # 1-P2/4
  15. [-1, 3, HorBlock, [128]],
  16. [-1, 1, Conv, [256, 3, 2]], # 3-P3/8
  17. [-1, 6, HorBlock, [256]],
  18. [-1, 1, Conv, [512, 3, 2]], # 5-P4/16
  19. [-1, 9, HorBlock, [512]],
  20. [-1, 1, Conv, [1024, 3, 2]], # 7-P5/32
  21. [-1, 3, HorBlock, [1024]],
  22. [-1, 1, SPPF, [1024, 5]], # 9
  23. ]
  24. # YOLOv5 v6.0 head
  25. head:
  26. [[-1, 1, Conv, [512, 1, 1]],
  27. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  28. [[-1, 6], 1, Concat, [1]], # cat backbone P4
  29. [-1, 3, C3, [512, False]], # 13
  30. [-1, 1, Conv, [256, 1, 1]],
  31. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  32. [[-1, 4], 1, Concat, [1]], # cat backbone P3
  33. [-1, 3, C3, [256, False]], # 17 (P3/8-small)
  34. [-1, 1, Conv, [256, 3, 2]],
  35. [[-1, 14], 1, Concat, [1]], # cat head P4
  36. [-1, 3, C3, [512, False]], # 20 (P4/16-medium)
  37. [-1, 1, Conv, [512, 3, 2]],
  38. [[-1, 10], 1, Concat, [1]], # cat head P5
  39. [-1, 3, C3, [1024, False]], # 23 (P5/32-large)
  40. [[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)
  41. ]

其中的nc值得根据你任务的类别来进行修改。

以上就是整体的修改过程,整体网络的结构如下:

在一个小数据集上跑了150轮后的结果如下,单从这个结果上来看效果并不好,不如原始的网络的效果好。从结果反推原因,可能是将所有的C3模块都进行替换并不是一个好的选择。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/weixin_40725706/article/detail/93992
推荐阅读
相关标签
  

闽ICP备14008679号