赞
踩
按照定义一个a×b的矩阵乘以一个b×c的矩阵要做abc次乘法,所以abc就是两个矩阵相乘的复杂度了,这是我们估算Transformer复杂度的依据
设n为序列长度,d为head_size(base版是64),h为head的数目(base版是12),那么hd就是我们通常说的“hidden_size”(base版是768)。
对于SA来说:
Q,K,V的投影变换,即n×hd的矩阵乘以hd×hd的矩阵做3次,因此计算量是3n(hd)2;
h个Attention头的运算,每个头先是n×d的Q与d×n的KT相乘得到n×n的Attention矩阵(softmax和归一化的计算量暂且忽略),然后n×n的矩阵与n×d的V相乘得到n×d的矩阵,这两步的计算量都是n2d,所以总计算量是h(n2d+n2d);
输出投影变换,也是n×hd的矩阵乘以hd×hd的矩阵,计算量是n(hd)2
所以,SA的总计算量是
3n(hd)2+h(n2d+n2d)+n(hd)2=4nh2d2+2n2hd
FFN就是两个全连接层,也就是两个矩阵变换(激活函数的计算量也忽略不计),一般的参数设置是:第一层是n×hd的矩阵乘以hd×4hd的矩阵,第二层就是n×4hd的矩阵乘以4hd×hd的矩阵。所以总计算量是
n×hd×4hd+n×4hd×hd=8nh2d2
4nh2d2+2n2hd > 8nh2d2 ==> n>2hd
对于base版来说,这意味着n>1536!也就是说,只有当序列长度超过1536时,SA的计算量才大于FFN,在这之前,都是线性复杂度的FFN占主导
4nh2d2+2n2hd + 8nh2d2 = 12nh2d2+2n2hd
它是关于n的一次项和二次项的求和,当n足够大时,复杂度自然是声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小小林熬夜学编程/article/detail/352502
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。