当前位置:   article > 正文

pytorch 多头注意力实现_pytorch multiheadattention

pytorch multiheadattention

记录学习过程

m_{a}= softmax\left ( \frac{qk^{T}}{\sqrt{d}} \right )v

得到多个矩阵,再将矩阵拆分,得到想要形状(这就是注意力头的参数),最后通过公式得到注意力矩阵。

  1. class MultiHeadAttention(nn.Module):
  2. '''
  3. 该层是用于计算单个注意力权重的,因此我们需要通过三层线性
  4. 注意mask必须是张量,不然会报错
  5. '''
  6. def __init__(self,heads,d_model,dropout=0.1):
  7. super().__init__()
  8. self.d_model = d_model
  9. self.heads = heads
  10. self.d_k = d_model // heads #此为注意力放缩因子,这种设计可以根据的d_model来更好的进行缩放
  11. #layer
  12. self.q_l = nn.Linear(d_model,d_model)
  13. self.v_l = nn.Linear(d_model,d_model)
  14. self.k_l = nn.Linear(d_model,d_model)
  15. self.dropout = nn.Dropout(dropout)
  16. self.out = nn.Linear(d_model,d_model)
  17. def attention(self, q , k, v, d_k, mask = None, dropout = None):
  18. '''
  19. input:
  20. q = [d_model,d_model]
  21. k = [d_model,d_model]
  22. v = [d_model,d_model]
  23. output:
  24. output = [d_model,d_model]
  25. '''
  26. scores = torch.matmul(q,k.transpose(-2,-1)) / math.sqrt(d_k)
  27. if mask is not None:
  28. #问题
  29. mask = mask.unsqueeze(1)
  30. # 掩码创建
  31. scores = scores.masked_fill(mask==0,-1e9)#.to(device='cuda:0')
  32. scores = torch.softmax(scores,dim=-1)
  33. if dropout is not None:
  34. scores = dropout(scores)
  35. output = torch.matmul(scores,v)#[d_model,de_model]
  36. return output
  37. def forward(self,q,k,v,mask = None):
  38. batch_size = q.size(0)
  39. '''
  40. 拆头
  41. [d_model,d_model] reshape为[batch_size,-1,self.heads,self.d_k]
  42. 为什么要先重构后进行transpose操作,因为其得到的特征值要有包含原本的线性关系,直接重构会有问题
  43. '''
  44. k = self.k_l(k).view(batch_size,-1,self.heads,self.d_k)
  45. q = self.q_l(q).view(batch_size,-1,self.heads,self.d_k)
  46. v = self.v_l(v).view(batch_size,-1,self.heads,self.d_k)
  47. k = k.transpose(1,2)
  48. q = q.transpose(1,2)
  49. v = v.transpose(1,2)
  50. scores = self.attention(q,k,v,self.d_k,mask,self.dropout)#调用同类方法
  51. concat = scores.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)
  52. output = self.dropout(concat)
  53. return output

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/365059
推荐阅读
相关标签
  

闽ICP备14008679号