当前位置:   article > 正文

重新审视MHA与Transformer_手写mha

手写mha

本文将基于PyTorch源码重新审视MultiheadAttention与Transformer。事实上,早在一年前博主就已经分别介绍了两者:各种注意力机制的PyTorch实现从零开始手写一个Transformer,但当时的实现大部分是基于d2l教程的,这次将基于PyTorch源码重新实现一遍。

1. MultiheadAttention

1.1 思路

回顾多头注意力,其公式如下:

MHA ( Q , K , V ) = Concat ( head 1 , ⋯   , head h ) W O head i = Attn ( Q W i Q , K W i K , V W i V ) \text{MHA}(Q,K,V)=\text{Concat}(\text{head}_1,\cdots,\text{head}_h)W^O \\ \text{head}_i=\text{Attn}(QW_i^Q,KW_i^K,VW_i^V) MHA(Q,K,V)=Concat(head1,,headh)WOheadi=Attn(QWiQ,KWiK,VWiV)

其中 W i Q ∈ R d m o d e l × d k W_i^Q\in \mathbb{R}^{d_{model}\times d_k} WiQRdmodel×dk W i K ∈ R d m o d e l × d k W_i^K\in \mathbb{R}^{d_{model}\times d_k} WiKRdmodel×dk W i V ∈ R d m o d e l × d v W_i^V\in \mathbb{R}^{d_{model}\times d_v} WiVRdmodel×dv W O ∈ R h d v × d m o d e l W^O\in \mathbb{R}^{hd_v\times d_{model}} WORhdv×dmodel,且 d k = d v = d m o d e l / h d_k=d_v=d_{model}/h dk=dv=dmodel/h

如果记 d h e a d = d m o d e l / h d_{head}=d_{model}/h dhead=dmodel/h,则 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV 的形状均为 ( d m o d e l , d h e a d ) (d_{model},d_{head}) (dmodel,dhead) W O W^O WO 的形状为 ( d m o d e l , d m o d e l ) (d_{model},d_{model}) (dmodel,dmodel)

先不考虑batch和mask的情形,在只有一个头的情况下( h = 1 h=1 h=1),MHA的计算方式为

class MHA(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.w_q = nn.Parameter(torch.empty(d_model, d_model))
        self.w_k = nn.Parameter(torch.empty(d_model, d_model))
        self.w_v = nn.Parameter(torch.empty(d_model, d_model))
        self.w_o = nn.Parameter(torch.empty(d_model, d_model))

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, query, key, value):
        """
        Args:
            query: (n, d_model),n是query的个数,m是key-value的个数
            key: (m, d_model)
            value: (m, d_model)
        """
        q = query @ self.w_q
        k = key @ self.w_k
        v = value @ self.w_v

        attn_logits = q @ k.transpose(0, 1) / math.sqrt(q.size(1))  # attn_logits: (n, m)
        attn_probs = F.softmax(attn_logits, dim=-1)
        attn_output = attn_probs @ v  # attn_output: (n, d_model)
        return attn_output, attn_probs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

现在考虑 h = 2 h=2 h=2 的情形,此时一共需要 3 ⋅ 2 + 1 = 7 3\cdot2+1=7 32+1=7 个参数矩阵

