当前位置:   article > 正文

【论文笔记】动态蛇卷积(Dynamic Snake Convolution)_动态蛇形卷积

动态蛇形卷积

精确分割拓扑管状结构例如血管和道路,对医疗各个领域至关重要,可确保下游任务的准确性和效率。然而许多因素使分割任务变得复杂,包括细小脆弱的局部结构和复杂多变的全局形态。针对这个问题,作者提出了动态蛇卷积,该结构在管状分割任务上获得了极好的性能。

论文:Dynamic Snake Convolution based on Topological Geometric Constraints for Tubular Structure Segmentation

中文论文:拓扑几何约束管状结构分割的动态蛇卷积

代码:https://github.com/yaoleiqi/dscnet

一、适用场景

管状目标分割的特点是细长且复杂,标准卷积、空洞卷积无法更具目标特征调整关注区域,可变形卷积可以更具特征自适应学习感兴趣区域,但是对于管状目标,可变形卷积无法限制关注区域的连通性,而动态蛇卷积限制了关注区域的连通性,是的其更适合管状场景。

二、动态蛇卷积

对于一个标准3x3的2D卷积核K,其表示为:

为了赋予卷积核更多灵活性,使其能够聚焦于目标 的复杂几何特征,受到可变形卷积的启发,引入了变形偏 移 ∆。然而,如果模型被完全自由地学习变形偏移,感知场往往会偏离目标,特别是在处理细长管状结构的情 况下。因此,作者采用了一个迭代策略(下图),依次选 择每个要处理的目标的下一个位置进行观察,从而确保关注的连续性,不会由于大的变形偏移而将感知范围扩 散得太远。

在动态蛇形卷积中,作者将标准卷积核在 x 轴和 y 轴方向都进行了直线化。考虑一个大小为 9 的卷积 核,以 x 轴方向为例,K 中每个网格的具体位置表示 为:Ki±c = (xi±c, yi±c),其中 c = 0, 1, 2, 3, 4 表示距离 中心网格的水平距离。卷积核 K 中每个网格位置 Ki±c 的选择是一个累积过程。从中心位置 Ki 开始,远离中 心网格的位置取决于前一个网格的位置:Ki+1 相对于 Ki 增加了偏移量 ∆ = {δ|δ ∈ [−1, 1]}。因此,偏移量 需要进行累加 Σ,从而确保卷积核符合线性形态结构。 上图中 x 轴方向的变化为:

y轴方向的变化为:

由于偏移量 ∆ 通常是小数,然而坐标通常是整数 形式,因此采用双线性插值,表示为:

其中,K 表示方程 2和方程 3的小数位置,K′ 列 举所有整数空间位置,B 是双线性插值核,可以分解为 两个一维核,即:

再给个整体图:

三、代码

