当前位置:   article > 正文

YOLOv8模型改进6【增加注意力机制Gaussian Context Transformer+解决报错:ImportError:cannot import name ‘GCT‘ ......】_gtg yolo

gtg yolo

一、Gaussian Context Transformer注意力机制简介

Gaussian Context Transformer缩写为:GCT注意力机制。它是一种通道注意力结构的注意力模块,与SE注意力相似,但它可以做到在几乎不引入参数的前提下优于SE等注意力模型。(属于轻量高效类型的注意力机制)

GCT注意力机制的模型结构如图所示:
在这里插入图片描述
下图是GCT模块在模型参数方面的表现:
在这里插入图片描述
下图为GCT模块在目标检测任务上的表现:
在这里插入图片描述
在这里插入图片描述
【注:论文-Gaussian Context Transformer链接】:https://openaccess.thecvf.com/content/CVPR2021/papers/Ruan_Gaussian_Context_Transformer_CVPR_2021_paper.pdf

【注:没有找到Github上的代码,但是有厉害的大哥给了代码,具体见下面的改进】

二、增加GCT注意力机制到YOLOv8模型上

基本上还是一样的,主要看差别吧:(注意:依据的代码是直到发文为止最新的YOLOv8代码进行的改进,请及时更新代码)

【1: …/ultralytics/nn/modules/conv.py

在文件末尾增加有关GCT模块的代码:

#增加GCT注意力
class GCT(nn.Module):
    def __init__(self, channels, c=2, eps=1e-5):
        super().__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.eps = eps
        self.c = c

    def forward(self, x):
        y = self.avgpool(x)
        mean = y.mean(dim=1, keepdim=True)
        mean_x2 = (y ** 2).mean(dim=1, keepdim=True)
        var = mean_x2 - mean ** 2
        y_norm = (y - mean) / torch.sqrt(var + self.eps)
        y_transform = torch.exp(-(y_norm ** 2 / 2 * self.c))
        return x * y_transform.expand_as(x)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

【2:…ultralytics-main/ultralytics/nn/modules/__init__.py

这个是一样的,在这个文件下对GCT模块进行声明,声明的名字要与前面的GCT代码模块名保持一致,可以参考这篇文章:https://blog.csdn.net/A__MP/article/details/136597192

1:找到这段,在后面加上模块名:GCT
from .conv 
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/558480
推荐阅读
相关标签
  

闽ICP备14008679号