当前位置:   article > 正文

YoloV8改进---注意力机制:高斯上下文变换器GCT,性能优于ECA、SE等注意力模块 | CVPR2021

yolov8改进

目录

 1.GCT介绍

实验结果

 2.GCT引入到yolov8

2.1 修改modules.py中:

2.2 加入tasks.py中:

2.3 yolov8_GCT.yaml

3.YOLOv8魔术师专栏介绍



 

 1.GCT介绍

 论文:https://openaccess.thecvf.com/content/CVPR2021/papers/Ruan_Gaussian_Context_Transformer_CVPR_2021_paper.pdf

浙江大学等机构发布的一篇收录于CVPR2021的文章,提出了一种新的通道注意力结构,在几乎不引入参数的前提下优于大多SOTA通道注意力模型,如SE、ECA等。这篇文章虽然叫Gaussian Context Transformer,但是和Transformer并无太多联系,这里可以理解为高斯上下文变换器。

         LCT(linear context transform)观察所得,如下图所示,SE倾向于学习一种负相关,即全局上下文偏离均值越多,得到的注意力激活值就越小。为了更加精准地学习这种相关性,LCT使用一个逐通道地变换来替代SE中的两个全连接层。然而,实验表明,LCT学得的这种负相关质量并不是很高,下图中右侧可以看出,LCT的注意力激活值波动是很大的。
 

         在本文中,我们假设这种关系是预先确定的。基于这个假设,我们提出了一个简单但极其有效的通道注意力块,称为高斯上下文Transformer (GCT),它使用满足预设关系的高斯函数实现上下文特征激励。

实验结果

在ImageNet 和 MS COCO 基准测试的大量实验表明,我们的 GCT 导致各种深度 CNN 和检测器的持续改进。与一系列最先进的通道注意力块(例如 SE 和 ECA)相比,我们的 GCT 在有效性和效率方面更为出色。

 2.GCT引入到yolov8

2.1 修改modules.py中:

  1. ###################### Gaussian Context Transformer attention #### END by AI&CV ###############################
  2. """
  3. PyTorch implementation of Gaussian Context Transformer
  4. As described in http://openaccess.thecvf.com//content/CVPR2021/papers/Ruan_Gaussian_Context_Transformer_CVPR_2021_paper.pdf
  5. Gaussian Context Transformer (GCT), which achieves contextual feature excitation using
  6. a Gaussian function that satisfies the presupposed relationship.
  7. """
  8. import torch
  9. from torch import nn
  10. class GCT(nn.Module):
  11. def __init__(self, channels, c=2, eps=1e-5):
  12. super().__init__()
  13. self.avgpool = nn.AdaptiveAvgPool2d(1)
  14. self.eps = eps
  15. self.c = c
  16. def forward(self, x):
  17. y = self.avgpool(x)
  18. mean = y.mean(dim=1, keepdim=True)
  19. mean_x2 = (y ** 2).mean(dim=1, keepdim=True)
  20. var = mean_x2 - mean ** 2
  21. y_norm = (y - mean) / torch.sqrt(var + self.eps)
  22. y_transform = torch.exp(-(y_norm ** 2 / 2 * self.c))
  23. return x * y_transform.expand_as(x)
  24. ###################### Gaussian Context Transformer attention #### END by AI&CV ###############################

2.2 加入tasks.py中:

from ultralytics.nn.modules import (C1, C2, C3, C3TR, SPP, SPPF, Bottleneck, BottleneckCSP, C2f, C3Ghost, C3x, Classify, Concat, Conv, ConvTranspose, Detect, DWConv, DWConvTranspose2d, Ensemble, Focus,GCT)

 修改def parse_model(d, ch, verbose=True): # model_dict, input_channels(3):

  1. if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,
  2. BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x,GCT
  3. ):

2.3 yolov8_GCT.yaml

  1. # Ultralytics YOLO
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/271248
    推荐阅读
    相关标签