当前位置:   article > 正文

多头注意力机制_多头注意力机制csdn

多头注意力机制csdn

前面已经讲完了自注意力机制,简单来讲,就是对一组向量空间分别求内积,然后进行缩放,最后对不同的向量使用压缩后的分数累加求和。

1.多头是个什么东西?

        实际上很简单,自注意力层的输出空间被分解为一组独立的子空间,对这些子空间分别进行学习,也就是说,初始的Q,K,V三组独立的密集投影生成三组独立的向量[1],每个向量都通过神经注意力进行处理,然后将多个输出拼接为一个输出序列[2],然后将输出序列经过线性变换[3],每个这样的子空间叫做一个头。密集投影层是可学习层,因此投影过程是可以学习的,独立的头也有助于该层为每个词元学习多组特征,其中每一组内的特征彼此相关,但与其他组的特征几乎无关。

我标记出了三个位置,这三个位置的描述就是实现多头注意力的关键

按照之前我们实现了一个注意力层,我们将其打包为attention(q,k,v)

(1).Q,K,V三组投影,实际上就是线性变化Y = W X

newQ = W_q*Q\\ newK = W_k*K\\ newV = W_v*V

  1. import numpy as np
  2. #假设有矩阵Q,K,V,矩阵大小都一样,[batch_size, N, feature_numbers]
  3. head_num = 3 #三个头
  4. #这里的w矩阵需要能够学习,这里是选择了一个初始化为0的矩阵
  5. w_q = np.random.random((head_num, N, feature_numbers))
  6. w_k = np.random.random((head_num, N, feature_numbers))
  7. w_v = np.random.random((head_num, N, feature_numbers))
  8. #线性变换
  9. newQ = np.matmul(w_q, Q)
  10. newK = np.matmul(w_k, K)
  11. newV = np.matmul(w_v, V)
  12. #使用多头注意力
  13. result = attention(newQ, newK, newV)
  14. #这里只能算伪代码了
  15. #拼接多个头,假设各个矩阵大小一样,因此可以直接转换维度作为拼接
  16. output = result.reshape(Q.shape, head_num)
  17. #最终输出到密集层
  18. head_output = output * Wo

然后经过注意力机制,生成一个头,这是其中一个头而已,根据需要可以生成多个

h_i=attention(newQ, newK, newV)

(2).拼接多个头

output = concat(h_1,h_2,h_3,...,h_n)

(3).全连接

result = W*output

这个代码顶多算伪代码,以后有空修改吧

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/358799
推荐阅读
相关标签
  

闽ICP备14008679号