当前位置:   article > 正文

即插即用模块之DO-Conv(深度过度参数化卷积层)详解

do-conv

目录

一、摘要

二、核心创新点

三、代码详解

四、实验结果

4.1Image Classification

4.2Semantic Segmentation

4.3Object Detection 

五、总结


论文:DOConv论文

代码:DOConv代码

一、摘要

卷积层是卷积神经网络(cnn)的核心组成部分。在本文中,我们建议用额外的深度卷积来增强卷积层,其中每个输入通道与不同的二维核进行卷积。这两个卷积的组合构成了一个过度参数化,因为它增加了可学习的参数,而结果的线性操作可以用单个卷积层来表示。我们把这个深度过度参数化的卷积层称为DO-Conv。我们通过大量的实验表明,仅仅用DO-Conv层替换传统的卷积层就可以提高cnn在许多经典视觉任务上的性能,例如图像分类、检测和分割。此外,在推理阶段,深度卷积被折叠成常规卷积,将计算量减少到完全等同于卷积层的计算量,而没有过度参数化。由于DO-Conv在不增加推理计算复杂度的情况下引入了性能提升,我们主张将其作为传统卷积层的替代方案。

二、核心创新点

深度过参数化卷积层(DO-Conv)是一个具有可训练kernel深度卷积和一个具有可训练常规卷积的组合。给定一个输入, DO-Conv算子的输出与卷积层相同,是一个同维特征。DO-Conv算子是深度卷积算子和卷积算子的复合,如图所示,有两种数学上等价的方法来实现复合:特征复合(a)和核复合(b)。

