当前位置:   article > 正文

多头注意力机制_bert性能优化之——用另一种方式整合多头注意力

多头注意力机制的分数怎么融合

作者:邱震宇(华泰证券股份有限公司 算法工程师)

知乎专栏:我的ai之路


a7fa4b2d4c9e23378976d8105036a9d2.png

今天我想给大家介绍这样一篇论文:Multi-Head Attention: Collaborate Instead of Concatenate。作者均来自

洛桑联邦理工学院_百度百科baike.baidu.comb3c5ca9dc1ed7adb950ec08d6ce9e4d5.png

看过我文章的同学肯定知道,我一直在关注bert模型的性能优化相关研究,而这篇论文正好是与transformer的性能优化相关,并且我认为它的方法不需要做太多的适配就能应用在预训练模型上面,实用性较高,因此推荐给大家。


众所周知,经典的transformer架构中采用了multi-head attention机制来引导模型从不同角度学习不同的语义信息,从各种实验对比中也能发现多头机制确实能够提升模型在NLP任务上的精度。然而,随着目前大规模预训练模型的普及,多头注意力机制在带来精度提升的同时,也增加了计算的成本,带来了性能上的限制。

因此最近两年,有些研究人员尝试从不同的维度去探讨是否能从多头机制上去优化transformer的性能。有些工作重点关注了多头中每个头的注意力到底捕捉了哪些语义信息,头与头之间捕捉的信息是否有冗余,例如这篇论文:Analyzing multi-head self-attention: Specialized heads do the heavy lifting, the rest can be pruned,提出了一种量化注意力头重要程度的方法。还有一些工作更加激进,提出了多头注意力机制是否有必要的疑问,例如这篇论文:Are sixteen heads really better than one。它对transformer中的每个头都做了消融实验,探讨了每个头在不同下游NLP任务上的作用,最后提出了一种迭代式地剪枝注意力头的方法。

与上述工作不同,本篇论文并非直接对注意力头进行结构性剪枝,而是关注所有注意力头捕捉的通用信息,试图将这些信息提取出来作为sharing weights,每个头各自关注自己独有的工作,从而减少多头注意力计算时的成本。下面我就详细得为大家解读这篇论文的工作。

单个注意力头的减负

在那篇经典的Attention is all you need论文中,对于注意力分数的计算是这样的:

85045ddb-322b-eb11-8da9-e4434bdf6706.svg

其中, 88045ddb-322b-eb11-8da9-e4434bdf6706.svg 当X和Y是同一个序列时,就是自注意力模型,此时 8a045ddb-322b-eb11-8da9-e4434bdf6706.svg 。

然而,在各种版本的transformer实现中,上述各种线性映射计算是附加bias的,即 8c045ddb-322b-eb11-8da9-e4434bdf6706.svg ,其中 8f045ddb-322b-eb11-8da9-e4434bdf6706.svg 。因为在各种深度学习框架中,默认支持broadcasting,所以这里公式上引入了 90045ddb-322b-eb11-8da9-e4434bdf6706.svg ,相当于实现了broadcasting。

在引入了bias后,我们重新对 91045ddb-322b-eb11-8da9-e4434bdf6706.svg 进行展开,可得:

92045ddb-322b-eb11-8da9-e4434bdf6706.svg

备注一下:论文这里的公式貌似有点问题,最后一项应该是我推导出的项。

最后两项在做softmax的时候可以舍弃掉,为什么呢?其实很简单,我们得到的Attention分数是一个T*T的矩阵,而 94045ddb-322b-eb11-8da9-e4434bdf6706.svg 和 95045ddb-322b-eb11-8da9-e4434bdf6706.svg 得到的都是一个T*1的向量,最后通过 96045ddb-322b-eb11-8da9-e4434bdf6706.svg 重复了T列扩充成了矩阵,因此每一行上,它的每一列的值都是相同的,因为softmax针对的是列维度,因此后两项对于整体的attention计算来说是一个常量,又因为:

 98045ddb-322b-eb11-8da9-e4434bdf6706.svg 

因此最后两项计算可以舍弃。又因为前面两项中,不存在 99045ddb-322b-eb11-8da9-e4434bdf6706.svg ,因此我们甚至不用去定义这个bias项。

另外,对于上述推导式的第一项,由于其计算了Query和key的相互关系,因此相当于捕捉了上下文的相关信息&

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/365009
推荐阅读
相关标签
  

闽ICP备14008679号