当前位置:   article > 正文

【pytorch】手把手实现自注意力机制_自注意力机制 pyth

自注意力机制 pyth

背景:

        不仅在NLP领域,自注意力机制也在CV领域有着广泛的应用。所以,如何很好地实现自注意力机制成为比较关键的问题。下面我们来对于该机制进行简单实现。

        先总结一下思路:

        1. 我们的输入是一个(B,N,C)形状的矩阵,其中B代表Batch Size,N代表Time Step,C代表每个Time Step的维度。

        2. 我们想做的是,根据输入得到多头的qkv。qkv分别代表query,key,value。我们想用query来查询key而得到一个关联度矩阵A。

        3. 由于是多头注意力,我们得到了多个关联度矩阵,我们要将多个关联度矩阵合并为一个。

        4. 最后的关联度矩阵和value矩阵相乘,等到最后的输出。

最后的代码如下:

        

  1. import torch,math
  2. import torch.nn as nn
  3. class MultiHead_SelfAttention(nn.Module):
  4. def __init__(self, dim, num_head):
  5. '''
  6. Args:
  7. dim: dimension for each time step
  8. num_head:num head for multi-head self-attention
  9. '''
  10. super().__init__()
  11. self.dim=dim
  12. self.num_head=num_head
  13. self.qkv=nn.Linear(dim, dim*3) # extend the dimension for later spliting
  14. def forward(self, x):
  15. B, N, C = x.shape
  16. qkv = self.qkv(x).reshape(B, N, 3, self.num_head, C//self.num_head).permute(2, 0, 3, 1, 4)
  17. q, k, v= qkv[0], qkv[1], qkv[2]
  18. att = q@k.transpose(-1, -2)/ math.sqrt(C)
  19. att = att.softmax(dim=1) # 将多个注意力矩阵合并为一个
  20. x = (att@v).transpose(1, 2)
  21. x=x.reshape(B, N, C)
  22. return x
  23. if __name__=='__main__':
  24. B = 10
  25. N = 20
  26. C = 32
  27. num_head=8
  28. x = torch.rand((B, N, C))
  29. MHSA=Multihead_SelfAttention(C, num_head)
  30. print(MHSA(x).shape)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/291033
推荐阅读
相关标签
  

闽ICP备14008679号