当前位置:   article > 正文

DAMO-YOLO的Neck( Efficient RepGFPN)详解_efficient-repgfpn

efficient-repgfpn

 这个图是有点问题的,在GiraffeNeckV2代码中只有了5个Fusion Block(图中有6个)

https://github.com/tinyvision/DAMO-YOLO/blob/master/damo/base_models/necks/giraffe_fpn_btn.py

代码中只有5个CSPStage

所以我自己画了一个总体图,在github上提了个issue,得到了原作者的肯定

I think the pictures in your paper are not rigorous in several places · Issue #91 · tinyvision/DAMO-YOLO · GitHub

 

想要看懂Neck部分,只需要看懂Fusion Block在做什么就行了,其他部分和PAN差不太多

  1. class CSPStage(nn.Module):
  2. def __init__(self,
  3. block_fn,
  4. ch_in,
  5. ch_hidden_ratio,
  6. ch_out,
  7. n,
  8. act='swish',
  9. spp=False):
  10. super(CSPStage, self).__init__()
  11. split_ratio = 2
  12. ch_first = int(ch_out // split_ratio)
  13. ch_mid = int(ch_out - ch_first)
  14. self.conv1 = ConvBNAct(ch_in, ch_first, 1, act=act)
  15. self.conv2 = ConvBNAct(ch_in, ch_mid, 1, act=act)
  16. self.convs = nn.Sequential()
  17. next_ch_in = ch_mid
  18. for i in range(n):
  19. if block_fn == 'BasicBlock_3x3_Reverse':
  20. self.convs.add_module(
  21. str(i),
  22. BasicBlock_3x3_Reverse(next_ch_in,
  23. ch_hidden_ratio,
  24. ch_mid,
  25. act=act,
  26. shortcut=True))
  27. else:
  28. raise NotImplementedError
  29. if i == (n - 1) // 2 and spp:
  30. self.convs.add_module(
  31. 'spp', SPP(ch_mid * 4, ch_mid, 1, [5, 9, 13], act=act))
  32. next_ch_in = ch_mid
  33. self.conv3 = ConvBNAct(ch_mid * n + ch_first, ch_out, 1, act=act)
  34. def forward(self, x):
  35. y1 = self.conv1(x)
  36. y2 = self.conv2(x)
  37. mid_out = [y1]
  38. for conv in self.convs:
  39. y2 = conv(y2)
  40. mid_out.append(y2)
  41. y = torch.cat(mid_out, axis=1)
  42. y = self.conv3(y)
  43. return y

以上是CSPStage的代码,要想看懂,我们得先看懂ConvBNAct、BasicBlock_3x3_Reverse这两个类

  1. class ConvBNAct(nn.Module):
  2. """A Conv2d -> Batchnorm -> silu/leaky relu block"""
  3. def __init__(
  4. self,
  5. in_channels,
  6. out_channels,
  7. ksize,
  8. stride=1,
  9. groups=1,
  10. bias=False,
  11. act='silu',
  12. norm='bn',
  13. reparam=False,
  14. ):
  15. super().__init__()
  16. # same padding
  17. pad = (ksize - 1) // 2
  18. self.conv = nn.Conv2d(
  19. in_channels,
  20. out_channels,
  21. kernel_size=ksize,
  22. stride=stride,
  23. padding=pad,
  24. groups=groups,
  25. bias=bias,
  26. )
  27. if norm is not None:
  28. self.bn = get_norm(norm, out_channels, inplace=True)
  29. if act is not None:
  30. self.act = get_activation(act, inplace=True)
  31. self.with_norm = norm is not None
  32. self.with_act = act is not None
  33. def forward(self, x):
  34. x = self.conv(x)
  35. if self.with_norm:
  36. x = self.bn(x)
  37. if self.with_act:
  38. x = self.act(x)
  39. return x
  40. def fuseforward(self, x):
  41. return self.act(self.conv(x))

ConvBNAct还是很好看懂的,Conv +BN + SiLU就完事了(也可用别的激活函数,文章用SiLU)

 如果设置了groups参数就变成了组卷积了

  1. class BasicBlock_3x3_Reverse(nn.Module):
  2. def __init__(self,
  3. ch_in,
  4. ch_hidden_ratio,
  5. ch_out,
  6. act='relu',
  7. shortcut=True):
  8. super(BasicBlock_3x3_Reverse, self).__init__()
  9. assert ch_in == ch_out
  10. ch_hidden = int(ch_in * ch_hidden_ratio)
  11. self.conv1 = ConvBNAct(ch_hidden, ch_out, 3, stride=1, act=act)
  12. self.conv2 = RepConv(ch_in, ch_hidden, 3, stride=1, act=act)
  13. self.shortcut = shortcut
  14. def forward(self, x):
  15. y = self.conv2(x)
  16. y = self.conv1(y)
  17. if self.shortcut:
  18. return x + y
  19. else:
  20. return y

要看懂BasicBlock_3x3_Reverse这个类,就得了解RepConv类,这个类就是根据RepVGG网络的RepVGGBlock改的

  1. class RepConv(nn.Module):
  2. '''RepConv is a basic rep-style block, including training and deploy status
  3. Code is based on https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
  4. '''
  5. def __init__(self,
  6. in_channels,
  7. out_channels,
  8. kernel_size=3,
  9. stride=1,
  10. padding=1,
  11. dilation=1,
  12. groups=1,
  13. padding_mode='zeros',
  14. deploy=False,
  15. act='relu',
  16. norm=None):
  17. super(RepConv, self).__init__()
  18. self.deploy = deploy
  19. self.groups = groups
  20. self.in_channels = in_channels
  21. self.out_channels = out_channels
  22. assert kernel_size == 3
  23. assert padding == 1
  24. padding_11 = padding - kernel_size // 2
  25. if isinstance(act, str):
  26. self.nonlinearity = get_activation(act)
  27. else:
  28. self.nonlinearity = act
  29. if deploy:
  30. self.rbr_reparam = nn.Conv2d(in_channels=in_channels,
  31. out_channels=out_channels,
  32. kernel_size=kernel_size,
  33. stride=stride,
  34. padding=padding,
  35. dilation=dilation,
  36. groups=groups,
  37. bias=True,
  38. padding_mode=padding_mode)
  39. else:
  40. self.rbr_identity = None
  41. self.rbr_dense = conv_bn(in_channels=in_channels,
  42. out_channels=out_channels,
  43. kernel_size=kernel_size,
  44. stride=stride,
  45. padding=padding,
  46. groups=groups)
  47. self.rbr_1x1 = conv_bn(in_channels=in_channels,
  48. out_channels=out_channels,
  49. kernel_size=1,
  50. stride=stride,
  51. padding=padding_11,
  52. groups=groups)
  53. def forward(self, inputs):
  54. '''Forward process'''
  55. if hasattr(self, 'rbr_reparam'):
  56. return self.nonlinearity(self.rbr_reparam(inputs))
  57. if self.rbr_identity is None:
  58. id_out = 0
  59. else:
  60. id_out = self.rbr_identity(inputs)
  61. return self.nonlinearity(
  62. self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
  63. def get_equivalent_kernel_bias(self):
  64. kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
  65. kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
  66. kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
  67. return kernel3x3 + self._pad_1x1_to_3x3_tensor(
  68. kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
  69. def _pad_1x1_to_3x3_tensor(self, kernel1x1):
  70. if kernel1x1 is None:
  71. return 0
  72. else:
  73. return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
  74. def _fuse_bn_tensor(self, branch):
  75. if branch is None:
  76. return 0, 0
  77. if isinstance(branch, nn.Sequential):
  78. kernel = branch.conv.weight
  79. running_mean = branch.bn.running_mean
  80. running_var = branch.bn.running_var
  81. gamma = branch.bn.weight
  82. beta = branch.bn.bias
  83. eps = branch.bn.eps
  84. else:
  85. assert isinstance(branch, nn.BatchNorm2d)
  86. if not hasattr(self, 'id_tensor'):
  87. input_dim = self.in_channels // self.groups
  88. kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
  89. dtype=np.float32)
  90. for i in range(self.in_channels):
  91. kernel_value[i, i % input_dim, 1, 1] = 1
  92. self.id_tensor = torch.from_numpy(kernel_value).to(
  93. branch.weight.device)
  94. kernel = self.id_tensor
  95. running_mean = branch.running_mean
  96. running_var = branch.running_var
  97. gamma = branch.weight
  98. beta = branch.bias
  99. eps = branch.eps
  100. std = (running_var + eps).sqrt()
  101. t = (gamma / std).reshape(-1, 1, 1, 1)
  102. return kernel * t, beta - running_mean * gamma / std
  103. def switch_to_deploy(self):
  104. if hasattr(self, 'rbr_reparam'):
  105. return
  106. kernel, bias = self.get_equivalent_kernel_bias()
  107. self.rbr_reparam = nn.Conv2d(
  108. in_channels=self.rbr_dense.conv.in_channels,
  109. out_channels=self.rbr_dense.conv.out_channels,
  110. kernel_size=self.rbr_dense.conv.kernel_size,
  111. stride=self.rbr_dense.conv.stride,
  112. padding=self.rbr_dense.conv.padding,
  113. dilation=self.rbr_dense.conv.dilation,
  114. groups=self.rbr_dense.conv.groups,
  115. bias=True)
  116. self.rbr_reparam.weight.data = kernel
  117. self.rbr_reparam.bias.data = bias
  118. for para in self.parameters():
  119. para.detach_()
  120. self.__delattr__('rbr_dense')
  121. self.__delattr__('rbr_1x1')
  122. if hasattr(self, 'rbr_identity'):
  123. self.__delattr__('rbr_identity')
  124. if hasattr(self, 'id_tensor'):
  125. self.__delattr__('id_tensor')
  126. self.deploy = True

 RepConv的特点是结构重参数化,训练时采用三条分支,推理时将三个分支融合在一起,大大减少了推理时间(建议看看RepVGG的讲解视频),我图画得太丑了

  RepConv采用的两分支的结构(a)

 其他细节有缘再更,代码不难,慢慢看完全能懂。有写的不对的地方请见谅

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/204244
推荐阅读
相关标签
  

闽ICP备14008679号