当前位置:   article > 正文

ICCV 2023 | 动态蛇形卷积(内含即插即用的代码及测试用例)_动态蛇形卷积代码详解

动态蛇形卷积代码详解

论文链接:

https://arxiv.org/abs/2307.08388

代码链接:

https://github.com/YaoleiQi/DSCNet

下面直接上代码,并且源码中也给了测试用例,是一个即插即用的模块

  1. import os
  2. import torch
  3. import numpy as np
  4. from torch import nn
  5. import warnings
  6. warnings.filterwarnings("ignore")
  7. """
  8. This code is mainly the deformation process of our DSConv
  9. """
  10. class DSConv(nn.Module):
  11. def __init__(self, in_ch, out_ch, kernel_size, extend_scope, morph,
  12. if_offset, device):
  13. """
  14. The Dynamic Snake Convolution
  15. :param in_ch: input channel
  16. :param out_ch: output channel
  17. :param kernel_size: the size of kernel
  18. :param extend_scope: the range to expand (default 1 for this method)
  19. :param morph: the morphology of the convolution kernel is mainly divided into two types
  20. along the x-axis (0) and the y-axis (1) (see the paper for details)
  21. :param if_offset: whether deformation is required, if it is False, it is the standard convolution kernel
  22. :param device: set on gpu
  23. """
  24. super(DSConv, self).__init__()
  25. # use the <offset_conv> to learn the deformable offset
  26. self.offset_conv = nn.Conv2d(in_ch, 2 * kernel_size, 3, padding=1)
  27. self.bn = nn.BatchNorm2d(2 * kernel_size)
  28. self.kernel_size = kernel_size
  29. # two types of the DSConv (along x-axis and y-axis)
  30. self.dsc_conv_x = nn.Conv2d(
  31. in_ch,
  32. out_ch,
  33. kernel_size=(kernel_size, 1),
  34. stride=(kernel_size, 1),
  35. padding=0,
  36. )
  37. self.dsc_conv_y = nn.Conv2d(
  38. in_ch,
  39. out_ch,
  40. kernel_size=(1, kernel_size),
  41. stride=(1, kernel_size),
  42. padding=0,
  43. )
  44. self.gn = nn.GroupNorm(out_ch // 4, out_ch)
  45. self.relu = nn.ReLU(inplace=True)
  46. self.extend_scope = extend_scope
  47. self.morph = morph
  48. self.if_offset = if_offset
  49. self.device = device
  50. def forward(self, f):
  51. offset = self.offset_conv(f)
  52. offset = self.bn(offset)
  53. # We need a range of deformation between -1 and 1 to mimic the snake's swing
  54. offset = torch.tanh(offset)
  55. input_shape = f.shape
  56. dsc = DSC(input_shape, self.kernel_size, self.extend_scope, self.morph,
  57. self.device)
  58. deformed_feature = dsc.deform_conv(f, offset, self.if_offset)
  59. if self.morph == 0:
  60. x = self.dsc_conv_x(deformed_feature)
  61. x = self.gn(x)
  62. x = self.relu(x)
  63. return x
  64. else:
  65. x = self.dsc_conv_y(deformed_feature)
  66. x = self.gn(x)
  67. x = self.relu(x)
  68. return x
  69. # Core code, for ease of understanding, we mark the dimensions of input and output next to the code
  70. class DSC(object):
  71. def __init__(self, input_shape, kernel_size, extend_scope, morph, device):
  72. self.num_points = kernel_size
  73. self.width = input_shape[2]
  74. self.height = input_shape[3]
  75. self.morph = morph
  76. self.device = device
  77. self.extend_scope = extend_scope # offset (-1 ~ 1) * extend_scope
  78. # define feature map shape
  79. """
  80. B: Batch size C: Channel W: Width H: Height
  81. """
  82. self.num_batch = input_shape[0]
  83. self.num_channels = input_shape[1]
  84. """
  85. input: offset [B,2*K,W,H] K: Kernel size (2*K: 2D image, deformation contains <x_offset> and <y_offset>)
  86. output_x: [B,1,W,K*H] coordinate map
  87. output_y: [B,1,K*W,H] coordinate map
  88. """
  89. def _coordinate_map_3D(self, offset, if_offset):
  90. # offset
  91. y_offset, x_offset = torch.split(offset, self.num_points, dim=1)
  92. y_center = torch.arange(0, self.width).repeat([self.height])
  93. y_center = y_center.reshape(self.height, self.width)
  94. y_center = y_center.permute(1, 0)
  95. y_center = y_center.reshape([-1, self.width, self.height])
  96. y_center = y_center.repeat([self.num_points, 1, 1]).float()
  97. y_center = y_center.unsqueeze(0)
  98. x_center = torch.arange(0, self.height).repeat([self.width])
  99. x_center = x_center.reshape(self.width, self.height)
  100. x_center = x_center.permute(0, 1)
  101. x_center = x_center.reshape([-1, self.width, self.height])
  102. x_center = x_center.repeat([self.num_points, 1, 1]).float()
  103. x_center = x_center.unsqueeze(0)
  104. if self.morph == 0:
  105. """
  106. Initialize the kernel and flatten the kernel
  107. y: only need 0
  108. x: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  109. !!! The related PPT will be submitted later, and the PPT will contain the whole changes of each step
  110. """
  111. y = torch.linspace(0, 0, 1)
  112. x = torch.linspace(
  113. -int(self.num_points // 2),
  114. int(self.num_points // 2),
  115. int(self.num_points),
  116. )
  117. y, x = torch.meshgrid(y, x)
  118. y_spread = y.reshape(-1, 1)
  119. x_spread = x.reshape(-1, 1)
  120. y_grid = y_spread.repeat([1, self.width * self.height])
  121. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  122. y_grid = y_grid.unsqueeze(0) # [B*K*K, W,H]
  123. x_grid = x_spread.repeat([1, self.width * self.height])
  124. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  125. x_grid = x_grid.unsqueeze(0) # [B*K*K, W,H]
  126. y_new = y_center + y_grid
  127. x_new = x_center + x_grid
  128. y_new = y_new.repeat(self.num_batch, 1, 1, 1).to(self.device)
  129. x_new = x_new.repeat(self.num_batch, 1, 1, 1).to(self.device)
  130. y_offset_new = y_offset.detach().clone()
  131. if if_offset:
  132. y_offset = y_offset.permute(1, 0, 2, 3)
  133. y_offset_new = y_offset_new.permute(1, 0, 2, 3)
  134. center = int(self.num_points // 2)
  135. # The center position remains unchanged and the rest of the positions begin to swing
  136. # This part is quite simple. The main idea is that "offset is an iterative process"
  137. y_offset_new[center] = 0
  138. for index in range(1, center):
  139. y_offset_new[center + index] = (y_offset_new[center + index - 1] + y_offset[center + index])
  140. y_offset_new[center - index] = (y_offset_new[center - index + 1] + y_offset[center - index])
  141. y_offset_new = y_offset_new.permute(1, 0, 2, 3).to(self.device)
  142. y_new = y_new.add(y_offset_new.mul(self.extend_scope))
  143. y_new = y_new.reshape(
  144. [self.num_batch, self.num_points, 1, self.width, self.height])
  145. y_new = y_new.permute(0, 3, 1, 4, 2)
  146. y_new = y_new.reshape([
  147. self.num_batch, self.num_points * self.width, 1 * self.height
  148. ])
  149. x_new = x_new.reshape(
  150. [self.num_batch, self.num_points, 1, self.width, self.height])
  151. x_new = x_new.permute(0, 3, 1, 4, 2)
  152. x_new = x_new.reshape([
  153. self.num_batch, self.num_points * self.width, 1 * self.height
  154. ])
  155. return y_new, x_new
  156. else:
  157. """
  158. Initialize the kernel and flatten the kernel
  159. y: -num_points//2 ~ num_points//2 (Determined by the kernel size)
  160. x: only need 0
  161. """
  162. y = torch.linspace(
  163. -int(self.num_points // 2),
  164. int(self.num_points // 2),
  165. int(self.num_points),
  166. )
  167. x = torch.linspace(0, 0, 1)
  168. y, x = torch.meshgrid(y, x)
  169. y_spread = y.reshape(-1, 1)
  170. x_spread = x.reshape(-1, 1)
  171. y_grid = y_spread.repeat([1, self.width * self.height])
  172. y_grid = y_grid.reshape([self.num_points, self.width, self.height])
  173. y_grid = y_grid.unsqueeze(0)
  174. x_grid = x_spread.repeat([1, self.width * self.height])
  175. x_grid = x_grid.reshape([self.num_points, self.width, self.height])
  176. x_grid = x_grid.unsqueeze(0)
  177. y_new = y_center + y_grid
  178. x_new = x_center + x_grid
  179. y_new = y_new.repeat(self.num_batch, 1, 1, 1)
  180. x_new = x_new.repeat(self.num_batch, 1, 1, 1)
  181. y_new = y_new.to(self.device)
  182. x_new = x_new.to(self.device)
  183. x_offset_new = x_offset.detach().clone()
  184. if if_offset:
  185. x_offset = x_offset.permute(1, 0, 2, 3)
  186. x_offset_new = x_offset_new.permute(1, 0, 2, 3)
  187. center = int(self.num_points // 2)
  188. x_offset_new[center] = 0
  189. for index in range(1, center):
  190. x_offset_new[center + index] = (x_offset_new[center + index - 1] + x_offset[center + index])
  191. x_offset_new[center - index] = (x_offset_new[center - index + 1] + x_offset[center - index])
  192. x_offset_new = x_offset_new.permute(1, 0, 2, 3).to(self.device)
  193. x_new = x_new.add(x_offset_new.mul(self.extend_scope))
  194. y_new = y_new.reshape(
  195. [self.num_batch, 1, self.num_points, self.width, self.height])
  196. y_new = y_new.permute(0, 3, 1, 4, 2)
  197. y_new = y_new.reshape([
  198. self.num_batch, 1 * self.width, self.num_points * self.height
  199. ])
  200. x_new = x_new.reshape(
  201. [self.num_batch, 1, self.num_points, self.width, self.height])
  202. x_new = x_new.permute(0, 3, 1, 4, 2)
  203. x_new = x_new.reshape([
  204. self.num_batch, 1 * self.width, self.num_points * self.height
  205. ])
  206. return y_new, x_new
  207. """
  208. input: input feature map [N,C,D,W,H];coordinate map [N,K*D,K*W,K*H]
  209. output: [N,1,K*D,K*W,K*H] deformed feature map
  210. """
  211. def _bilinear_interpolate_3D(self, input_feature, y, x):
  212. y = y.reshape([-1]).float()
  213. x = x.reshape([-1]).float()
  214. zero = torch.zeros([]).int()
  215. max_y = self.width - 1
  216. max_x = self.height - 1
  217. # find 8 grid locations
  218. y0 = torch.floor(y).int()
  219. y1 = y0 + 1
  220. x0 = torch.floor(x).int()
  221. x1 = x0 + 1
  222. # clip out coordinates exceeding feature map volume
  223. y0 = torch.clamp(y0, zero, max_y)
  224. y1 = torch.clamp(y1, zero, max_y)
  225. x0 = torch.clamp(x0, zero, max_x)
  226. x1 = torch.clamp(x1, zero, max_x)
  227. input_feature_flat = input_feature.flatten()
  228. input_feature_flat = input_feature_flat.reshape(
  229. self.num_batch, self.num_channels, self.width, self.height)
  230. input_feature_flat = input_feature_flat.permute(0, 2, 3, 1)
  231. input_feature_flat = input_feature_flat.reshape(-1, self.num_channels)
  232. dimension = self.height * self.width
  233. base = torch.arange(self.num_batch) * dimension
  234. base = base.reshape([-1, 1]).float()
  235. repeat = torch.ones([self.num_points * self.width * self.height
  236. ]).unsqueeze(0)
  237. repeat = repeat.float()
  238. base = torch.matmul(base, repeat)
  239. base = base.reshape([-1])
  240. base = base.to(self.device)
  241. base_y0 = base + y0 * self.height
  242. base_y1 = base + y1 * self.height
  243. # top rectangle of the neighbourhood volume
  244. index_a0 = base_y0 - base + x0
  245. index_c0 = base_y0 - base + x1
  246. # bottom rectangle of the neighbourhood volume
  247. index_a1 = base_y1 - base + x0
  248. index_c1 = base_y1 - base + x1
  249. # get 8 grid values
  250. value_a0 = input_feature_flat[index_a0.type(torch.int64)].to(self.device)
  251. value_c0 = input_feature_flat[index_c0.type(torch.int64)].to(self.device)
  252. value_a1 = input_feature_flat[index_a1.type(torch.int64)].to(self.device)
  253. value_c1 = input_feature_flat[index_c1.type(torch.int64)].to(self.device)
  254. # find 8 grid locations
  255. y0 = torch.floor(y).int()
  256. y1 = y0 + 1
  257. x0 = torch.floor(x).int()
  258. x1 = x0 + 1
  259. # clip out coordinates exceeding feature map volume
  260. y0 = torch.clamp(y0, zero, max_y + 1)
  261. y1 = torch.clamp(y1, zero, max_y + 1)
  262. x0 = torch.clamp(x0, zero, max_x + 1)
  263. x1 = torch.clamp(x1, zero, max_x + 1)
  264. x0_float = x0.float()
  265. x1_float = x1.float()
  266. y0_float = y0.float()
  267. y1_float = y1.float()
  268. vol_a0 = ((y1_float - y) * (x1_float - x)).unsqueeze(-1).to(self.device)
  269. vol_c0 = ((y1_float - y) * (x - x0_float)).unsqueeze(-1).to(self.device)
  270. vol_a1 = ((y - y0_float) * (x1_float - x)).unsqueeze(-1).to(self.device)
  271. vol_c1 = ((y - y0_float) * (x - x0_float)).unsqueeze(-1).to(self.device)
  272. outputs = (value_a0 * vol_a0 + value_c0 * vol_c0 + value_a1 * vol_a1 +
  273. value_c1 * vol_c1)
  274. if self.morph == 0:
  275. outputs = outputs.reshape([
  276. self.num_batch,
  277. self.num_points * self.width,
  278. 1 * self.height,
  279. self.num_channels,
  280. ])
  281. outputs = outputs.permute(0, 3, 1, 2)
  282. else:
  283. outputs = outputs.reshape([
  284. self.num_batch,
  285. 1 * self.width,
  286. self.num_points * self.height,
  287. self.num_channels,
  288. ])
  289. outputs = outputs.permute(0, 3, 1, 2)
  290. return outputs
  291. def deform_conv(self, input, offset, if_offset):
  292. y, x = self._coordinate_map_3D(offset, if_offset)
  293. deformed_feature = self._bilinear_interpolate_3D(input, y, x)
  294. return deformed_feature
  295. # Code for testing the DSConv
  296. if __name__ == '__main__':
  297. os.environ["CUDA_VISIBLE_DEVICES"] = '0'
  298. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  299. A = np.random.rand(4, 5, 6, 7)
  300. # A = np.ones(shape=(3, 2, 2, 3), dtype=np.float32)
  301. # print(A)
  302. A = A.astype(dtype=np.float32)
  303. A = torch.from_numpy(A)
  304. # print(A.shape)
  305. conv0 = DSConv(
  306. in_ch=5,
  307. out_ch=10,
  308. kernel_size=15,
  309. extend_scope=1,
  310. morph=0,
  311. if_offset=True,
  312. device=device)
  313. if torch.cuda.is_available():
  314. A = A.to(device)
  315. conv0 = conv0.to(device)
  316. out = conv0(A)
  317. print(out.shape)

谢谢小伙伴们多多支持,对了,如果希望出一个将其改进到不同的模型中的教程可以在评论区留言哦

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号