蛇卷积的代码如下:

  1. # -*- coding: utf-8 -*-
  2. import os
  3. import torch
  4. from torch import nn
  5. import einops
  6. """Dynamic Snake Convolution Module"""
  7. class DSConv_pro(nn.Module):
  8. def __init__(
  9. self,
  10. in_channels: int = 1,
  11. out_channels: int = 1,
  12. kernel_size: int = 9,
  13. extend_scope: float = 1.0,
  14. morph: int = 0,
  15. if_offset: bool = True,
  16. device: str | torch.device = "cuda",
  17. ):
  18. """
  19. A Dynamic Snake Convolution Implementation
  20. Based on:
  21. TODO
  22. Args:
  23. in_ch: number of input channels. Defaults to 1.
  24. out_ch: number of output channels. Defaults to 1.
  25. kernel_size: the size of kernel. Defaults to 9.
  26. extend_scope: the range to expand. Defaults to 1 for this method.
  27. morph: the morphology of the convolution kernel is mainly divided into two types along the x-axis (0) and the y-axis (1) (see the paper for details).
  28. if_offset: whether deformation is required, if it is False, it is the standard convolution kernel. Defaults to True.
  29. """
  30. super().__init__()
  31. if morph not in (0, 1):
  32. raise ValueError("morph should be 0 or 1.")
  33. self.kernel_size = kernel_size
  34. self.extend_scope = extend_scope
  35. self.morph = morph
  36. self.if_offset = if_offset
  37. self.device = torch.device(device)
  38. self.to(device)
  39. # self.bn = nn.BatchNorm2d(2 * kernel_size)
  40. self.gn_offset = nn.GroupNorm(kernel_size, 2 * kernel_size)
  41. self.gn = nn.GroupNorm(out_channels // 4, out_channels)
  42. self.relu = nn.ReLU(inplace=True)
  43. self.tanh = nn.Tanh()
  44. self.offset_conv = nn.Conv2d(in_channels, 2 * kernel_size, 3, padding=1)
  45. self.dsc_conv_x = nn.Conv2d(
  46. in_channels,
  47. out_channels,
  48. kernel_size=(kernel_size, 1),
  49. stride=(kernel_size, 1),
  50. padding=0,
  51. )
  52. self.dsc_conv_y = nn.Conv2d(
  53. in_channels,
  54. out_channels,
  55. kernel_size=(1, kernel_size),
  56. stride=(1, kernel_size),
  57. padding=0,
  58. )
  59. def forward(self, input: torch.Tensor):
  60. # Predict offset map between [-1, 1]
  61. offset = self.offset_conv(input)
  62. # offset = self.bn(offset)
  63. offset = self.gn_offset(offset)
  64. offset = self.tanh(offset)
  65. # Run deformative conv
  66. y_coordinate_map, x_coordinate_map = get_coordinate_map_2D(
  67. offset=offset,
  68. morph=self.morph,
  69. extend_scope=self.extend_scope,
  70. device=self.device,
  71. )
  72. deformed_feature = get_interpolated_feature(
  73. input,
  74. y_coordinate_map,
  75. x_coordinate_map,
  76. )
  77. if self.morph == 0:
  78. output = self.dsc_conv_x(deformed_feature)
  79. elif self.morph == 1:
  80. output = self.dsc_conv_y(deformed_feature)
  81. # Groupnorm & ReLU
  82. output = self.gn(output)
  83. output = self.relu(output)
  84. return output
  85. def get_coordinate_map_2D(
  86. offset: torch.Tensor,
  87. morph: int,
  88. extend_scope: float = 1.0,
  89. device: str | torch.device = "cuda",
  90. ):
  91. """Computing 2D coordinate map of DSCNet based on: TODO
  92. Args:
  93. offset: offset predict by network with shape [B, 2*K, W, H]. Here K refers to kernel size.
  94. morph: the morphology of the convolution kernel is mainly divided into two types along the x-axis (0) and the y-axis (1) (see the paper for details).
  95. extend_scope: the range to expand. Defaults to 1 for this method.
  96. device: location of data. Defaults to 'cuda'.
  97. Return:
  98. y_coordinate_map: coordinate map along y-axis with shape [B, K_H * H, K_W * W]
  99. x_coordinate_map: coordinate map along x-axis with shape [B, K_H * H, K_W * W]
  100. """
  101. if morph not in (0, 1):
  102. raise ValueError("morph should be 0 or 1.")
  103. batch_size, _, width, height = offset.shape
  104. kernel_size = offset.shape[1] // 2
  105. center = kernel_size // 2
  106. device = torch.device(device)
  107. y_offset_, x_offset_ = torch.split(offset, kernel_size, dim=1)
  108. y_center_ = torch.arange(0, width, dtype=torch.float32, device=device)
  109. y_center_ = einops.repeat(y_center_, "w -> k w h", k=kernel_size, h=height)
  110. x_center_ = torch.arange(0, height, dtype=torch.float32, device=device)
  111. x_center_ = einops.repeat(x_center_, "h -> k w h", k=kernel_size, w=width)
  112. if morph == 0:
  113. """
  114. Initialize the kernel and flatten the kernel
  115. y: only need 0
  116. x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  117. """
  118. y_spread_ = torch.zeros([kernel_size], device=device)
  119. x_spread_ = torch.linspace(-center, center, kernel_size, device=device)
  120. y_grid_ = einops.repeat(y_spread_, "k -> k w h", w=width, h=height)
  121. x_grid_ = einops.repeat(x_spread_, "k -> k w h", w=width, h=height)
  122. y_new_ = y_center_ + y_grid_
  123. x_new_ = x_center_ + x_grid_
  124. y_new_ = einops.repeat(y_new_, "k w h -> b k w h", b=batch_size)
  125. x_new_ = einops.repeat(x_new_, "k w h -> b k w h", b=batch_size)
  126. y_offset_ = einops.rearrange(y_offset_, "b k w h -> k b w h")
  127. y_offset_new_ = y_offset_.detach().clone()
  128. # The center position remains unchanged and the rest of the positions begin to swing
  129. # This part is quite simple. The main idea is that "offset is an iterative process"
  130. y_offset_new_[center] = 0
  131. for index in range(1, center + 1):
  132. y_offset_new_[center + index] = (
  133. y_offset_new_[center + index - 1] + y_offset_[center + index]
  134. )
  135. y_offset_new_[center - index] = (
  136. y_offset_new_[center - index + 1] + y_offset_[center - index]
  137. )
  138. y_offset_new_ = einops.rearrange(y_offset_new_, "k b w h -> b k w h")
  139. y_new_ = y_new_.add(y_offset_new_.mul(extend_scope))
  140. y_coordinate_map = einops.rearrange(y_new_, "b k w h -> b (w k) h")
  141. x_coordinate_map = einops.rearrange(x_new_, "b k w h -> b (w k) h")
  142. elif morph == 1:
  143. """
  144. Initialize the kernel and flatten the kernel
  145. y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  146. x: only need 0
  147. """
  148. y_spread_ = torch.linspace(-center, center, kernel_size, device=device)
  149. x_spread_ = torch.zeros([kernel_size], device=device)
  150. y_grid_ = einops.repeat(y_spread_, "k -> k w h", w=width, h=height)
  151. x_grid_ = einops.repeat(x_spread_, "k -> k w h", w=width, h=height)
  152. y_new_ = y_center_ + y_grid_
  153. x_new_ = x_center_ + x_grid_
  154. y_new_ = einops.repeat(y_new_, "k w h -> b k w h", b=batch_size)
  155. x_new_ = einops.repeat(x_new_, "k w h -> b k w h", b=batch_size)
  156. x_offset_ = einops.rearrange(x_offset_, "b k w h -> k b w h")
  157. x_offset_new_ = x_offset_.detach().clone()
  158. # The center position remains unchanged and the rest of the positions begin to swing
  159. # This part is quite simple. The main idea is that "offset is an iterative process"
  160. x_offset_new_[center] = 0
  161. for index in range(1, center + 1):
  162. x_offset_new_[center + index] = (
  163. x_offset_new_[center + index - 1] + x_offset_[center + index]
  164. )
  165. x_offset_new_[center - index] = (
  166. x_offset_new_[center - index + 1] + x_offset_[center - index]
  167. )
  168. x_offset_new_ = einops.rearrange(x_offset_new_, "k b w h -> b k w h")
  169. x_new_ = x_new_.add(x_offset_new_.mul(extend_scope))
  170. y_coordinate_map = einops.rearrange(y_new_, "b k w h -> b w (h k)")
  171. x_coordinate_map = einops.rearrange(x_new_, "b k w h -> b w (h k)")
  172. return y_coordinate_map, x_coordinate_map
  173. def get_interpolated_feature(
  174. input_feature: torch.Tensor,
  175. y_coordinate_map: torch.Tensor,
  176. x_coordinate_map: torch.Tensor,
  177. interpolate_mode: str = "bilinear",
  178. ):
  179. """From coordinate map interpolate feature of DSCNet based on: TODO
  180. Args:
  181. input_feature: feature that to be interpolated with shape [B, C, H, W]
  182. y_coordinate_map: coordinate map along y-axis with shape [B, K_H * H, K_W * W]
  183. x_coordinate_map: coordinate map along x-axis with shape [B, K_H * H, K_W * W]
  184. interpolate_mode: the arg 'mode' of nn.functional.grid_sample, can be 'bilinear' or 'bicubic' . Defaults to 'bilinear'.
  185. Return:
  186. interpolated_feature: interpolated feature with shape [B, C, K_H * H, K_W * W]
  187. """
  188. if interpolate_mode not in ("bilinear", "bicubic"):
  189. raise ValueError("interpolate_mode should be 'bilinear' or 'bicubic'.")
  190. y_max = input_feature.shape[-2] - 1
  191. x_max = input_feature.shape[-1] - 1
  192. y_coordinate_map_ = _coordinate_map_scaling(y_coordinate_map, origin=[0, y_max])
  193. x_coordinate_map_ = _coordinate_map_scaling(x_coordinate_map, origin=[0, x_max])
  194. y_coordinate_map_ = torch.unsqueeze(y_coordinate_map_, dim=-1)
  195. x_coordinate_map_ = torch.unsqueeze(x_coordinate_map_, dim=-1)
  196. # Note here grid with shape [B, H, W, 2]
  197. # Where [:, :, :, 2] refers to [x ,y]
  198. grid = torch.cat([x_coordinate_map_, y_coordinate_map_], dim=-1)
  199. interpolated_feature = nn.functional.grid_sample(
  200. input=input_feature,
  201. grid=grid,
  202. mode=interpolate_mode,
  203. padding_mode="zeros",
  204. align_corners=True,
  205. )
  206. return interpolated_feature
  207. def _coordinate_map_scaling(
  208. coordinate_map: torch.Tensor,
  209. origin: list,
  210. target: list = [-1, 1],
  211. ):
  212. """Map the value of coordinate_map from origin=[min, max] to target=[a,b] for DSCNet based on: TODO
  213. Args:
  214. coordinate_map: the coordinate map to be scaled
  215. origin: original value range of coordinate map, e.g. [coordinate_map.min(), coordinate_map.max()]
  216. target: target value range of coordinate map,Defaults to [-1, 1]
  217. Return:
  218. coordinate_map_scaled: the coordinate map after scaling
  219. """
  220. min, max = origin
  221. a, b = target
  222. coordinate_map_scaled = torch.clamp(coordinate_map, min, max)
  223. scale_factor = (b - a) / (max - min)
  224. coordinate_map_scaled = a + scale_factor * (coordinate_map_scaled - min)
  225. return coordinate_map_scaled
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/95256
推荐阅读
相关标签
  

闽ICP备14008679号