当前位置:   article > 正文

YOLOV8改进-添加Dynamic Snake Convolution-动态蛇形卷积_dysnakeconv

dysnakeconv

01 Dynamic Snake Convolution介绍

        主要用于管状结构分割的卷积,分割例如血管、道路等细长连续的结构。 蛇形卷积能够进行精确分割,主要的挑战源于细长微弱的局部结构特征与复杂多变的全局形态特征。所以模型在学习特征的过程中,改变卷积核的形状,从而关注管状结构的核心结构特点。由于管状结构所占比例较小,模型不可避免地失去对相应结构的感知,卷积核完全游离在目标以外。因此我们希望根据管状结构的特点来设计特定的网络结构,从而指导模型关注关键特征。

        本文关注到管状结构细长连续的拓扑特征,并利用这一信息在神经网络以下三个阶段同时增强感知:特征提取、特征融合和损失约束。分别设计了动态蛇形卷积,多视角特征融合策略与连续性拓扑约束损失。我们同时给出了基于 2D 和 3D 的方法设计,通过实验证明了本文所提出的 DSCNet 在管状结构分割任务上提供了更好的精度和连续性。

        论文地址:https://arxiv.org/abs/2307.08388

02 DySnakeConv-动态蛇形卷积代码

        以下代码包括

