当前位置:   article > 正文

【无标题】嘻嘻嘻嘻嘻嘻嘻_max_energy_0 = torch.max(energy, -1, keepdim=true)

max_energy_0 = torch.max(energy, -1, keepdim=true)[0].expand_as(energy)
  1. class PRM(nn.Module):
  2. def __init__(self, output_chl_num, efficient=False):
  3. super(PRM, self).__init__()
  4. self.output_chl_num = output_chl_num
  5. self.conv_bn_relu_prm_1 = conv_bn_relu(self.output_chl_num, self.output_chl_num, kernel_size=3,
  6. stride=1, padding=1, has_bn=True, has_relu=True,
  7. efficient=efficient)
  8. self.conv_bn_relu_prm_2_1 = conv_bn_relu(self.output_chl_num, self.output_chl_num, kernel_size=1,
  9. stride=1, padding=0, has_bn=True, has_relu=True,
  10. efficient=efficient)
  11. self.conv_bn_relu_prm_2_2 = conv_bn_relu(self.output_chl_num, self.output_chl_num, kernel_size=1,
  12. stride=1, padding=0, has_bn=True, has_relu=True,
  13. efficient=efficient)
  14. self.sigmoid2 = nn.Sigmoid()
  15. self.conv_bn_relu_prm_3_1 = conv_bn_relu(self.output_chl_num, self.output_chl_num, kernel_size=1,
  16. stride=1, padding=0, has_bn=True, has_relu=True,
  17. efficient=efficient)
  18. self.conv_bn_relu_prm_3_2 = conv_bn_relu(self.output_chl_num, self.output_chl_num, kernel_size=9,
  19. stride=1, padding=4, has_bn=True, has_relu=True,
  20. efficient=efficient,groups=self.output_chl_num)
  21. self.sigmoid3 = nn.Sigmoid()
  22. def forward(self, x):
  23. out = self.conv_bn_relu_prm_1(x)
  24. out_1 = out
  25. out_2 = torch.nn.functional.adaptive_avg_pool2d(out_1, (1,1))
  26. out_2 = self.conv_bn_relu_prm_2_1(out_2)
  27. out_2 = self.conv_bn_relu_prm_2_2(out_2)
  28. out_2 = self.sigmoid2(out_2)
  29. out_3 = self.conv_bn_relu_prm_3_1(out_1)
  30. out_3 = self.conv_bn_relu_prm_3_2(out_3)
  31. out_3 = self.sigmoid3(out_3)
  32. out = out_1.mul(1 + out_2.mul(out_3))
  33. return out

  1. class CAM_Module(Module):
  2. """ Channel attention module"""
  3. def __init__(self, in_dim):
  4. super(CAM_Module, self).__init__()
  5. self.chanel_in = in_dim
  6. self.gamma = Parameter(torch.zeros(1))
  7. self.softmax = Softmax(dim=-1)
  8. def forward(self,x):
  9. """
  10. inputs :
  11. x : input feature maps( B X C X H X W)
  12. returns :
  13. out : attention value + input feature
  14. attention: B X C X C
  15. """
  16. m_batchsize, C, height, width = x.size()
  17. proj_query = x.view(m_batchsize, C, -1)
  18. proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
  19. energy = torch.bmm(proj_query, proj_key)
  20. energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
  21. attention = self.softmax(energy_new)
  22. proj_value = x.view(m_batchsize, C, -1)
  23. out = torch.bmm(attention, proj_value)
  24. out = out.view(m_batchsize, C, height, width)
  25. out = self.gamma*out + x
  26. return out

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/366910?site
推荐阅读
相关标签
  

闽ICP备14008679号