当前位置:   article > 正文

YOLOV5改进:CVPR2023:加入EfficientViT主干:具级联组注意力的访存高效ViT_yolov5改进 efficientvit

yolov5改进 efficientvit

 

论文地址

视觉变压器由于其高模型能力而取得了巨大的成功。然而,它们卓越的性能伴随着沉重的计算成本,这使得它们不适合实时应用。在这篇论文中,我们提出了一个高速视觉变压器家族,名为EfficientViT。我们发现现有的变压器模型的速度通常受到内存低效操作的限制,特别是在MHSA中的张量重塑和单元函数。因此,我们设计了一种具有三明治布局的新构建块,即在高效FFN层之间使用单个内存绑定的MHSA,从而提高了内存效率,同时增强了信道通信。此外,我们发现注意图在头部之间具有很高的相似性,从而导致计算冗余。为了解决这个问题,我们提出了一个级联的群体注意模块,以不同的完整特征分割来馈送注意头,不仅节省了计算成本,而且提高了注意多样性。综合实验表明,高效vit优于现有的高效模型,在速度和精度之间取得了良好的平衡。例如,我们的EfficientViT-M5在准确率上比MobileNetV3-Large高出1.9%,而在Nvidia V100 GPU和Intel Xeon CPU上的吞吐量分别高出40.4%和45.2%。与最近的高效型号MobileViT-XXS相比,efficientvitt - m2的精度提高了1.8%,同时在GPU/CPU上运行速度提高了5.8倍/3.7倍,转换为ONNX格式时速度提高了7.4倍。代码和模型可在这里获得

 

以yolov5 7.0版本进行改进

1.nextvit.py文件,添加如下代码:

  1. # --------------------------------------------------------
  2. # EfficientViT Model Architecture for Downstream Tasks
  3. # Copyright (c) 2022 Microsoft
  4. # Written by: Xinyu Liu
  5. # --------------------------------------------------------
  6. import torch
  7. import torch.nn as nn
  8. import torch.nn.functional as F
  9. import torch.utils.checkpoint as checkpoint
  10. import itertools
  11. from timm.models.layers import SqueezeExcite
  12. import numpy as np
  13. import itertools
  14. __all__ = ['EfficientViT_M0', 'EfficientViT_M1', 'EfficientViT_M2', 'EfficientViT_M3', 'EfficientViT_M4',
  15. 'EfficientViT_M5']
  16. class Conv2d_BN(torch.nn.Sequential):
  17. def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
  18. groups=1, bn_weight_init=1, resolution=-10000):
  19. super().__init__()
  20. self.add_module('c', torch.nn.Conv2d(
  21. a, b, ks, stride, pad, dilation, groups, bias=False))
  22. self.add_module('bn', torch.nn.BatchNorm2d(b))
  23. torch.nn.init.constant_(self.bn.weight, bn_weight_init)
  24. torch.nn.init.constant_(self.bn.bias, 0)
  25. @torch.no_grad()
  26. def fuse(self):
  27. c, bn = self._modules.values()
  28. w = bn.weight / (bn.running_var + bn.eps) ** 0.5
  29. w = c.weight * w[:, None, None, None]
  30. b = bn.bias - bn.running_mean * bn.weight / \
  31. (bn.running_var + bn.eps) ** 0.5
  32. m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
  33. 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation,
  34. groups=self.c.groups)
  35. m.weight.data.copy_(w)
  36. m.bias.data.copy_(b)
  37. return m
  38. def replace_batchnorm(net):
  39. for child_name, child in net.named_children():
  40. if hasattr(child, 'fuse'):
  41. setattr(net, child_name, child.fuse())
  42. elif isinstance(child, torch.nn.BatchNorm2d):
  43. setattr(net, child_name, torch.nn.Identity())
  44. else:
  45. replace_batchnorm(child)
  46. class PatchMerging(torch.nn.Module):
  47. def __init__(self, dim, out_dim, input_resolution):
  48. super().__init__()
  49. hid_dim = int(dim * 4)
  50. self.conv1 = Conv2d_BN(dim, hid_dim, 1, 1, 0, resolution=input_resolution)
  51. self.act = torch.nn.ReLU()
  52. self.conv2 = Conv2d_BN(hid_dim, hid_dim, 3, 2, 1, groups=hid_dim, resolution=input_resolution)
  53. self.se = SqueezeExcite(hid_dim, .25)
  54. self.conv3 = Conv2d_BN(hid_dim, out_dim, 1, 1, 0, resolution=input_resolution // 2)
  55. def forward(self, x):
  56. x = self.conv3(self.se(self.act(self.conv2(self.act(self.conv1(x))))))
  57. return x
  58. class Residual(torch.nn.Module):
  59. def __init__(self, m, drop=0.):
  60. super().__init__()
  61. self.m = m
  62. self.drop = drop
  63. def forward(self, x):
  64. if self.training and self.drop > 0:
  65. return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
  66. device=x.device).ge_(self.drop).div(1 - self.drop).detach()
  67. else:
  68. return x + self.m(x)
  69. class FFN(torch.nn.Module):
  70. def __init__(self, ed, h, resolution):
  71. super().__init__()
  72. self.pw1 = Conv2d_BN(ed, h, resolution=resolution)
  73. self.act = torch.nn.ReLU()
  74. self.pw2 = Conv2d_BN(h, ed, bn_weight_init=0, resolution=resolution)
  75. def forward(self, x):
  76. x = self.pw2(self.act(self.pw1(x)))
  77. return x
  78. class CascadedGroupAttention(torch.nn.Module):
  79. r""" Cascaded Group Attention.
  80. Args:
  81. dim (int): Number of input channels.
  82. key_dim (int): The dimension for query and key.
  83. num_heads (int): Number of attention heads.
  84. attn_ratio (int): Multiplier for the query dim for value dimension.
  85. resolution (int): Input resolution, correspond to the window size.
  86. kernels (List[int]): The kernel size of the dw conv on query.
  87. """
  88. def __init__(self, dim, key_dim, num_heads=8,
  89. attn_ratio=4,
  90. resolution=14,
  91. kernels=[5, 5, 5, 5], ):
  92. super().__init__()
  93. self.num_heads = num_heads
  94. self.scale = key_dim ** -0.5
  95. self.key_dim = key_dim
  96. self.d = int(attn_ratio * key_dim)
  97. self.attn_ratio = attn_ratio
  98. qkvs = []
  99. dws = []
  100. for i in range(num_heads
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/439878
推荐阅读
相关标签
  

闽ICP备14008679号