DySnakeConv--单个的蛇形卷积
DSConv--蛇形卷积用到的偏置卷积
DSC--偏置卷积用到的模块
  1. class DySnakeConv(nn.Module):
  2. def __init__(self, inc, ouc, k=3, act=True) -> None:
  3. super().__init__()
  4. self.conv_0 = Conv(inc, ouc, k, act=act)
  5. self.conv_x = DSConv(inc, ouc, 0, k)
  6. self.conv_y = DSConv(inc, ouc, 1, k)
  7. self.conv_1x1 = Conv(ouc * 3, ouc, 1, act=act)
  8. def forward(self, x):
  9. return self.conv_1x1(torch.cat([self.conv_0(x), self.conv_x(x), self.conv_y(x)], dim=1))
  10. class DSConv(nn.Module):
  11. def __init__(self, in_ch, out_ch, morph, kernel_size=3, if_offset=True, extend_scope=1):
  12. """
  13. The Dynamic Snake Convolution
  14. :param in_ch: input channel
  15. :param out_ch: output channel
  16. :param kernel_size: the size of kernel
  17. :param extend_scope: the range to expand (default 1 for this method)
  18. :param morph: the morphology of the convolution kernel is mainly divided into two types
  19. along the x-axis (0) and the y-axis (1) (see the paper for details)
  20. :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
  21. """
  22. super(DSConv, self).__init__()
  23. # use the <offset_conv> to learn the deformable offset
  24. self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
  25. self.bn = nn.BatchNorm2d(2 * kernel_size)
  26. self.kernel_size = kernel_size
  27. # two types of the DSConv (along x-axis and y-axis)
  28. self.dsc_conv_x = nn.Conv2d(
  29. in_ch,
  30. out_ch,
  31. kernel_size=(kernel_size, 1),
  32. stride=(kernel_size, 1),
  33. padding=0,
  34. )
  35. self.dsc_conv_y = nn.Conv2d(
  36. in_ch,
  37. out_ch,
  38. kernel_size=(1, kernel_size),
  39. stride=(1, kernel_size),
  40. padding=0,
  41. )
  42. self.gn = nn.GroupNorm(out_ch // 4, out_ch)
  43. self.act = Conv.default_act
  44. self.extend_scope = extend_scope
  45. self.morph = morph
  46. self.if_offset = if_offset
  47. def forward(self, f):
  48. offset = self.offset_conv(f)
  49. offset = self.bn(offset)
  50. # We need a range of deformation between -1 and 1 to mimic the snake's swing
  51. offset = torch.tanh(offset)
  52. input_shape = f.shape
  53. dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph)
  54. deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
  55. if self.morph == 0:
  56. x = self.dsc_conv_x(deformed_feature.type(f.dtype))
  57. x = self.gn(x)
  58. x = self.act(x)
  59. return x
  60. else:
  61. x = self.dsc_conv_y(deformed_feature.type(f.dtype))
  62. x = self.gn(x)
  63. x = self.act(x)
  64. return x
  65. # Core code, for ease of understanding, we mark the dimensions of input and output next to the code
  66. class DSC(object):
  67. def __init__(self, input_shape, kernel_size, extend_scope, morph):
  68. self.num_points = kernel_size
  69. self.width = input_shape[2]
  70. self.height = input_shape[3]
  71. self.morph = morph
  72. self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope
  73. # define feature map shape
  74. """
  75. B: Batch size C: Channel W: Width H: Height
  76. """
  77. self.num_batch = input_shape[0]
  78. self.num_channels = input_shape[1]
  79. """
  80. input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
  81. output_x: [B,1,W,K*H] coordinate map
  82. output_y: [B,1,K*W,H] coordinate map
  83. """
  84. def _coordinate_map_3D(self, offset, if_offset):
  85. device = offset.device
  86. # offset
  87. y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
  88. y_center = torch.arange(0, self.width).repeat([self.height])
  89. y_center = y_center.reshape(self.height, self.width)
  90. y_center = y_center.permute(1, 0)
  91. y_center = y_center.reshape([-1, self.width, self.height])
  92. y_center = y_center.repeat([self.num_points, 1, 1]).float()
  93. y_center = y_center.unsqueeze(0)
  94. x_center = torch.arange(0, self.height).repeat([self.width])
  95. x_center = x_center.reshape(self.width, self.height)
  96. x_center = x_center.permute(0, 1)
  97. x_center = x_center.reshape([-1, self.width, self.height])
  98. x_center = x_center.repeat([self.num_points, 1, 1]).float()
  99. x_center = x_center.unsqueeze(0)
  100. if self.morph == 0:
  101. """
  102. Initialize the kernel and flatten the kernel
  103. y: only need 0
  104. x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  105. !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
  106. """
  107. y = torch.linspace(0, 0, 1)
  108. x = torch.linspace(
  109. -int(self.num_points // 2),
  110. int(self.num_points // 2),
  111. int(self.num_points),
  112. )
  113. y, x = torch.meshgrid(y, x)
  114. y_spread = y.reshape(-1, 1)
  115. x_spread = x.reshape(-1, 1)
  116. y_grid = y_spread.repeat([1, self.width * self.height])
  117. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  118. y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H]
  119. x_grid = x_spread.repeat([1, self.width * self.height])
  120. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  121. x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H]
  122. y_new = y_center + y_grid
  123. x_new = x_center + x_grid
  124. y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(device)
  125. x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(device)
  126. y_offset_new = y_offset.detach().clone()
  127. if if_offset:
  128. y_offset = y_offset.permute(1, 0, 2, 3)
  129. y_offset_new = y_offset_new.permute(1, 0, 2, 3)
  130. center = int(self.num_points // 2)
  131. # The center position remains unchanged and the rest of the positions begin to swing
  132. # This part is quite simple. The main idea is that "offset is an iterative process"
  133. y_offset_new[center] = 0
  134. for index in range(1, center):
  135. y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
  136. y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
  137. y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(device)
  138. y_new = y_new.add(y_offset_new.mul(self.extend_scope))
  139. y_new = y_new.reshape(
  140. [self.num_batch, self.num_points, 1, self.width, self.height])
  141. y_new = y_new.permute(0, 3, 1, 4, 2)
  142. y_new = y_new.reshape([
  143. self.num_batch, self.num_points * self.width, 1 * self.height
  144. ])
  145. x_new = x_new.reshape(
  146. [self.num_batch, self.num_points, 1, self.width, self.height])
  147. x_new = x_new.permute(0, 3, 1, 4, 2)
  148. x_new = x_new.reshape([
  149. self.num_batch, self.num_points * self.width, 1 * self.height
  150. ])
  151. return y_new, x_new
  152. else:
  153. """
  154. Initialize the kernel and flatten the kernel
  155. y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  156. x: only need 0
  157. """
  158. y = torch.linspace(
  159. -int(self.num_points // 2),
  160. int(self.num_points // 2),
  161. int(self.num_points),
  162. )
  163. x = torch.linspace(0, 0, 1)
  164. y, x = torch.meshgrid(y, x)
  165. y_spread = y.reshape(-1, 1)
  166. x_spread = x.reshape(-1, 1)
  167. y_grid = y_spread.repeat([1, self.width * self.height])
  168. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  169. y_grid = y_grid.unsqueeze(0)
  170. x_grid = x_spread.repeat([1, self.width * self.height])
  171. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  172. x_grid = x_grid.unsqueeze(0)
  173. y_new = y_center + y_grid
  174. x_new = x_center + x_grid
  175. y_new = y_new.repeat(self.num_batch, 1, 1, 1)
  176. x_new = x_new.repeat(self.num_batch, 1, 1, 1)
  177. y_new = y_new.to(device)
  178. x_new = x_new.to(device)
  179. x_offset_new = x_offset.detach().clone()
  180. if if_offset:
  181. x_offset = x_offset.permute(1, 0, 2, 3)
  182. x_offset_new = x_offset_new.permute(1, 0, 2, 3)
  183. center = int(self.num_points // 2)
  184. x_offset_new[center] = 0
  185. for index in range(1, center):
  186. x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
  187. x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
  188. x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(device)
  189. x_new = x_new.add(x_offset_new.mul(self.extend_scope))
  190. y_new = y_new.reshape(
  191. [self.num_batch, 1, self.num_points, self.width, self.height])
  192. y_new = y_new.permute(0, 3, 1, 4, 2)
  193. y_new = y_new.reshape([
  194. self.num_batch, 1 * self.width, self.num_points * self.height
  195. ])
  196. x_new = x_new.reshape(
  197. [self.num_batch, 1, self.num_points, self.width, self.height])
  198. x_new = x_new.permute(0, 3, 1, 4, 2)
  199. x_new = x_new.reshape([
  200. self.num_batch, 1 * self.width, self.num_points * self.height
  201. ])
  202. return y_new, x_new
  203. """
  204. input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
  205. output: [N,1,K*D,K*W,K*H] deformed feature map
  206. """
  207. def _bilinear_interpolate_3D(self, input_feature, y, x):
  208. device = input_feature.device
  209. y = y.reshape([-1]).float()
  210. x = x.reshape([-1]).float()
  211. zero = torch.zeros([]).int()
  212. max_y = self.width - 1
  213. max_x = self.height - 1
  214. # find 8 grid locations
  215. y0 = torch.floor(y).int()
  216. y1 = y0 + 1
  217. x0 = torch.floor(x).int()
  218. x1 = x0 + 1
  219. # clip out coordinates exceeding feature map volume
  220. y0 = torch.clamp(y0, zero, max_y)
  221. y1 = torch.clamp(y1, zero, max_y)
  222. x0 = torch.clamp(x0, zero, max_x)
  223. x1 = torch.clamp(x1, zero, max_x)
  224. input_feature_flat = input_feature.flatten()
  225. input_feature_flat = input_feature_flat.reshape(
  226. self.num_batch, self.num_channels, self.width, self.height)
  227. input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
  228. input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
  229. dimension = self.height * self.width
  230. base = torch.arange(self.num_batch) * dimension
  231. base = base.reshape([-1, 1]).float()
  232. repeat = torch.ones([self.num_points * self.width * self.height
  233. ]).unsqueeze(0)
  234. repeat = repeat.float()
  235. base = torch.matmul(base, repeat)
  236. base = base.reshape([-1])
  237. base = base.to(device)
  238. base_y0 = base + y0 * self.height
  239. base_y1 = base + y1 * self.height
  240. # top rectangle of the neighbourhood volume
  241. index_a0 = base_y0 - base + x0
  242. index_c0 = base_y0 - base + x1
  243. # bottom rectangle of the neighbourhood volume
  244. index_a1 = base_y1 - base + x0
  245. index_c1 = base_y1 - base + x1
  246. # get 8 grid values
  247. value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(device)
  248. value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(device)
  249. value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(device)
  250. value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(device)
  251. # find 8 grid locations
  252. y0 = torch.floor(y).int()
  253. y1 = y0 + 1
  254. x0 = torch.floor(x).int()
  255. x1 = x0 + 1
  256. # clip out coordinates exceeding feature map volume
  257. y0 = torch.clamp(y0, zero, max_y + 1)
  258. y1 = torch.clamp(y1, zero, max_y + 1)
  259. x0 = torch.clamp(x0, zero, max_x + 1)
  260. x1 = torch.clamp(x1, zero, max_x + 1)
  261. x0_float = x0.float()
  262. x1_float = x1.float()
  263. y0_float = y0.float()
  264. y1_float = y1.float()
  265. vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(device)
  266. vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(device)
  267. vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(device)
  268. vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(device)
  269. outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
  270. value_c1 * vol_c1)
  271. if self.morph == 0:
  272. outputs = outputs.reshape([
  273. self.num_batch,
  274. self.num_points * self.width,
  275. 1 * self.height,
  276. self.num_channels,
  277. ])
  278. outputs = outputs.permute(0, 3, 1, 2)
  279. else:
  280. outputs = outputs.reshape([
  281. self.num_batch,
  282. 1 * self.width,
  283. self.num_points * self.height,
  284. self.num_channels,
  285. ])
  286. outputs = outputs.permute(0, 3, 1, 2)
  287. return outputs
  288. def deform_conv(self, input, offset, if_offset):
  289. y, x = self._coordinate_map_3D(offset, if_offset)
  290. deformed_feature = self._bilinear_interpolate_3D(input, y, x)
  291. return deformed_feature

03 代码使用方法

        DySnakeConv是卷积的一种,将完整代码放在卷积的路径中,我的路径是ultralytics/nn/modules/conv.py。

01在all添加导出名称

        在Python代码中,__all__ 是一个特殊的变量,用于指定一个模块中哪些对象(函数、类、变量等)应该被导出(即可以通过from 模块名 import *导入),以及哪些对象应该被视为模块的公共接口。

  1. __all__ = [
  2. 'Conv', 'LightConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv', 'ChannelAttention',
  3. 'SpatialAttention', 'CBAM', 'Concat', 'RepConv',
  4. 'DCNv2','DiverseBranchBlock','DySnakeConv']

01直接使用DySnakeConv

01加载conv模块

        在路径ultralytics/nn/modules/__init__.py中加入DySnakeConv

  1. from .conv import (CBAM, ChannelAttention, Concat, Conv, ConvTranspose, DWConv, DWConvTranspose2d, Focus, GhostConv,
  2. LightConv, RepConv, SpatialAttention,
  3. DCNv2,DySnakeConv,
  4. )

        在下方all中也加入'DySnakeConv'。

  1. __all__ = [
  2. 'Conv', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv','DySnakeConv',
  3. 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', 'TransformerBlock', 'MLPBlock',
  4. 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
  5. 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', 'Segment', 'Pose', 'Classify',
  6. 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI', 'DeformableTransformerDecoder',
  7. 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP','deattn',
  8. 'DCNv2','ODConv_3rd',
  9. 'C2f_DCN','C2f_DSConv','C2f_Biformer','DiverseBranchBlock','Bottleneck_DBB','C2f_DBB','EVCBlock','C2f_EMA','C2f_ODConv']

02在task脚本中添加

        路径在ultralytics/nn/tasks.py,添加导入的名称DySnakeConv。

  1. from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
  2. Classify, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus,
  3. GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, RTDETRDecoder,
  4. DCNv2,
  5. DySnakeConv,
  6. Segment,CBAM,C2f_DCN,C2f_DSConv,C2f_Biformer,C2f_DBB, C2f_EMA,EVCBlock,
  7. ODConv_3rd)

03解析参数添加

        在def parse_model(d, ch, verbose=True):中,找到并添加DySnakeConv。

  1. if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,DCNv2,DySnakeConv,
  2. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3,BasicRFB_a,BasicRFB,
  3. C2f_DCN,C2f_DSConv,C2f_DBB,EVCBlock,BoT3, C2f_EMA,ODConv_3rd,C2f_Biformer):

04带入yaml中使用

        由于DySnakeConv并不改变图像大小,没有步长设置为2的过程。所以选择添加一层,通道数设置和输入一样,卷积核大小就为3。YOLOv8的ymal中没有不改变图像大小的conv,所以是无法替换的。

  - [-1, 1, DySnakeConv, [512, 3]]
  1. head:
  2. - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 40*40*512 w*r
  3. - [[-1, 6], 1, Concat, [1]] # cat backbone P4 11 40*40*512 *(w+wr)
  4. - [-1, 3, C2f, [512]] # 12 # 40*40*512 *w
  5. - [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 80*80*256 *w
  6. - [[-1, 4], 1, Concat, [1]] # cat backbone P3 80*80*512 *w
  7. - [-1, 3, C2f, [256]] # 15 (P3/8-small) 80*80*256 *w --Detect
  8. - [-1, 1, Conv, [256, 3, 2]] # 40*40*256 *w
  9. - [-1, 1, DySnakeConv, [256, 3]]
  10. - [[-1, 12], 1, Concat, [1]] # cat head P4 # 40*40*512 *w
  11. - [-1, 3, C2f, [512]] # 18 (P4/16-medium) # 40*40*512 *w --Detect
  12. - [-1, 1, Conv, [512, 3, 2]] # 20*20*512 *w
  13. - [[-1, 9], 1, Concat, [1]] # cat head P5 # 20*20*512 *(w+wr)
  14. - [-1, 3, C2f, [1024]] # 21 (P5/32-large) # 20*20*512 *w --Detect
  15. - [[15, 18, 22], 1, Detect, [nc]] # Detect(P3=80*80*256 *w, P4=40*40*512 *w, P5=20*20*512)

02将DySnakeConv融入模块使用

        将DySnakeConv融入Bottleneck形成Bottleneck_DySnakeConv,在作为C2F的模型形成C2F_DySnakeConv。

01导入DySnakeConv

        和其他卷积一样,作为模型的子模块导入。

from .conv import Conv, DWConv, GhostConv, LightConv, RepConv,DCNv2,DySnakeConv

02添加DySnakeConv到Bottleneck

  1. class Bottleneck_DySnakeConv(nn.Module):
  2. # Standard bottleneck with DCN
  3. def __init__(self, c1, c2, shortcut=True, g=1, k=(3, 3), e=0.5): # ch_in, ch_out, shortcut-残差连接, groups, kernels, expand
  4. super().__init__()
  5. c_ = int(c2 * e) # hidden channels
  6. if k[0] == 3:
  7. self.cv1 = DySnakeConv(c1, c_, k[0], 1)
  8. else:
  9. self.cv1 = Conv(c1, c_, k[0], 1) #self.cv2 = DySnakeConv(c_, c2, 3)
  10. if k[1] == 3:
  11. self.cv2 = DySnakeConv(c_, c2, k[1])
  12. else:
  13. self.cv2 = Conv(c_, c2, k[1], 1, g=g)
  14. self.add = shortcut and c1 == c2 #如果残差连接以及通道数等
  15. def forward(self, x):
  16. return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))

