当前位置:   article > 正文

涨点神器!gnConv打造新视觉主干家族:HorNet

gnconv

点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

作者夙曦 |  已授权转载(源:知乎)编辑:CVer

https://zhuanlan.zhihu.com/p/553143354

c75e9291d5b4977b5bf1cd8f2af16dfd.png

HorNet: Efficient High-Order Spatial Interactions with Recursive Gated Convolutions

代码:https://github.com/raoyongming/HorNet

论文:https://arxiv.org/abs/2207.14284

总结

  • 提出了递归门控卷积(gnConv),它通过门控卷积和递归设计来执行高阶空间交互,具有高度的灵活性和可定制性,兼容各种卷积变量,并将自注意的两阶交互扩展到任意阶,而不引入显著的额外计算。

  • gnConv可以作为一个即插即用的模块,以改进各种视觉Transformer和基于卷积的模型。在此基础上构建了一个新的通用视觉骨干家族,名为HorNet。

前言

图1展示了几张不同卷积的结构,并说明了优劣:

  1. 标准的卷积运算并没有明确地考虑空间间的相互作用。

  2. 动态卷积和SE引入了动态权值,提高具有额外空间交互的卷积的建模能力。

  3. 自注意操作通过两个连续的矩阵乘法进行二阶空间交互。

  4. gnConv使用门控卷积和递归对的高效实现实现任意顺序的空间交互。

8349b24fcf6cfa6c86eabee71e2629bf.png

不同卷积结构

方法

gnConv 递归门控卷积

gnConv是用标准卷积、线性投影和元素乘法构建的,但具有类似于自注意的输入-自适应空间混合函数。

882eb76cfe99f22e1e65b8cd931cfd92.png

与门控卷积之间的输入-自适应交互作用

图片的大小阻碍着视觉Transformer的应用,特别是分割和大分辨率检测。本文并没有寻求降低自注意的复杂性,而是寻求一种更有效的方法来通过卷积和全连接层等简单的操作来执行空间交互。

设x∈RHW×C为输入特征,门控卷积y=gConv(x)的输出可以写为

eea1bb6bfdb6cbc50e1071a61f2ae932.png

其中,φin,φout是执行通道混合的投影层,f是深度卷积。gConv中的交互作用是一阶交互作用,因为每个p0与它的邻居特征q0只有交互作用一次。

与递归门控的高阶交互作用

在与gConv实现有效的一阶空间交互作用后设计了gnConv,这是一种递归门控卷积,通过引入高阶交互作用进一步提高模型容量。

我们首先使用φin来获得一组投影特征p0和{qk}n−1k=0:

ca17d32772121ef925dbb2bc2ee02b4f.png

然后递归地执行门控卷积

4f96e6f3d91a6f93b6790f11732c52d8.png

我们将输出缩放为1/α来稳定训练。是一组基于深度的卷积层,并用于以不同的顺序匹配维度:

8f15854ffbfd9eeafd8b1485747d82ab.png

最后,我们将最后一个递归步骤qn的输出输入给投影层φout,得到gnConv的结果。

与大型核卷积的长期交互作用

传统的CNNs通常在整个网络中使用3×3卷积,而视觉Transformer在整个特征图或一个相对较大的局部窗口(例如7×7)内计算自注意。受此设计的启发,最近有一些努力将大型内核卷积引入cnn的。为了使我们的gnConv能够捕获长期的交互,我们采用了两种深度卷积的实现f:

  • 7 * 7卷积

  • 全局滤波器(Global Filter)

实验

99d5c89370b2a17c072fe6cd95571221.png

3242f58be30576d8fce93c384c16b2f4.png

通过ImageNet w.r.t.上的前1个精度来比较模型的权衡(a)个参数数;(b)FLOPs;(c)延迟。延迟是用一个单一的NVIDIA RTX 3090 GPU来测量的。

a00dbae76de8259171ae8f842cf27e8c.png


模块代码