三、代码详解

  1. # 使用 utf-8 编码
  2. # 导入必要的库
  3. import math # 导入数学库
  4. import torch # 导入 PyTorch 库
  5. import numpy as np # 导入 NumPy 库
  6. from torch.nn import init # 导入 PyTorch 中的初始化函数
  7. from itertools import repeat # 导入 itertools 库中的 repeat 函数
  8. from torch.nn import functional as F # 导入 PyTorch 中的函数式接口
  9. from torch._jit_internal import Optional # 导入 PyTorch 中的可选模块
  10. from torch.nn.parameter import Parameter # 导入 PyTorch 中的参数类
  11. from torch.nn.modules.module import Module # 导入 PyTorch 中的模块类
  12. import collections # 导入 collections 库
  13. # 定义自定义模块 DOConv2d
  14. class DOConv2d(Module):
  15. """
  16. DOConv2d 可以作为 torch.nn.Conv2d 的替代。
  17. 接口与 Conv2d 类似,但有一个例外:
  18. 1. D_mul:超参数的深度乘法器。
  19. 请注意,groups 参数在 DO-Conv(groups=1)、DO-DConv(groups=in_channels)、DO-GConv(其他情况)之间切换。
  20. """
  21. # 常量声明
  22. __constants__ = ['stride', 'padding', 'dilation', 'groups',
  23. 'padding_mode', 'output_padding', 'in_channels',
  24. 'out_channels', 'kernel_size', 'D_mul']
  25. # 注解声明
  26. __annotations__ = {'bias': Optional[torch.Tensor]}
  27. # 初始化函数
  28. def __init__(self, in_channels, out_channels, kernel_size, D_mul=None, stride=1,
  29. padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
  30. super(DOConv2d, self).__init__()
  31. # 将 kernel_size、stride、padding、dilation 转化为二元元组
  32. kernel_size = _pair(kernel_size)
  33. stride = _pair(stride)
  34. padding = _pair(padding)
  35. dilation = _pair(dilation)
  36. # 检查 groups 是否合法
  37. if in_channels % groups != 0:
  38. raise ValueError('in_channels 必须能被 groups 整除')
  39. if out_channels % groups != 0:
  40. raise ValueError('out_channels 必须能被 groups 整除')
  41. # 检查 padding_mode 是否合法
  42. valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
  43. if padding_mode not in valid_padding_modes:
  44. raise ValueError("padding_mode 必须为 {} 中的一种,但得到 padding_mode='{}'".format(
  45. valid_padding_modes, padding_mode))
  46. # 初始化模块参数
  47. self.in_channels = in_channels
  48. self.out_channels = out_channels
  49. self.kernel_size = kernel_size
  50. self.stride = stride
  51. self.padding = padding
  52. self.dilation = dilation
  53. self.groups = groups
  54. self.padding_mode = padding_mode
  55. self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))
  56. #################################### 初始化 D & W ###################################
  57. M = self.kernel_size[0]
  58. N = self.kernel_size[1]
  59. self.D_mul = M * N if D_mul is None or M * N <= 1 else D_mul
  60. self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, self.D_mul))
  61. init.kaiming_uniform_(self.W, a=math.sqrt(5))
  62. if M * N > 1:
  63. self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul))
  64. init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32)
  65. self.D.data = torch.from_numpy(init_zero)
  66. eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N))
  67. d_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N)))
  68. if self.D_mul % (M * N) != 0: # 当 D_mul > M * N 时
  69. zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)])
  70. self.d_diag = Parameter(torch.cat([d_diag, zeros], dim=2), requires_grad=False)
  71. else: # 当 D_mul = M * N 时
  72. self.d_diag = Parameter(d_diag, requires_grad=False)
  73. ##################################################################################################
  74. # 初始化偏置参数
  75. if bias:
  76. self.bias = Parameter(torch.Tensor(out_channels))
  77. fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
  78. bound = 1 / math.sqrt(fan_in)
  79. init.uniform_(self.bias, -bound, bound)
  80. else:
  81. self.register_parameter('bias', None)
  82. # 返回模块配置的字符串表示形式
  83. def extra_repr(self):
  84. s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
  85. ', stride={stride}')
  86. if self.padding != (0,) * len(self.padding):
  87. s += ', padding={padding}'
  88. if self.dilation != (1,) * len(self.dilation):
  89. s += ', dilation={dilation}'
  90. if self.groups != 1:
  91. s += ', groups={groups}'
  92. if self.bias is None:
  93. s += ', bias=False'
  94. if self.padding_mode != 'zeros':
  95. s += ', padding_mode={padding_mode}'
  96. return s.format(**self.__dict__)
  97. # 重新设置状态
  98. def __setstate__(self, state):
  99. super(DOConv2d, self).__setstate__(state)
  100. if not hasattr(self, 'padding_mode'):
  101. self.padding_mode = 'zeros'
  102. # 辅助函数,执行卷积操作
  103. def _conv_forward(self, input, weight):
  104. if self.padding_mode != 'zeros':
  105. return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
  106. weight, self.bias, self.stride,
  107. _pair(0), self.dilation, self.groups)
  108. return F.conv2d(input, weight, self.bias, self.stride,
  109. self.padding, self.dilation, self.groups)
  110. # 前向传播函数
  111. def forward(self, input):
  112. M = self.kernel_size[0]
  113. N = self.kernel_size[1]
  114. DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N)
  115. if M * N > 1:
  116. ######################### 计算 DoW #################
  117. # (input_channels, D_mul, M * N)
  118. D = self.D + self.d_diag
  119. W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul))
  120. # einsum 输出 (out_channels // groups, in_channels, M * N),
  121. # 重塑为
  122. # (out_channels, in_channels // groups, M, N)
  123. DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
  124. #######################################################
  125. else:
  126. # 在这种情况下 D_mul == M * N
  127. # 从 (out_channels, in_channels // groups, D_mul) 重塑为 (out_channels, in_channels // groups, M, N)
  128. DoW = torch.reshape(self.W, DoW_shape)
  129. return self._conv_forward(input, DoW)
  130. # 定义辅助函数
  131. def _ntuple(n):
  132. def parse(x):
  133. if isinstance(x, collections.abc.Iterable):
  134. return x
  135. return tuple(repeat(x, n))
  136. return parse
  137. # 定义辅助函数,将输入转化为二元元组
  138. _pair = _ntuple(2)

四、实验结果

4.1Image Classification

4.2Semantic Segmentation

4.3Object Detection 

五、总结

DO-Conv是一种深度过参数化卷积层,是一种新颖、简单、通用的提高cnn性能的方法。除了提高现有cnn的训练和最终精度的实际意义之外,在推理阶段不引入额外的计算,我们设想其优势的揭示也可以鼓励进一步探索过度参数化作为网络架构设计的一个新维度。

在未来,对这一相当简单的方法进行理论理解,以在一系列应用中实现令人惊讶的非凡性能改进,将是有趣的。此外,我们希望扩大这些过度参数化卷积层可能有效的应用范围,并了解哪些超参数可以从中受益更多。

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号