当前位置:   article > 正文

深度学习笔记3——Pytorch 图像处理中ECA注意力机制的解析与代码详解_eca模块代码

eca模块代码

一、ECA

在这里插入图片描述

  • ECA模块去除了原来SENet模块中的全连接层
  • 全局平均池化之后的特征上通过一个1D卷积进行学习
  • 通过sigmoid函数获取特征层每一个通道的权重
  • 将获取的权值与输入的Feature相乘

二、ECA的pytorch实现

class eca_block(nn.Module):

    def __init__(self, channel, gamma=2, b=1):
        super(eca_block, self).__init__()
        kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
        kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
        # 控制根据输入的channel控制kernel_size 的大小,保证kernel_size是一个奇数 
        padding = kernel_size // 2
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv1d(1, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, h, w = x.size()
        avg = self.avg_pool(x).view([b, 1, c])
        out = self.conv(avg)
        out = self.sigmoid(out).view([b, c, 1, 1])
        return out * x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/603962
推荐阅读
相关标签
  

闽ICP备14008679号