当前位置:   article > 正文

【论文笔记】AK卷积(Convolutional Kernel with Arbitrary Sampled Shapes and Arbitrary Number of Parameters)_akconv

akconv

本文介绍AK卷积,传统的卷积有2个缺陷:

1、卷积运算在固定大小的窗口运行、无法捕获其他窗口的信息,并且窗口的形状是固定的;

2、卷积核的尺寸固定为k\times k,窗口大小固定为k,随着k增加,参数会快速增加。

针对传统卷积的缺陷,作者提出了AK卷积,AK卷积拥有任意形状和任意的参数。作者在yolov5n和yolov8n上进行了测试,效果非常好。

论文地址:AKConv: Convolutional Kernel with Arbitrary Sampled Shapes and Arbitrary Number of Parameters

代码:https://github.com/cv-zhangxin/akconv

一、AKConv

前面已经提到了传统卷积的2个缺陷,那么AKConv是怎么做的呢?

1、任意形状

标准卷积是k\times k的矩形,而可变形卷积(Deformable Conv)是可以调整形状的,类似可变形卷积,AKConv也会学习偏移量,来改变卷积核的形状,如下图所示。

N是AKConv卷积参数的数量,特征图经过卷积运算得到卷积的位置偏移量,然后进行卷积运算,和可变形卷积一样。

2、任意参数数量

可变形卷积可以通过学习偏移量改变卷积计算的位置,从而使得卷积核的形状不固定,但是可变形卷积有个缺陷:卷积核参数是固定的。(比如1,9, 27...)

AKConv的另一个特点就是参数数量是任意的(可以设置为1,2,3,4,5...任意值),如下图,这点是和传统卷积不一样的,摆脱了k\times k的参数限制。

除了参数数量可以任意选择,初始的卷积核形状也是可以任意选择,下图为5个卷积参数时,卷积核的初始形状设计方案。

二、性能

AKConv是对可变形卷积的巨大改进,他的性能也是非常好的,在yolov5n上添加AKConv,可以看到在COCO2017数据集上的表现非常亮眼:

不同的卷积形状在yoloV8的测试(COCO2017数据集):

关于实验可以参考原论文,不多赘述。

三、代码

