当前位置:   article > 正文

self-attention(pytorch 实现)_pytorch selfattention

pytorch selfattention

来源: MEF-GAN: Multi-Exposure Image Fusion via Generative Adversarial Networks

 

  1. class Attention(nn.Module):
  2. def __init__(self, bn=True):
  3. super(Attention, self).__init__()
  4. self.conv1 = nn.Conv2d(6, 16, kernel_size=3, stride=2)
  5. self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
  6. self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2)
  7. self.bn = nn.BatchNorm2d(16)
  8. self.relu = nn.ReLU()
  9. self.bn2 = nn.BatchNorm2d(32)
  10. self.Cv1 = nn.Conv2d(32, 32, kernel_size=1, stride=1)
  11. self.cv2 = nn.Conv2d(32, 8, kernel_size=1, stride=1)
  12. self.cv3 = nn.Conv2d(32, 8, kernel_size=1, stride=1)
  13. def forward(self, under, over):
  14. x = torch.cat((under, over), dim=1)
  15. output = self.relu(self.bn(self.conv1(x)))
  16. output = self.maxpool(output)
  17. output = self.relu(self.bn2(self.conv2(output)))
  18. C = self.Cv1(output)
  19. C = C.view(C.shape[0] * C.shape[1], C.shape[2] * C.shape[3])
  20. c1 = self.cv2(output)
  21. c1 = c1.view(c1.shape[0] * c1.shape[2] * c1.shape[3], 8)
  22. c2 = self.cv3(output)
  23. c2 = c2.view(c2.shape[0] * c2.shape[2] * c2.shape[3], 8).t()
  24. c = torch.nn.Softmax(torch.mm(c1, c2), dim=1)
  25. c = c.view(output.shape[0], c.shape[0], int(c.shape[1] // output.shape[0]))
  26. c = c.view(c.shape[0] * c.shape[1], c.shape[2])
  27. attention_map = torch.mm(C, c.t())
  28. attention_map = attention_map.view(output.shape[0], output.shape[1], output.shape[2] * output.shape[0], output.shape[3] * output.shape[0] )
  29. attention_map = F.interpolate(attention_map, size=[under.shape[2], under.shape[3]])
  30. return attention_map

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/351222
推荐阅读
相关标签
  

闽ICP备14008679号