03添加Bottleneck_DySnakeConv到C2F

  1. class C2f_DySnakeConv(nn.Module):
  2. # CSP Bottleneck with 2 convolutions #两个卷积的梯度流
  3. def __init__(self, c1, c2, n=1, shortcut=False, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansion
  4. super().__init__()
  5. self.c = int(c2 * e) # hidden channels
  6. self.cv1 = Conv(c1, 2 * self.c, 1, 1)
  7. self.cv2 = Conv((2 + n) * self.c, c2, 1) # optional act=FReLU(c2)
  8. self.m = nn.ModuleList(Bottleneck_DySnakeConv(self.c, self.c, shortcut, g, k=(3, 3), e=1.0) for _ in range(n)) #Bottleneck
  9. def forward(self, x):
  10. y = list(self.cv1(x).split((self.c, self.c), 1)) #先进行卷积 在进行切分
  11. y.extend(m(y[-1]) for m in self.m)
  12. return self.cv2(torch.cat(y, 1))

04将C2f_DySnakeConv写入导入all中

        all路径在同一个脚本中

  1. __all__ = [
  2. 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost', 'GhostBottleneck',
  3. 'Bottleneck', 'BottleneckCSP', 'Proto', 'RepC3',
  4. 'C2f_DCN','C2f_DSConv','C2f_Biformer','Bottleneck_DBB','C2f_DBB','EVCBlock','C2f_EMA','ODConv_3rd','C2f_DySnakeConv']

05将C2f_DySnakeConv写入init

        C2f_DySnakeConv是一个模块,在block中
  1. from .block import (C1, C2, C3, C3TR, DFL, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, GhostBottleneck,
  2. HGBlock, HGStem, Proto, RepC3,
  3. C2f_DCN,C2f_DSConv,C2f_Biformer,C2f_DBB,Bottleneck_DBB,EVCBlock, C2f_EMA,ODConv_3rd,C2f_DySnakeConv)#自己加的

        在下面all中也添加

  1. __all__ = [
  2. 'Conv', 'LightConv', 'RepConv', 'DWConv', 'DWConvTranspose2d', 'ConvTranspose', 'Focus', 'GhostConv','DySnakeConv',
  3. 'ChannelAttention', 'SpatialAttention', 'CBAM', 'Concat', 'TransformerLayer', 'TransformerBlock', 'MLPBlock',
  4. 'LayerNorm2d', 'DFL', 'HGBlock', 'HGStem', 'SPP', 'SPPF', 'C1', 'C2', 'C3', 'C2f', 'C3x', 'C3TR', 'C3Ghost',
  5. 'GhostBottleneck', 'Bottleneck', 'BottleneckCSP', 'Proto', 'Detect', 'Segment', 'Pose', 'Classify',
  6. 'TransformerEncoderLayer', 'RepC3', 'RTDETRDecoder', 'AIFI', 'DeformableTransformerDecoder',
  7. 'DeformableTransformerDecoderLayer', 'MSDeformAttn', 'MLP','deattn',
  8. 'DCNv2','ODConv_3rd',
  9. 'C2f_DCN','C2f_DSConv','C2f_Biformer','DiverseBranchBlock','Bottleneck_DBB','C2f_DBB','EVCBlock','C2f_EMA','C2f_ODConv','C2f_DySnakeConv']

06在task.py中添加

        路径在ultralytics/nn/tasks.py。

        从moulde中导入C2f_DySnakeConv

  1. from ultralytics.nn.modules import (AIFI, C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x,
  2. Classify, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Focus,
  3. GhostBottleneck, GhostConv, HGBlock, HGStem, Pose, RepC3, RepConv, RTDETRDecoder,
  4. DCNv2,
  5. DySnakeConv,C2f_DySnakeConv,
  6. Segment,CBAM,C2f_DCN,C2f_DSConv,C2f_Biformer,C2f_DBB, C2f_EMA,EVCBlock,
  7. ODConv_3rd)

        在解析参数-def parse_model(d, ch, verbose=True):中添加

  1. if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,DCNv2,DySnakeConv,
  2. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, RepC3,BasicRFB_a,BasicRFB,
  3. C2f_DCN,C2f_DSConv,C2f_DBB,EVCBlock,BoT3, C2f_EMA,ODConv_3rd,C2f_Biformer,C2f_DySnakeConv):

        C2f_DySnakeConv多次使用,下面哪个也添加

  1. if m in (BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, C3x, RepC3,
  2. C2f_DCN,C2f_DSConv,C2f_DBB,BoT3,C2f_EMA,C2f_Biformer,C2f_DySnakeConv):#增加block

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/AllinToyou/article/detail/95218
推荐阅读
相关标签
  

闽ICP备14008679号