当前位置:   article > 正文

self-attention 的 pytorch 实现_pytorch self-attention

pytorch self-attention

参考self-attention 的 pytorch 实现 - 云+社区 - 腾讯云

问题

基于条件的卷积GAN 在那些约束较少的类别中生成的图片较好,比如大海,天空等;但是在那些细密纹理,全局结构较强的类别中生成的图片不是很好,如人脸(可能五官不对应),狗(可能狗腿数量有差,或者毛色不协调)。

可能的原因

大部分卷积神经网络都严重依赖于局部感受野,而无法捕捉全局特征。另外,在多次卷积之后,细密的纹理特征逐渐消失。

SA-GAN解决思路

不仅仅依赖于局部特征,也利用全局特征,通过将不同位置的特征图结合起来(转置就可以结合不同位置的特征)。

  1. ##############################
  2. # self attention layer
  3. # author Xu Mingle
  4. # time Feb 18, 2019
  5. ##############################
  6. import torch.nn.Module
  7. import torch
  8. import torch.nn.init
  9. def init_conv(conv, glu=True):
  10. init.xavier_uniform_(conv.weight)
  11. if conv.bias is not None:
  12. conv.bias.data.zero_()
  13. class SelfAttention(nn.Module):
  14. r"""
  15. Self attention Layer.
  16. Source paper: https://arxiv.org/abs/1805.08318
  17. """
  18. def __init__(self, in_dim, activation=F.relu):
  19. super(SelfAttention, self).__init__()
  20. self.chanel_in = in_dim
  21. self.activation = activation
  22. self.f = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8 , kernel_size=1)
  23. self.g = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8 , kernel_size=1)
  24. self.h = nn.Conv2d(in_channels=in_dim, out_channels=in_dim , kernel_size=1)
  25. self.gamma = nn.Parameter(torch.zeros(1))
  26. self.softmax = nn.Softmax(dim=-1)
  27. init_conv(self.f)
  28. init_conv(self.g)
  29. init_conv(self.h)
  30. def forward(self, x):
  31. """
  32. inputs :
  33. x : input feature maps( B X C X W X H)
  34. returns :
  35. out : self attention feature maps
  36. """
  37. m_batchsize, C, width, height = x.size()
  38. f = self.f(x).view(m_batchsize, -1, width * height) # B * (C//8) * (W * H)
  39. g = self.g(x).view(m_batchsize, -1, width * height) # B * (C//8) * (W * H)
  40. h = self.h(x).view(m_batchsize, -1, width * height) # B * C * (W * H)
  41. attention = torch.bmm(f.permute(0, 2, 1), g) # B * (W * H) * (W * H)
  42. attention = self.softmax(attention)
  43. self_attetion = torch.bmm(h, attention) # B * C * (W * H)
  44. self_attetion = self_attetion.view(m_batchsize, C, width, height) # B * C * W * H
  45. out = self.gamma * self_attetion + x
  46. return out

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/306206
推荐阅读
相关标签
  

闽ICP备14008679号