gnConv

  1. class gnconv(nn.Module):
  2. def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
  3. super().__init__()
  4. self.order = order
  5. self.dims = [dim // 2 ** i for i in range(order)]
  6. self.dims.reverse()
  7. self.proj_in = nn.Conv2d(dim, 2*dim, 1)
  8. if gflayer is None:
  9. self.dwconv = get_dwconv(sum(self.dims), 7, True)
  10. else:
  11. self.dwconv = gflayer(sum(self.dims), h=h, w=w)
  12. self.proj_out = nn.Conv2d(dim, dim, 1)
  13. self.pws = nn.ModuleList(
  14. [nn.Conv2d(self.dims[i], self.dims[i+1], 1) for i in range(order-1)]
  15. )
  16. self.scale = s
  17. print('[gnconv]', order, 'order with dims=', self.dims, 'scale=%.4f'%self.scale)
  18. def forward(self, x, mask=None, dummy=False):
  19. B, C, H, W = x.shape
  20. fused_x = self.proj_in(x)
  21. pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)
  22. dw_abc = self.dwconv(abc) * self.scale
  23. dw_list = torch.split(dw_abc, self.dims, dim=1)
  24. x = pwa * dw_list[0]
  25. for i in range(self.order -1):
  26. x = self.pws[i](x) * dw_list[i+1]
  27. x = self.proj_out(x)
  28. return x

全局滤波器

  1. class GlobalLocalFilter(nn.Module):
  2. def __init__(self, dim, h=14, w=8):
  3. super().__init__()
  4. self.dw = nn.Conv2d(dim // 2, dim // 2, kernel_size=3, padding=1, bias=False, groups=dim // 2)
  5. self.complex_weight = nn.Parameter(torch.randn(dim // 2, h, w, 2, dtype=torch.float32) * 0.02)
  6. trunc_normal_(self.complex_weight, std=.02)
  7. self.pre_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
  8. self.post_norm = LayerNorm(dim, eps=1e-6, data_format='channels_first')
  9. def forward(self, x):
  10. x = self.pre_norm(x)
  11. x1, x2 = torch.chunk(x, 2, dim=1)
  12. x1 = self.dw(x1)
  13. x2 = x2.to(torch.float32)
  14. B, C, a, b = x2.shape
  15. x2 = torch.fft.rfft2(x2, dim=(2, 3), norm='ortho')
  16. weight = self.complex_weight
  17. if not weight.shape[1:3] == x2.shape[2:4]:
  18. weight = F.interpolate(weight.permute(3,0,1,2), size=x2.shape[2:4], mode='bilinear', align_corners=True).permute(1,2,3,0)
  19. weight = torch.view_as_complex(weight.contiguous())
  20. x2 = x2 * weight
  21. x2 = torch.fft.irfft2(x2, s=(a, b), dim=(2, 3), norm='ortho')
  22. x = torch.cat([x1.unsqueeze(2), x2.unsqueeze(2)], dim=2).reshape(B, 2 * C, a, b)
  23. x = self.post_norm(x)
  24. return x

点击进入—> CV 微信技术交流群

CVPR 2022论文和代码下载

 
 

后台回复:CVPR2022,即可下载CVPR 2022论文和代码开源的论文合集

后台回复:Transformer综述,即可下载最新的3篇Transformer综述PDF

  1. 目标检测和Transformer交流群成立
  2. 扫描下方二维码,或者添加微信:CVer6666,即可添加CVer小助手微信,便可申请加入CVer-目标检测或者Transformer 微信交流群。另外其他垂直方向已涵盖:目标检测、图像分割、目标跟踪、人脸检测&识别、OCR、姿态估计、超分辨率、SLAM、医疗影像、Re-ID、GAN、NAS、深度估计、自动驾驶、强化学习、车道线检测、模型剪枝&压缩、去噪、去雾、去雨、风格迁移、遥感图像、行为识别、视频理解、图像融合、图像检索、论文投稿&交流、PyTorch、TensorFlow和Transformer等。
  3. 一定要备注:研究方向+地点+学校/公司+昵称(如目标检测或者Transformer+上海+上交+卡卡),根据格式备注,可更快被通过且邀请进群
  4. ▲扫码或加微信: CVer6666,进交流群
  5. CVer学术交流群(知识星球)来了!想要了解最新最快最好的CV/DL/ML论文速递、优质开源项目、学习教程和实战训练等资料,欢迎扫描下方二维码,加入CVer学术交流群,已汇集数千人!
  6. ▲扫码进群
  7. ▲点击上方卡片,关注CVer公众号
  8. 整理不易,请点赞和在看
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/94001
推荐阅读
相关标签
  

闽ICP备14008679号