当前位置:   article > 正文

YOLOv8改进之C2f-DySnakeConv(动态蛇形卷积Dynamic Snake Convolution)


一、动态蛇形卷积Dynamic Snake Convolution论文

论文地址:2307.08388.pdf (arxiv.org)


   动态蛇形卷积(Dynamic Snake Convolution)的设计灵感来源于蛇形的形状,用于改善对目标形状和边界的敏感性。能够帮助神经网络更好地捕捉目标的形状信息,特别是对于复杂的或不规则形状的目标。通过引入动态的、可变形的卷积核来实现这一目标。这种可变形的卷积核能够根据目标的形状和边界信息进行调整,从而更好地适应目标的特定形状。






YOLOv8改进之C2f-DBB(C2f模块中融合多元分支模块Diverse Branch Block )-CSDN博客

  1. import torch
  2. import torch.nn as nn
  3. from ..modules.conv import Conv
  4. __all__ = ['DySnakeConv']
  5. class DySnakeConv(nn.Module):
  6. def __init__(self, inc, ouc, k=3) -> None:
  7. super().__init__()
  8. self.conv_0 = Conv(inc, ouc, k)
  9. self.conv_x = DSConv(inc, ouc, 0, k)
  10. self.conv_y = DSConv(inc, ouc, 1, k)
  11. def forward(self, x):
  12. return torch.cat([self.conv_0(x), self.conv_x(x), self.conv_y(x)], dim=1)
  13. class DSConv(nn.Module):
  14. def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1):
  15. """
  16. The Dynamic Snake Convolution
  17. :param in_ch: input channel
  18. :param out_ch: output channel
  19. :param kernel_size: the size of kernel
  20. :param extend_scope: the range to expand (default 1 for this method)
  21. :param morph: the morphology of the convolution kernel is mainly divided into two types
  22. along the x-axis (0) and the y-axis (1) (see the paper for details)
  23. :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
  24. """
  25. super(DSConv, self).__init__()
  26. # use the <offset_conv> to learn the deformable offset
  27. self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
  28. self.bn = nn.BatchNorm2d(2 * kernel_size)
  29. self.kernel_size = kernel_size
  30. # two types of the DSConv (along x-axis and y-axis)
  31. self.dsc_conv_x = nn.Conv2d(
  32. in_ch,
  33. out_ch,
  34. kernel_size=(kernel_size, 1),
  35. stride=(kernel_size, 1),
  36. padding=0,
  37. )
  38. self.dsc_conv_y = nn.Conv2d(
  39. in_ch,
  40. out_ch,
  41. kernel_size=(1, kernel_size),
  42. stride=(1, kernel_size),
  43. padding=0,
  44. )
  45. self.gn = nn.GroupNorm(out_ch // 4, out_ch)
  46. self.act = Conv.default_act
  47. self.extend_scope = extend_scope
  48. self.morph = morph
  49. self.if_offset = if_offset
  50. def forward(self, f):
  51. offset = self.offset_conv(f)
  52. offset = self.bn(offset)
  53. # We need a range of deformation between -1 and 1 to mimic the snake's swing
  54. offset = torch.tanh(offset)
  55. input_shape = f.shape
  56. dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
  57. deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
  58. if self.morph == 0:
  59. x = self.dsc_conv_x(deformed_feature.type(f.dtype))
  60. x = self.gn(x)
  61. x = self.act(x)
  62. return x
  63. else:
  64. x = self.dsc_conv_y(deformed_feature.type(f.dtype))
  65. x = self.gn(x)
  66. x = self.act(x)
  67. return x
  68. # Core code, for ease of understanding, we mark the dimensions of input and output next to the code
  69. class DSC(object):
  70. def __init__(self, input_shape, kernel_size, extend_scope, morph):
  71. self.num_points = kernel_size
  72. self.width = input_shape[2]
  73. self.height = input_shape[3]
  74. self.morph = morph
  75. self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope
  76. # define feature map shape
  77. """
  78. B: Batch size C: Channel W: Width H: Height
  79. """
  80. self.num_batch = input_shape[0]
  81. self.num_channels = input_shape[1]
  82. """
  83. input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
  84. output_x: [B,1,W,K*H] coordinate map
  85. output_y: [B,1,K*W,H] coordinate map
  86. """
  87. def _coordinate_map_3D(self, offset, if_offset):
  88. device = offset.device
  89. # offset
  90. y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
  91. y_center = torch.arange(0, self.width).repeat([self.height])
  92. y_center = y_center.reshape(self.height, self.width)
  93. y_center = y_center.permute(1, 0)
  94. y_center = y_center.reshape([-1, self.width, self.height])
  95. y_center = y_center.repeat([self.num_points, 1, 1]).float()
  96. y_center = y_center.unsqueeze(0)
  97. x_center = torch.arange(0, self.height).repeat([self.width])
  98. x_center = x_center.reshape(self.width, self.height)
  99. x_center = x_center.permute(0, 1)
  100. x_center = x_center.reshape([-1, self.width, self.height])
  101. x_center = x_center.repeat([self.num_points, 1, 1]).float()
  102. x_center = x_center.unsqueeze(0)
  103. if self.morph == 0:
  104. """
  105. Initialize the kernel and flatten the kernel
  106. y: only need 0
  107. x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  108. !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
  109. """
  110. y = torch.linspace(0, 0, 1)
  111. x = torch.linspace(
  112. -int(self.num_points // 2),
  113. int(self.num_points // 2),
  114. int(self.num_points),
  115. )
  116. y, x = torch.meshgrid(y, x)
  117. y_spread = y.reshape(-1, 1)
  118. x_spread = x.reshape(-1, 1)
  119. y_grid = y_spread.repeat([1, self.width * self.height])
  120. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  121. y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H]
  122. x_grid = x_spread.repeat([1, self.width * self.height])
  123. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  124. x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H]
  125. y_new = y_center + y_grid
  126. x_new = x_center + x_grid
  127. y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device)
  128. x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device)
  129. y_offset_new = y_offset.detach().clone()
  130. if if_offset:
  131. y_offset = y_offset.permute(1, 0, 2, 3)
  132. y_offset_new = y_offset_new.permute(1, 0, 2, 3)
  133. center = int(self.num_points // 2)
  134. # The center position remains unchanged and the rest of the positions begin to swing
  135. # This part is quite simple. The main idea is that "offset is an iterative process"
  136. y_offset_new[center] = 0
  137. for index in range(1, center):
  138. y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
  139. y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
  140. y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device)
  141. y_new = y_new.add(y_offset_new.mul(self.extend_scope))
  142. y_new = y_new.reshape(
  143. [self.num_batch, self.num_points, 1, self.width, self.height])
  144. y_new = y_new.permute(0, 3, 1, 4, 2)
  145. y_new = y_new.reshape([
  146. self.num_batch, self.num_points * self.width, 1 * self.height
  147. ])
  148. x_new = x_new.reshape(
  149. [self.num_batch, self.num_points, 1, self.width, self.height])
  150. x_new = x_new.permute(0, 3, 1, 4, 2)
  151. x_new = x_new.reshape([
  152. self.num_batch, self.num_points * self.width, 1 * self.height
  153. ])
  154. return y_new, x_new
  155. else:
  156. """
  157. Initialize the kernel and flatten the kernel
  158. y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  159. x: only need 0
  160. """
  161. y = torch.linspace(
  162. -int(self.num_points // 2),
  163. int(self.num_points // 2),
  164. int(self.num_points),
  165. )
  166. x = torch.linspace(0, 0, 1)
  167. y, x = torch.meshgrid(y, x)
  168. y_spread = y.reshape(-1, 1)
  169. x_spread = x.reshape(-1, 1)
  170. y_grid = y_spread.repeat([1, self.width * self.height])
  171. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  172. y_grid = y_grid.unsqueeze(0)
  173. x_grid = x_spread.repeat([1, self.width * self.height])
  174. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  175. x_grid = x_grid.unsqueeze(0)
  176. y_new = y_center + y_grid
  177. x_new = x_center + x_grid
  178. y_new = y_new.repeat(self.num_batch, 1, 1, 1)
  179. x_new = x_new.repeat(self.num_batch, 1, 1, 1)
  180. y_new = y_new.to(device)
  181. x_new = x_new.to(device)
  182. x_offset_new = x_offset.detach().clone()
  183. if if_offset:
  184. x_offset = x_offset.permute(1, 0, 2, 3)
  185. x_offset_new = x_offset_new.permute(1, 0, 2, 3)
  186. center = int(self.num_points // 2)
  187. x_offset_new[center] = 0
  188. for index in range(1, center):
  189. x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
  190. x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
  191. x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device)
  192. x_new = x_new.add(x_offset_new.mul(self.extend_scope))
  193. y_new = y_new.reshape(
  194. [self.num_batch, 1, self.num_points, self.width, self.height])
  195. y_new = y_new.permute(0, 3, 1, 4, 2)
  196. y_new = y_new.reshape([
  197. self.num_batch, 1 * self.width, self.num_points * self.height
  198. ])
  199. x_new = x_new.reshape(
  200. [self.num_batch, 1, self.num_points, self.width, self.height])
  201. x_new = x_new.permute(0, 3, 1, 4, 2)
  202. x_new = x_new.reshape([
  203. self.num_batch, 1 * self.width, self.num_points * self.height
  204. ])
  205. return y_new, x_new
  206. """
  207. input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
  208. output: [N,1,K*D,K*W,K*H] deformed feature map
  209. """
  210. def _bilinear_interpolate_3D(self, input_feature, y, x):
  211. device = input_feature.device
  212. y = y.reshape([-1]).float()
  213. x = x.reshape([-1]).float()
  214. zero = torch.zeros([]).int()
  215. max_y = self.width - 1
  216. max_x = self.height - 1
  217. # find 8 grid locations
  218. y0 = torch.floor(y).int()
  219. y1 = y0 + 1
  220. x0 = torch.floor(x).int()
  221. x1 = x0 + 1
  222. # clip out coordinates exceeding feature map volume
  223. y0 = torch.clamp(y0, zero, max_y)
  224. y1 = torch.clamp(y1, zero, max_y)
  225. x0 = torch.clamp(x0, zero, max_x)
  226. x1 = torch.clamp(x1, zero, max_x)
  227. input_feature_flat = input_feature.flatten()
  228. input_feature_flat = input_feature_flat.reshape(
  229. self.num_batch, self.num_channels, self.width, self.height)
  230. input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
  231. input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
  232. dimension = self.height * self.width
  233. base = torch.arange(self.num_batch) * dimension
  234. base = base.reshape([-1, 1]).float()
  235. repeat = torch.ones([self.num_points * self.width * self.height
  236. ]).unsqueeze(0)
  237. repeat = repeat.float()
  238. base = torch.matmul(base, repeat)
  239. base = base.reshape([-1])
  240. base = base.to(device)
  241. base_y0 = base + y0 * self.height
  242. base_y1 = base + y1 * self.height
  243. # top rectangle of the neighbourhood volume
  244. index_a0 = base_y0 - base + x0
  245. index_c0 = base_y0 - base + x1
  246. # bottom rectangle of the neighbourhood volume
  247. index_a1 = base_y1 - base + x0
  248. index_c1 = base_y1 - base + x1
  249. # get 8 grid values
  250. value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
  251. value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
  252. value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
  253. value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)
  254. # find 8 grid locations
  255. y0 = torch.floor(y).int()
  256. y1 = y0 + 1
  257. x0 = torch.floor(x).int()
  258. x1 = x0 + 1
  259. # clip out coordinates exceeding feature map volume
  260. y0 = torch.clamp(y0, zero, max_y + 1)
  261. y1 = torch.clamp(y1, zero, max_y + 1)
  262. x0 = torch.clamp(x0, zero, max_x + 1)
  263. x1 = torch.clamp(x1, zero, max_x + 1)
  264. x0_float = x0.float()
  265. x1_float = x1.float()
  266. y0_float = y0.float()
  267. y1_float = y1.float()
  268. vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)
  269. vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)
  270. vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)
  271. vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)
  272. outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
  273. value_c1 * vol_c1)
  274. if self.morph == 0:
  275. outputs = outputs.reshape([
  276. self.num_batch,
  277. self.num_points * self.width,
  278. 1 * self.height,
  279. self.num_channels,
  280. ])
  281. outputs = outputs.permute(0, 3, 1, 2)
  282. else:
  283. outputs = outputs.reshape([
  284. self.num_batch,
  285. 1 * self.width,
  286. self.num_points * self.height,
  287. self.num_channels,
  288. ])
  289. outputs = outputs.permute(0, 3, 1, 2)
  290. return outputs
  291. def deform_conv(self, input, offset, if_offset):
  292. y, x = self._coordinate_map_3D(offset, if_offset)
  293. deformed_feature = self._bilinear_interpolate_3D(input, y, x)
  294. return deformed_feature

