赞
踩
注意力机制(Attention Mechanism)是深度学习中一种重要的技术,它可以帮助模型更好地关注输入数据中的关键信息,从而提高模型的性能。注意力机制最早在自然语言处理领域的序列到序列(seq2seq)模型中得到广泛应用,后来逐渐扩展到了计算机视觉、语音识别等多个领域。
注意力机制的基本思想是为输入数据的每个部分分配一个权重,这个权重表示该部分对于当前任务的重要程度。在自然语言处理任务中,这通常意味着对输入句子中的每个单词分配一个权重,而在计算机视觉任务中,这可能意味着为输入图像的每个像素或区域分配一个权重。
总结:1.在conv.py加入注意力代码
2.在__init.oy__和tasks.py引用GAM
3.修改yaml文件
conv.py的路径:ultralytics-main\ultralytics\nn\modules\conv.py
如图下所示:
在conv.py的最下面添加注意力代码:
代码如下:
- #-----------注意力机制代码-----------------
- import torch.nn as nn
- import torch
-
- class GAM_Attention(nn.Module):
- def __init__(self, in_channels,c2, rate=4):
- super(GAM_Attention, self).__init__()
-
- self.channel_attention = nn.Sequential(
- nn.Linear(in_channels, int(in_channels / rate)),
- nn.ReLU(inplace=True),
- nn.Linear(int(in_channels / rate), in_channels)
- )
-
- self.spatial_attention = nn.Sequential(
- nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),
- nn.BatchNorm2d(int(in_channels / rate)),
- nn.ReLU(inplace=True),
- nn.Conv2d(int(in_channels / rate), in_channels, kernel_size=7, padding=3),
- nn.BatchNorm2d(in_channels)
- )
-
- def forward(self, x):
- b, c, h, w = x.shape
- x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
- x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
- x_channel_att = x_att_permute.permute(0, 3, 1, 2).sigmoid()
- x = x * x_channel_att
- x_spatial_att = self.spatial_attention(x).sigmoid()
- out = x * x_spatial_att
-
- return out
-
- if __name__ == '__main__':
- x = torch.randn(1, 64, 20, 20)
- b, c, h, w = x.shape
- net = GAM_Attention(in_channels=c)
- y = net(x)
- print(y.size())
效果如图下所示:
__init__.py文件中引用GAM_Attention
路径:ultralytics-main\ultralytics\nn\modules\__init__.py
如图下:
在__init__.py文件中,在导包里面找到from .conv import和__all__,最后面添加GAM_Attention。
如图下所示:
tasks.py 文件中引用GAM_Attention
路径:ultralytics-main\ultralytics\nn\tasks.py
如图下:
在tasks.py文件中,在导包里面找到from ultralytics.nn.modules最后面添加GAM_Attention
如图下所示:
在tasks.py里写入调用方式
打开tasks.py,Ctrl键+F查找n = 1(有空格)就可以找到添加的位置,如效果图:
- # """**************add Attention***************"""
- elif m in {GAM_Attention}:
- c1, c2 = ch[f], args[0]
- if c2 != nc: # if not output
- c2 = make_divisible(min(c2, max_channels) * width, 8)
- args = [c1, c2, *args[1:]]
效果如图下所示:
路径如下:ultralytics-main\ultralytics\cfg\models\v8\my_yolov8.yaml
如图下所示:
修改后的代码如下(可以直接复制到自己的yaml里面):
- # Ultralytics YOLO 声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/558453推荐阅读
相关标签
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。