当前位置:   article > 正文

【深度学习】Self-Attention 原理与代码实现_selfattention代码

selfattention代码

1.Self-Attention 结构

        在计算的时候需要用到矩阵Q(查询),K(键值),V(值)。在实际中,Self-Attention 接收的是输入(单词的表示向量x组成的矩阵X) 或者上一个 Encoder block 的输出。而Q,K,V正是通过 Self-Attention 的输入进行线性变换得到的。

2. Q, K, V 的计算

        Self-Attention 的输入用矩阵X进行表示,使用线性变阵矩阵W_{q}W_{k}W_{k}经过计算得到QKV 计算如下图所示,注意 X, Q, K, V 的每一行都表示一个单词。

Q=Linear(X_{Embedding} )=X_{Embedding} *W_Q

K=Linear(K_{Embedding} )=K_{Embedding} *W_K

V=Linear(V_{Embedding} )=V_{Embedding} *W_V

3.Self-Attention 的计算

3.1 计算公式

        得到矩阵 Q, K, V之后就可以计算 Self-Attention 的值了,计算的公式如下:

3.2 计算相关系数 

        公式中计算矩阵QK每一行向量的内积,为了防止内积过大,因此除以d_{k}的平方根。Q*K^{T}后,得到的矩阵行列数都为 nn 为句子单词数,这个矩阵可以表示单词之间的 attention 强度。

        得到 Q* K^{T}之后,使用Softmax对矩阵的每一行都进行归一化,计算当前单词相对于其他单词的相关系数(attention值)。

3.3 相关系数相乘

        图中 Softmax 矩阵的第 1 行表示单词 1 与其他所有单词的 attention 系数,最终单词 1 的输出Z_{1}等于所有单词的v值根据 attention 系数相乘后加在一起得到,如下图所示:

3.4 self-attention的输出        

        得到 Softmax 矩阵之后可以和V相乘,得到最终的输出Z

 4.self-attention的代码实现

  1. from math import sqrt
  2. import torch
  3. import torch.nn as nn
  4. class SelfAttention(nn.Module):
  5. def __init__(self, dim_q, dim_k, dim_v):
  6. super(SelfAttention, self).__init__()
  7. self.dim_q = dim_q
  8. self.dim_k = dim_k
  9. self.dim_v = dim_v
  10. #定义线性变换函数
  11. self.linear_q = nn.Linear(dim_q, dim_k, bias=False)
  12. self.linear_k = nn.Linear(dim_q, dim_k, bias=False)
  13. self.linear_v = nn.Linear(dim_q, dim_v, bias=False)
  14. self._norm_fact = 1 / sqrt(dim_k)
  15. def forward(self, x):
  16. # x: batch, n, dim_q
  17. #根据文本获得相应的维度
  18. batch, n, dim_q = x.shape
  19. assert dim_q == self.dim_q
  20. q = self.linear_q(x) # batch, n, dim_k
  21. k = self.linear_k(x) # batch, n, dim_k
  22. v = self.linear_v(x) # batch, n, dim_v
  23. #q*k的转置 并*开根号后的dk
  24. dist = torch.bmm(q, k.transpose(1, 2)) * self._norm_fact # batch, n, n
  25. #归一化获得attention的相关系数
  26. dist = torch.softmax(dist, dim=-1) # batch, n, n
  27. #attention系数和v相乘,获得最终的得分
  28. att = torch.bmm(dist, v)
  29. return att

图reference:Transformer模型详解(图解最完整版) - 知乎 (zhihu.com)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家自动化/article/detail/433739
推荐阅读
相关标签
  

闽ICP备14008679号