当前位置:   article > 正文

YOLOv9独家改进|动态蛇形卷积Dynamic Snake Convolution与RepNCSPELAN4融合_repncsp c2f

repncsp c2f


专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!!


一、改进点介绍

        Dynamic Snake Convolution是一种针对细长微弱的局部结构特征与复杂多变的全局形态特征设计的卷积模块。

        RepNCSPELAN4是YOLOv9中的特征提取模块,类似YOLOv5和v8中的C2f与C3模块。


二、RepNCSPELAN4Dynamic模块详解

 2.1 模块简介

       RepNCSPELAN4Dynamic的主要思想:  使用Dynamic Snake Convolution与RepNCSPELAN4中融合。


三、 RepNCSPELAN4Dynamic模块使用教程

3.1 RepNCSPELAN4Dynamic模块的代码

  1. class DySnakeConv(nn.Module):
  2. def __init__(self, inc, ouc, k=3) -> None:
  3. super().__init__()
  4. c_ = ouc//3
  5. self.conv_0 = Conv(inc, ouc-2*c_, k)
  6. self.conv_x = DSConv(inc, c_, 0, k)
  7. self.conv_y = DSConv(inc, c_, 1, k)
  8. def forward(self, x):
  9. return torch.cat([self.conv_0(x), self.conv_x(x), self.conv_y(x)], dim=1)
  10. class DSConv(nn.Module):
  11. def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1):
  12. """
  13. The Dynamic Snake Convolution
  14. :param in_ch: input channel
  15. :param out_ch: output channel
  16. :param kernel_size: the size of kernel
  17. :param extend_scope: the range to expand (default 1 for this method)
  18. :param morph: the morphology of the convolution kernel is mainly divided into two types
  19. along the x-axis (0) and the y-axis (1) (see the paper for details)
  20. :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
  21. """
  22. super(DSConv, self).__init__()
  23. # use the <offset_conv> to learn the deformable offset
  24. self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
  25. self.bn = nn.BatchNorm2d(2 * kernel_size)
  26. self.kernel_size = kernel_size
  27. # two types of the DSConv (along x-axis and y-axis)
  28. self.dsc_conv_x = nn.Conv2d(
  29. in_ch,
  30. out_ch,
  31. kernel_size=(kernel_size, 1),
  32. stride=(kernel_size, 1),
  33. padding=0,
  34. )
  35. self.dsc_conv_y = nn.Conv2d(
  36. in_ch,
  37. out_ch,
  38. kernel_size=(1, kernel_size),
  39. stride=(1, kernel_size),
  40. padding=0,
  41. )
  42. self.gn = nn.GroupNorm([i + 1 for i in range(4) if out_ch % (i+1) == 0][-1], out_ch)
  43. self.act = Conv.default_act
  44. self.extend_scope = extend_scope
  45. self.morph = morph
  46. self.if_offset = if_offset
  47. def forward(self, f):
  48. offset = self.offset_conv(f)
  49. offset = self.bn(offset)
  50. # We need a range of deformation between -1 and 1 to mimic the snake's swing
  51. offset = torch.tanh(offset)
  52. input_shape = f.shape
  53. dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
  54. deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
  55. if self.morph == 0:
  56. x = self.dsc_conv_x(deformed_feature.type(f.dtype))
  57. x = self.gn(x)
  58. x = self.act(x)
  59. return x
  60. else:
  61. x = self.dsc_conv_y(deformed_feature.type(f.dtype))
  62. x = self.gn(x)
  63. x = self.act(x)
  64. return x
  65. class DSC(object):
  66. def __init__(self, input_shape, kernel_size, extend_scope, morph):
  67. self.num_points = kernel_size
  68. self.width = input_shape[2]
  69. self.height = input_shape[3]
  70. self.morph = morph
  71. self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope
  72. # define feature map shape
  73. """
  74. B: Batch size C: Channel W: Width H: Height
  75. """
  76. self.num_batch = input_shape[0]
  77. self.num_channels = input_shape[1]
  78. """
  79. input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
  80. output_x: [B,1,W,K*H] coordinate map
  81. output_y: [B,1,K*W,H] coordinate map
  82. """
  83. def _coordinate_map_3D(self, offset, if_offset):
  84. device = offset.device
  85. # offset
  86. y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
  87. y_center = torch.arange(0, self.width).repeat([self.height])
  88. y_center = y_center.reshape(self.height, self.width)
  89. y_center = y_center.permute(1, 0)
  90. y_center = y_center.reshape([-1, self.width, self.height])
  91. y_center = y_center.repeat([self.num_points, 1, 1]).float()
  92. y_center = y_center.unsqueeze(0)
  93. x_center = torch.arange(0, self.height).repeat([self.width])
  94. x_center = x_center.reshape(self.width, self.height)
  95. x_center = x_center.permute(0, 1)
  96. x_center = x_center.reshape([-1, self.width, self.height])
  97. x_center = x_center.repeat([self.num_points, 1, 1]).float()
  98. x_center = x_center.unsqueeze(0)
  99. if self.morph == 0:
  100. """
  101. Initialize the kernel and flatten the kernel
  102. y: only need 0
  103. x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  104. !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
  105. """
  106. y = torch.linspace(0, 0, 1)
  107. x = torch.linspace(
  108. -int(self.num_points // 2),
  109. int(self.num_points // 2),
  110. int(self.num_points),
  111. )
  112. y, x = torch.meshgrid(y, x)
  113. y_spread = y.reshape(-1, 1)
  114. x_spread = x.reshape(-1, 1)
  115. y_grid = y_spread.repeat([1, self.width * self.height])
  116. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  117. y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H]
  118. x_grid = x_spread.repeat([1, self.width * self.height])
  119. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  120. x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H]
  121. y_new = y_center + y_grid
  122. x_new = x_center + x_grid
  123. y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device)
  124. x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device)
  125. y_offset_new = y_offset.detach().clone()
  126. if if_offset:
  127. y_offset = y_offset.permute(1, 0, 2, 3)
  128. y_offset_new = y_offset_new.permute(1, 0, 2, 3)
  129. center = int(self.num_points // 2)
  130. # The center position remains unchanged and the rest of the positions begin to swing
  131. # This part is quite simple. The main idea is that "offset is an iterative process"
  132. y_offset_new[center] = 0
  133. for index in range(1, center):
  134. y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
  135. y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
  136. y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device)
  137. y_new = y_new.add(y_offset_new.mul(self.extend_scope))
  138. y_new = y_new.reshape(
  139. [self.num_batch, self.num_points, 1, self.width, self.height])
  140. y_new = y_new.permute(0, 3, 1, 4, 2)
  141. y_new = y_new.reshape([
  142. self.num_batch, self.num_points * self.width, 1 * self.height
  143. ])
  144. x_new = x_new.reshape(
  145. [self.num_batch, self.num_points, 1, self.width, self.height])
  146. x_new = x_new.permute(0, 3, 1, 4, 2)
  147. x_new = x_new.reshape([
  148. self.num_batch, self.num_points * self.width, 1 * self.height
  149. ])
  150. return y_new, x_new
  151. else:
  152. """
  153. Initialize the kernel and flatten the kernel
  154. y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  155. x: only need 0
  156. """
  157. y = torch.linspace(
  158. -int(self.num_points // 2),
  159. int(self.num_points // 2),
  160. int(self.num_points),
  161. )
  162. x = torch.linspace(0, 0, 1)
  163. y, x = torch.meshgrid(y, x)
  164. y_spread = y.reshape(-1, 1)
  165. x_spread = x.reshape(-1, 1)
  166. y_grid = y_spread.repeat([1, self.width * self.height])
  167. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  168. y_grid = y_grid.unsqueeze(0)
  169. x_grid = x_spread.repeat([1, self.width * self.height])
  170. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  171. x_grid = x_grid.unsqueeze(0)
  172. y_new = y_center + y_grid
  173. x_new = x_center + x_grid
  174. y_new = y_new.repeat(self.num_batch, 1, 1, 1)
  175. x_new = x_new.repeat(self.num_batch, 1, 1, 1)
  176. y_new = y_new.to(device)
  177. x_new = x_new.to(device)
  178. x_offset_new = x_offset.detach().clone()
  179. if if_offset:
  180. x_offset = x_offset.permute(1, 0, 2, 3)
  181. x_offset_new = x_offset_new.permute(1, 0, 2, 3)
  182. center = int(self.num_points // 2)
  183. x_offset_new[center] = 0
  184. for index in range(1, center):
  185. x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
  186. x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
  187. x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device)
  188. x_new = x_new.add(x_offset_new.mul(self.extend_scope))
  189. y_new = y_new.reshape(
  190. [self.num_batch, 1, self.num_points, self.width, self.height])
  191. y_new = y_new.permute(0, 3, 1, 4, 2)
  192. y_new = y_new.reshape([
  193. self.num_batch, 1 * self.width, self.num_points * self.height
  194. ])
  195. x_new = x_new.reshape(
  196. [self.num_batch, 1, self.num_points, self.width, self.height])
  197. x_new = x_new.permute(0, 3, 1, 4, 2)
  198. x_new = x_new.reshape([
  199. self.num_batch, 1 * self.width, self.num_points * self.height
  200. ])
  201. return y_new, x_new
  202. """
  203. input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
  204. output: [N,1,K*D,K*W,K*H] deformed feature map
  205. """
  206. def _bilinear_interpolate_3D(self, input_feature, y, x):
  207. device = input_feature.device
  208. y = y.reshape([-1]).float()
  209. x = x.reshape([-1]).float()
  210. zero = torch.zeros([]).int()
  211. max_y = self.width - 1
  212. max_x = self.height - 1
  213. # find 8 grid locations
  214. y0 = torch.floor(y).int()
  215. y1 = y0 + 1
  216. x0 = torch.floor(x).int()
  217. x1 = x0 + 1
  218. # clip out coordinates exceeding feature map volume
  219. y0 = torch.clamp(y0, zero, max_y)
  220. y1 = torch.clamp(y1, zero, max_y)
  221. x0 = torch.clamp(x0, zero, max_x)
  222. x1 = torch.clamp(x1, zero, max_x)
  223. input_feature_flat = input_feature.flatten()
  224. input_feature_flat = input_feature_flat.reshape(
  225. self.num_batch, self.num_channels, self.width, self.height)
  226. input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
  227. input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
  228. dimension = self.height * self.width
  229. base = torch.arange(self.num_batch) * dimension
  230. base = base.reshape([-1, 1]).float()
  231. repeat = torch.ones([self.num_points * self.width * self.height
  232. ]).unsqueeze(0)
  233. repeat = repeat.float()
  234. base = torch.matmul(base, repeat)
  235. base = base.reshape([-1])
  236. base = base.to(device)
  237. base_y0 = base + y0 * self.height
  238. base_y1 = base + y1 * self.height
  239. # top rectangle of the neighbourhood volume
  240. index_a0 = base_y0 - base + x0
  241. index_c0 = base_y0 - base + x1
  242. # bottom rectangle of the neighbourhood volume
  243. index_a1 = base_y1 - base + x0
  244. index_c1 = base_y1 - base + x1
  245. # get 8 grid values
  246. value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
  247. value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
  248. value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
  249. value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)
  250. # find 8 grid locations
  251. y0 = torch.floor(y).int()
  252. y1 = y0 + 1
  253. x0 = torch.floor(x).int()
  254. x1 = x0 + 1
  255. # clip out coordinates exceeding feature map volume
  256. y0 = torch.clamp(y0, zero, max_y + 1)
  257. y1 = torch.clamp(y1, zero, max_y + 1)
  258. x0 = torch.clamp(x0, zero, max_x + 1)
  259. x1 = torch.clamp(x1, zero, max_x + 1)
  260. x0_float = x0.float()
  261. x1_float = x1.float()
  262. y0_float = y0.float()
  263. y1_float = y1.float()
  264. vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)
  265. vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)
  266. vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)
  267. vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)
  268. outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
  269. value_c1 * vol_c1)
  270. if self.morph == 0:
  271. outputs = outputs.reshape([
  272. self.num_batch,
  273. self.num_points * self.width,
  274. 1 * self.height,
  275. self.num_channels,
  276. ])
  277. outputs = outputs.permute(0, 3, 1, 2)
  278. else:
  279. outputs = outputs.reshape([
  280. self.num_batch,
  281. 1 * self.width,
  282. self.num_points * self.height,
  283. self.num_channels,
  284. ])
  285. outputs = outputs.permute(0, 3, 1, 2)
  286. return outputs
  287. def deform_conv(self, input, offset, if_offset):
  288. y, x = self._coordinate_map_3D(offset, if_offset)
  289. deformed_feature = self._bilinear_interpolate_3D(input, y, x)
  290. return deformed_feature
  291. class RepNBottleneck_DySnakeConv(RepNBottleneck):
  292. # Standard bottleneck
  293. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, kernels, groups, expand
  294. super().__init__(c1, c2, shortcut, g, k, e)
  295. c_ = int(c2 * e) # hidden channels
  296. self.cv1 = RepConvN(c1, c_, k[0], 1)
  297. self.cv2 = Conv(c_, c2, k[1], s=1, g=g)
  298. self.add = shortcut and c1 == c2
  299. class RepNCSP_DySnakeConv(RepNCSP):
  300. # CSP Bottleneck with 3 convolutions
  301. def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  302. super().__init__(c1, c2, n, shortcut, g, e)
  303. c_ = int(c2 * e) # hidden channels
  304. self.cv1 = DySnakeConv(c1, c_)
  305. self.cv2 = DySnakeConv(c1, c_)
  306. self.cv3 = DySnakeConv(2 * c_, c2) # optional act=FReLU(c2)
  307. self.m = nn.Sequential(*(RepNBottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
  308. class RepNCSPELAN4DySnakeConv(RepNCSPELAN4):
  309. # csp-elan
  310. def __init__(self, c1, c2, c3, c4, c5=1): # ch_in, ch_out, number, shortcut, groups, expansion
  311. super().__init__(c1, c2, c3, c4, c5)
  312. self.cv1 = Conv(c1, c3, k=1, s=1)
  313. self.cv2 = nn.Sequential(RepNCSP_DySnakeConv(c3 // 2, c4, c5), DySnakeConv(c4, c4, 3))
  314. self.cv3 = nn.Sequential(RepNCSP_DySnakeConv(c4, c4, c5), DySnakeConv(c4, c4, 3))
  315. self.cv4 = Conv(c3 + (2 * c4), c2, 1, 1)

3.2 在YOlO v9中的添加教程

阅读YOLOv9添加模块教程或使用下文操作

        1. 将YOLOv9工程中models下common.py文件中的最下行(否则可能因类继承报错)增加模块的代码。

         2. 将YOLOv9工程中models下yolo.py文件中的第681行(可能因版本变化而变化)增加以下代码。

            RepNCSPELAN4, SPPELAN, RepNCSPELAN4DySnakeConv}:

3.3 运行配置文件

  1. # YOLOv9
  2. # parameters
  3. nc: 80 # number of classes
  4. depth_multiple: 1.0 # model depth multiple
  5. width_multiple: 1.0 # layer channel multiple
  6. #activation: nn.LeakyReLU(0.1)
  7. #activation: nn.ReLU()
  8. # anchors
  9. anchors: 3
  10. # YOLOv9 backbone
  11. backbone:
  12. [
  13. [-1, 1, Silence, []],
  14. # conv down
  15. [-1, 1, Conv, [64, 3, 2]], # 1-P1/2
  16. # conv down
  17. [-1, 1, Conv, [128, 3, 2]], # 2-P2/4
  18. # elan-1 block
  19. [-1, 1, RepNCSPELAN4DySnakeConv, [256, 128, 64, 1]], # 3
  20. # avg-conv down
  21. [-1, 1, ADown, [256]], # 4-P3/8
  22. # elan-2 block
  23. [-1, 1, RepNCSPELAN4DySnakeConv, [512, 256, 128, 1]], # 5
  24. # avg-conv down
  25. [-1, 1, ADown, [512]], # 6-P4/16
  26. # elan-2 block
  27. [-1, 1, RepNCSPELAN4DySnakeConv, [512, 512, 256, 1]], # 7
  28. # avg-conv down
  29. [-1, 1, ADown, [512]], # 8-P5/32
  30. # elan-2 block
  31. [-1, 1, RepNCSPELAN4DySnakeConv, [512, 512, 256, 1]], # 9
  32. ]
  33. # YOLOv9 head
  34. head:
  35. [
  36. # elan-spp block
  37. [-1, 1, SPPELAN, [512, 256]], # 10
  38. # up-concat merge
  39. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  40. [[-1, 7], 1, Concat, [1]], # cat backbone P4
  41. # elan-2 block
  42. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 13
  43. # up-concat merge
  44. [-1, 1, nn.Upsample, [None, 2, 'nearest']],
  45. [[-1, 5], 1, Concat, [1]], # cat backbone P3
  46. # elan-2 block
  47. [-1, 1, RepNCSPELAN4, [256, 256, 128, 1]], # 16 (P3/8-small)
  48. # avg-conv-down merge
  49. [-1, 1, ADown, [256]],
  50. [[-1, 13], 1, Concat, [1]], # cat head P4
  51. # elan-2 block
  52. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 19 (P4/16-medium)
  53. # avg-conv-down merge
  54. [-1, 1, ADown, [512]],
  55. [[-1, 10], 1, Concat, [1]], # cat head P5
  56. # elan-2 block
  57. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 22 (P5/32-large)
  58. # multi-level reversible auxiliary branch
  59. # routing
  60. [5, 1, CBLinear, [[256]]], # 23
  61. [7, 1, CBLinear, [[256, 512]]], # 24
  62. [9, 1, CBLinear, [[256, 512, 512]]], # 25
  63. # conv down
  64. [0, 1, Conv, [64, 3, 2]], # 26-P1/2
  65. # conv down
  66. [-1, 1, Conv, [128, 3, 2]], # 27-P2/4
  67. # elan-1 block
  68. [-1, 1, RepNCSPELAN4, [256, 128, 64, 1]], # 28
  69. # avg-conv down fuse
  70. [-1, 1, ADown, [256]], # 29-P3/8
  71. [[23, 24, 25, -1], 1, CBFuse, [[0, 0, 0]]], # 30
  72. # elan-2 block
  73. [-1, 1, RepNCSPELAN4, [512, 256, 128, 1]], # 31
  74. # avg-conv down fuse
  75. [-1, 1, ADown, [512]], # 32-P4/16
  76. [[24, 25, -1], 1, CBFuse, [[1, 1]]], # 33
  77. # elan-2 block
  78. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 34
  79. # avg-conv down fuse
  80. [-1, 1, ADown, [512]], # 35-P5/32
  81. [[25, -1], 1, CBFuse, [[2]]], # 36
  82. # elan-2 block
  83. [-1, 1, RepNCSPELAN4, [512, 512, 256, 1]], # 37
  84. # detection head
  85. # detect
  86. [[31, 34, 37, 16, 19, 22], 1, DualDDetect, [nc]], # DualDDetect(A3, A4, A5, P3, P4, P5)
  87. ]

3.4 训练过程


欢迎关注!


声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号