赞
踩
- class PRM(nn.Module):
-
- def __init__(self, output_chl_num, efficient=False):
- super(PRM, self).__init__()
- self.output_chl_num = output_chl_num
- self.conv_bn_relu_prm_1 = conv_bn_relu(self.output_chl_num, self.output_chl_num, kernel_size=3,
- stride=1, padding=1, has_bn=True, has_relu=True,
- efficient=efficient)
- self.conv_bn_relu_prm_2_1 = conv_bn_relu(self.output_chl_num, self.output_chl_num, kernel_size=1,
- stride=1, padding=0, has_bn=True, has_relu=True,
- efficient=efficient)
- self.conv_bn_relu_prm_2_2 = conv_bn_relu(self.output_chl_num, self.output_chl_num, kernel_size=1,
- stride=1, padding=0, has_bn=True, has_relu=True,
- efficient=efficient)
- self.sigmoid2 = nn.Sigmoid()
- self.conv_bn_relu_prm_3_1 = conv_bn_relu(self.output_chl_num, self.output_chl_num, kernel_size=1,
- stride=1, padding=0, has_bn=True, has_relu=True,
- efficient=efficient)
- self.conv_bn_relu_prm_3_2 = conv_bn_relu(self.output_chl_num, self.output_chl_num, kernel_size=9,
- stride=1, padding=4, has_bn=True, has_relu=True,
- efficient=efficient,groups=self.output_chl_num)
- self.sigmoid3 = nn.Sigmoid()
-
- def forward(self, x):
- out = self.conv_bn_relu_prm_1(x)
- out_1 = out
- out_2 = torch.nn.functional.adaptive_avg_pool2d(out_1, (1,1))
- out_2 = self.conv_bn_relu_prm_2_1(out_2)
- out_2 = self.conv_bn_relu_prm_2_2(out_2)
- out_2 = self.sigmoid2(out_2)
- out_3 = self.conv_bn_relu_prm_3_1(out_1)
- out_3 = self.conv_bn_relu_prm_3_2(out_3)
- out_3 = self.sigmoid3(out_3)
- out = out_1.mul(1 + out_2.mul(out_3))
- return out
- class CAM_Module(Module):
- """ Channel attention module"""
- def __init__(self, in_dim):
- super(CAM_Module, self).__init__()
- self.chanel_in = in_dim
-
-
- self.gamma = Parameter(torch.zeros(1))
- self.softmax = Softmax(dim=-1)
- def forward(self,x):
- """
- inputs :
- x : input feature maps( B X C X H X W)
- returns :
- out : attention value + input feature
- attention: B X C X C
- """
- m_batchsize, C, height, width = x.size()
- proj_query = x.view(m_batchsize, C, -1)
- proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
- energy = torch.bmm(proj_query, proj_key)
- energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
- attention = self.softmax(energy_new)
- proj_value = x.view(m_batchsize, C, -1)
-
- out = torch.bmm(attention, proj_value)
- out = out.view(m_batchsize, C, height, width)
-
- out = self.gamma*out + x
- return out
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。