在 ultralytics\ultralytics\nn\other_modules\block.py文件中开头声明动态蛇卷积:

  1. from ultralytics.nn.modules import Conv
  2. from .dynamic_snake_conv import DySnakeConv


然后还是在这个block文件中添加C2f DySnakeConv代码:

  1. ####### 添加 C2f DySnakeConv ##########
  2. class Bottleneck_DySnakeConv(Bottleneck):
  3. """Standard bottleneck with DySnakeConv."""
  4. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut, groups, kernels, expand
  5. super().__init__(c1, c2, shortcut, g, k, e)
  6. c_ = int(c2 * e) # hidden channels
  7. self.cv2 = DySnakeConv(c_, c2, k[1])
  8. self.cv3 = Conv(c2 * 3, c2, k=1)
  9. def forward(self, x):
  10. """'forward()' applies the YOLOv5 FPN to input data."""
  11. return x + self.cv3(self.cv2(self.cv1(x))) if self.add else self.cv3(self.cv2(self.cv1(x)))
  12. class C2f_DySnakeConv(C2f):
  13. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5):
  14. super().__init__(c1, c2, n, shortcut, g, e)
  15. self.m = nn.ModuleList(Bottleneck_DySnakeConv(self.c, self.c, shortcut, g, k=(3, 3), e=1.0) for _ in range(n))



  1. from ultralytics.nn.other_modules.block import C2f_DySnakeConv
  2. from ultralytics.nn.other_modules import *

然后在下边的然后def parse_model模块中添加warehouse_manager:在括号里面添加一个参数warehouse_manager:

def parse_model(d, ch, verbose=True, warehouse_manager=None):



  1. if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
  2. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3,
  3. C2f_DBB,C2f_DySnakeConv):
  4. if args[0] == 'head_channel':
  5. args[0] = d[args[0]]
  6. c1, c2 = ch[f], args[0]
  7. if c2 != nc: # if c2 not equal to number of classes (i.e. for Classify() output)
  8. c2 = make_divisible(min(c2, max_channels) * width, 8)
  9. args = [c1, c2, *args[1:]]
  10. if m in (
  11. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3, C2f_DBB,C2f_DySnakeConv):
  12. args.insert(2, n) # number of repeats
  13. n = 1

3、创建YOLOv8+C2f-DySnakeConv的yaml文件 :

  1. # Ultralytics YOLO