赞
踩
在ultralytics/nn目录下,创建SEAttention.py文件,并。
- import numpy as np
- import torch
- from torch import nn
- from torch.nn import init
-
-
- class SEAttention(nn.Module):
-
- def __init__(self, channel=512,reduction=16):
- super().__init__()
- self.avg_pool = nn.AdaptiveAvgPool2d(1)
- self.fc = nn.Sequential(
- nn.Linear(channel, channel // reduction, bias=False),
- nn.ReLU(inplace=True),
- nn.Linear(channel // reduction, channel, bias=False),
- nn.Sigmoid()
- )
-
-
- def init_weights(self):
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- init.kaiming_normal_(m.weight, mode='fan_out')
- if m.bias is not None:
- init.constant_(m.bias, 0)
- elif isinstance(m, nn.BatchNorm2d):
- init.constant_(m.weight, 1)
- init.constant_(m.bias, 0)
- elif isinstance(m, nn.Linear):
- init.normal_(m.weight, std=0.001)
- if m.bias is not None:
- init.constant_(m.bias, 0)
-
- def forward(self, x):
- b, c, _, _ = x.size()
- y = self.avg_pool(x).view(b, c)
- y = self.fc(y).view(b, c, 1, 1)
- return x * y.expand_as(x)
在tasks.py文件里导入SEAttention注意力机制
然后引用SEAttention注意力机制到parse_model函数中,CTRT+F查找parse_model,在下图红框中添加如下代码(如图所示)
- # SEAttention注意力机制
- elif m in {SEAttention}:
- args = [ch[f], *args]
或者直接用以下代码直接替换tasks.py的代码
- # Ultralytics YOLO 声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/539741推荐阅读
相关标签
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。