官方的代码已经给出了在v5/7/8上的配置文件和代码,这里给出其核心代码,配置文件见源码:

  1. import math
  2. import torch.nn.functional as F
  3. from .conv import Conv
  4. import einops
  5. class AKConv(nn.Module):
  6. def __init__(self, inc, outc, num_param, stride=1, bias=None):
  7. super(AKConv, self).__init__()
  8. self.num_param = num_param
  9. self.stride = stride
  10. self.conv = nn.Sequential(nn.Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias),nn.BatchNorm2d(outc),nn.SiLU()) # the conv adds the BN and SiLU to compare original Conv in YOLOv5.
  11. self.p_conv = nn.Conv2d(inc, 2 * num_param, kernel_size=3, padding=1, stride=stride)
  12. nn.init.constant_(self.p_conv.weight, 0)
  13. self.p_conv.register_full_backward_hook(self._set_lr)
  14. @staticmethod
  15. def _set_lr(module, grad_input, grad_output):
  16. grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
  17. grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))
  18. def forward(self, x):
  19. # N is num_param.
  20. offset = self.p_conv(x)
  21. dtype = offset.data.type()
  22. N = offset.size(1) // 2
  23. # (b, 2N, h, w)
  24. p = self._get_p(offset, dtype)
  25. # (b, h, w, 2N)
  26. p = p.contiguous().permute(0, 2, 3, 1)
  27. q_lt = p.detach().floor()
  28. q_rb = q_lt + 1
  29. q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2) - 1), torch.clamp(q_lt[..., N:], 0, x.size(3) - 1)],
  30. dim=-1).long()
  31. q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2) - 1), torch.clamp(q_rb[..., N:], 0, x.size(3) - 1)],
  32. dim=-1).long()
  33. q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
  34. q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)
  35. # clip p
  36. p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2) - 1), torch.clamp(p[..., N:], 0, x.size(3) - 1)], dim=-1)
  37. # bilinear kernel (b, h, w, N)
  38. g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
  39. g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
  40. g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
  41. g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))
  42. # resampling the features based on the modified coordinates.
  43. x_q_lt = self._get_x_q(x, q_lt, N)
  44. x_q_rb = self._get_x_q(x, q_rb, N)
  45. x_q_lb = self._get_x_q(x, q_lb, N)
  46. x_q_rt = self._get_x_q(x, q_rt, N)
  47. # bilinear
  48. x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
  49. g_rb.unsqueeze(dim=1) * x_q_rb + \
  50. g_lb.unsqueeze(dim=1) * x_q_lb + \
  51. g_rt.unsqueeze(dim=1) * x_q_rt
  52. x_offset = self._reshape_x_offset(x_offset, self.num_param)
  53. out = self.conv(x_offset)
  54. return out
  55. # generating the inital sampled shapes for the AKConv with different sizes.
  56. def _get_p_n(self, N, dtype):
  57. base_int = round(math.sqrt(self.num_param))
  58. row_number = self.num_param // base_int
  59. mod_number = self.num_param % base_int
  60. p_n_x,p_n_y = torch.meshgrid(
  61. torch.arange(0, row_number),
  62. torch.arange(0,base_int))
  63. p_n_x = torch.flatten(p_n_x)
  64. p_n_y = torch.flatten(p_n_y)
  65. if mod_number > 0:
  66. mod_p_n_x,mod_p_n_y = torch.meshgrid(
  67. torch.arange(row_number,row_number+1),
  68. torch.arange(0,mod_number))
  69. mod_p_n_x = torch.flatten(mod_p_n_x)
  70. mod_p_n_y = torch.flatten(mod_p_n_y)
  71. p_n_x,p_n_y = torch.cat((p_n_x,mod_p_n_x)),torch.cat((p_n_y,mod_p_n_y))
  72. p_n = torch.cat([p_n_x,p_n_y], 0)
  73. p_n = p_n.view(1, 2 * N, 1, 1).type(dtype)
  74. return p_n
  75. # no zero-padding
  76. def _get_p_0(self, h, w, N, dtype):
  77. p_0_x, p_0_y = torch.meshgrid(
  78. torch.arange(0, h * self.stride, self.stride),
  79. torch.arange(0, w * self.stride, self.stride))
  80. p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
  81. p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
  82. p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
  83. return p_0
  84. def _get_p(self, offset, dtype):
  85. N, h, w = offset.size(1) // 2, offset.size(2), offset.size(3)
  86. # (1, 2N, 1, 1)
  87. p_n = self._get_p_n(N, dtype)
  88. # (1, 2N, h, w)
  89. p_0 = self._get_p_0(h, w, N, dtype)
  90. p = p_0 + p_n + offset
  91. return p
  92. def _get_x_q(self, x, q, N):
  93. b, h, w, _ = q.size()
  94. padded_w = x.size(3)
  95. c = x.size(1)
  96. # (b, c, h*w)
  97. x = x.contiguous().view(b, c, -1)
  98. # (b, h, w, N)
  99. index = q[..., :N] * padded_w + q[..., N:] # offset_x*w + offset_y
  100. # (b, c, h*w*N)
  101. index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)
  102. x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)
  103. return x_offset
  104. # Stacking resampled features in the row direction.
  105. @staticmethod
  106. def _reshape_x_offset(x_offset, num_param):
  107. b, c, h, w, n = x_offset.size()
  108. # using Conv3d
  109. # x_offset = x_offset.permute(0,1,4,2,3), then Conv3d(c,c_out, kernel_size =(num_param,1,1),stride=(num_param,1,1),bias= False)
  110. # using 1 × 1 Conv
  111. # x_offset = x_offset.permute(0,1,4,2,3), then, x_offset.view(b,c×num_param,h,w) finally, Conv2d(c×num_param,c_out, kernel_size =1,stride=1,bias= False)
  112. # using the column conv as follow, then, Conv2d(inc, outc, kernel_size=(num_param, 1), stride=(num_param, 1), bias=bias)
  113. x_offset = rearrange(x_offset, 'b c h w n -> b c (h n) w')
  114. return x_offset

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/339999
推荐阅读
相关标签
  

闽ICP备14008679号