赞
踩
空间维度(channel)
来进行特征压缩,将每个二维的特征通道变成一个实数,这个实数某种程度上具有全局的感受野,并且输出的维度和输入的特征通道数相匹配。它表征着在特征通道上响应的全局分布,且使得靠近输入的层也可以获得全局的感受野。50×512×7×7
进行global average pooling,然后得到了一个50×512×1×1
大小的特征图,这个特征图具有全局感受野。50×512×1×1
特征图,经过两个全连接神经网络,最后用一 个类似于循环神经网络中门的机制,通过参数来为每个特征通道生成权重,中参数被学习用来显式地建模特征通道间的相关性(论文中使用的是sigmoid
)。50×512×1×1
变成50×512 / 16×1×1
,最后再还原回来:50×512×1×1
50×512×1×1
通过expand_as
得到50×512×7×7
), 完成在通道维度上对原始特征的重标定,并作为下一级的输入数据。class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
import numpy as np
import torch
from torch import nn
from torch.nn import init
class SEAttention(nn.Module):
def __init__(self, channel=512, reduction=16):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # 全局均值池化 输出的是c×1×1
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False), # channel // reduction代表通道压缩
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False), # 还原
nn.Sigmoid()
)
def init_weights(self):
for m in self.modules():
print(m) # 没运行到这儿
if isinstance(m, nn.Conv2d): # 判断类型函数——:m是nn.Conv2d类吗?
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, _, _ = x.size() # 50×512×7×7
y = self.avg_pool(x).view(b, c) # ① maxpool之后得:50×512×1×1 ② view形状得到50×512
y = self.fc(y).view(b, c, 1, 1) # 50×512×1×1
return x * y.expand_as(x) # 根据x.size来扩展y
if __name__ == '__main__':
input = torch.randn(50, 512, 7, 7)
se = SEAttention(channel=512, reduction=8) # 实例化模型se
output = se(input)
print(output.shape)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。