当前位置:   article > 正文

零基础教程:Yolov5模型改进-添加13种注意力机制_yolo模型添加注意力机制

yolo模型添加注意力机制

1.准备工作

先给出13种注意力机制的下载地址:

https://github.com/z1069614715/objectdetection_script

2.加入注意力机制

1.以添加SimAM注意力机制为例(不需要接收通道数的注意力机制)

1.在models文件下新建py文件,取名叫SimAM.py

将以下代码复制到SimAM.py文件种

  1. import torch
  2. import torch.nn as nn
  3. class SimAM(torch.nn.Module):
  4. # 不需要接收通道数输入
  5. def __init__(self, e_lambda=1e-4):
  6. super(SimAM, self).__init__()
  7. self.activaton = nn.Sigmoid()
  8. self.e_lambda = e_lambda
  9. def __repr__(self):
  10. s = self.__class__.__name__ + '('
  11. s += ('lambda=%f)' % self.e_lambda)
  12. return s
  13. @staticmethod
  14. def get_module_name():
  15. return "simam"
  16. def forward(self, x):
  17. b, c, h, w = x.size()
  18. n = w * h - 1
  19. x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
  20. y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
  21. return x * self.activaton(y)

2.在yolo.py头部导入SimAM这个类

3.然后复制yolov5s.yaml到同级目录,取名为yolov5s-SimAM.yaml

在某一层添加注意力机制

[from,number,module,args]

注意:!!!!!!!!!!!!!!!!!!!

添加完一层注意力机制之后,会对后面层数造成影响,记得在检测头那里要改层数

2.添加SE注意力机制(需要接收通道数的注意力机制)

1.新建SE.py

  1. import numpy as np
  2. import torch
  3. from torch import nn
  4. from torch.nn import init
  5. class SEAttention(nn.Module):
  6. def __init__(self, channel=512,reduction=16):
  7. super().__init__()
  8. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  9. self.fc = nn.Sequential(
  10. nn.Linear(channel, channel // reduction, bias=False),
  11. nn.ReLU(inplace=True),
  12. nn.Linear(channel // reduction, channel, bias=False),
  13. nn.Sigmoid()
  14. )
  15. def init_weights(self):
  16. for m in self.modules():
  17. if isinstance(m, nn.Conv2d):
  18. init.kaiming_normal_(m.weight, mode='fan_out')
  19. if m.bias is not None:
  20. init.constant_(m.bias, 0)
  21. elif isinstance(m, nn.BatchNorm2d):
  22. init.constant_(m.weight, 1)
  23. init.constant_(m.bias, 0)
  24. elif isinstance(m, nn.Linear):
  25. init.normal_(m.weight, std=0.001)
  26. if m.bias is not None:
  27. init.constant_(m.bias, 0)
  28. def forward(self, x):
  29. b, c, _, _ = x.size()
  30. y = self.avg_pool(x).view(b, c)
  31. y = self.fc(y).view(b, c, 1, 1)
  32. return x * y.expand_as(x)

2.修改yolo.py

添加这两行代码

  1. elif m is SEAttention:
  2. args = [ch[f]]

3.models下新建yolov5s-SE.yaml

  1. # YOLOv5
    声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/873804
    推荐阅读
    相关标签