class MHA(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.w_q_1 = nn.Parameter(torch.empty(d_model, d_model // 2))
        self.w_k_1 = nn.Parameter(torch.empty(d_model, d_model // 2))
        self.w_v_1 = nn.Parameter(torch.empty(d_model, d_model // 2))

        self.w_q_2 = nn.Parameter(torch.empty(d_model, d_model // 2))
        self.w_k_2 = nn.Parameter(torch.empty(d_model, d_model // 2))
        self.w_v_2 = nn.Parameter(torch.empty(d_model, d_model // 2))

        self.w_o = nn.Parameter(torch.empty(d_model, d_model))

        self._reset_parameters()

    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, query, key, value):
        """
        Args:
            query: (n, d_model),n是query的个数,m是key-value的个数
            key: (m, d_model)
            value: (m, d_model)
        """
        q_1 = query @ self.w_q_1
        k_1 = key @ self.w_k_1
        v_1 = value @ self.w_v_1

        q_2 = query @ self.w_q_2
        k_2 = key @ self.w_k_2
        v_2 = value @ self.w_v_2

        attn_logits_1 = q_1 @ k_1.transpose(0, 1) / math.sqrt(q_1.size(1))
        attn_probs_1 = F.softmax(attn_logits_1, dim=-1)
        attn_output_1 = attn_probs_1 @ v_1

        attn_logits_2 = q_2 @ k_2.transpose(0, 1) / math.sqrt(q_2.size(1))
        attn_probs_2 = F.softmax(attn_logits_2, dim=-1)
        attn_output_2 = attn_probs_2 @ v_2

        attn_output = torch.cat([attn_output_1, attn_output_2], dim=-1) @ self.w_o  # attn_output: (n, d_model)
        attn_probs = torch.stack([attn_probs_1, attn_probs_2], dim=0)  # attn_probs: (2, n, m),其中2是头数

        return attn_output, attn_probs
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

可以看到代码量已经增加了不少,如果扩展到 h h h 个头的情形,则需要 3 h + 1 3h+1 3h+1 个参数矩阵。手动去一个个声明显然不现实,因为 h h h 是动态变化的,而用for循环创建又略显笨拙,有没有更简便的方法呢?

在上面的代码中,我们用小写 q q q 来代表查询 Q Q Q 经过投影后的结果( k , v k,v k,v 同理),即

q i = Q W i Q , i = 1 , 2 , ⋯   , h q_i=QW_i^Q,\quad i =1,2,\cdots,h qi=QWiQ,i=1,2,,h

其中 Q Q Q 的形状为 ( n , d m o d e l ) (n,d_{model}) (n,dmodel) q i q_i qi 的形状为 ( n , d h e a d ) (n,d_{head}) (n,dhead),且有

h e a d i = softmax ( q i k i T d h e a d ) v i head_i=\text{softmax}\left(\frac{q_ik_i^{T}}{\sqrt{d_{head}}}\right)v_i headi=softmax(dhead qikiT)vi

注意到

[ q 1 , q 2 , ⋯   , q h ] = Q [ W 1 Q , W 2 Q , ⋯   , W h Q ] (1) [q_1,q_2,\cdots,q_h]=Q[W_1^Q,W_2^Q,\cdots,W_h^Q]\tag{1} [q1,q2,,qh]=Q[W1Q,W2Q,,WhQ](1)

如果记 q ≜ [ q 1 , q 2 , ⋯   , q h ] q\triangleq [q_1,q_2,\cdots,q_h] q[q1,q2,,qh] W Q ≜ [ W 1 Q , W 2 Q , ⋯   , W h Q ] W^Q\triangleq [W_1^Q,W_2^Q,\cdots,W_h^Q] WQ[W1Q,W2Q,,WhQ],则 W Q W^Q WQ 的形状为 ( d m o d e l , d m o d e l ) (d_{model},d_{model}) (dmodel,dmodel) h h h 无关 q q q 的形状为 ( n , d m o d e l ) (n,d_{model}) (n,dmodel)。这样一来,我们就不需要一个个声明 W i Q W_i^Q WiQ 了,并且可以一次性存储所有的 q i q_i qi

要计算 h e a d 1 head_1 head1,我们需要能够从 q q q 中取出 q 1 q_1 q1 k , v k,v k,v 同理),所以我们期望 q q q 的形状是 ( h , n , d h e a d ) (h,n,d_{head}) (h,n,dhead),从而 q [ 1 ] q[1] q[1] 就是 q 1 q_1 q1(这里下标从 1 1 1 开始)。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/351236

推荐阅读
相关标签