当前位置:   article > 正文

Grouped-query Attention(GQA)、Multi-query Attention(MQA)

grouped-query attention

Grouped-query attention an interpolation of multi-query and multi-head attention that achieves quality close to multi-head at comparable speed to multi-query attention.
在这里插入图片描述
上图来自https://arxiv.org/pdf/2305.13245v3.pdf

上code镇楼:

import torch.nn as nn
import torch
from torch import Tensor
import math

class MyGQA(nn.Module):
    def __init__(self, embed_dim, num_heads, num_kv_heads):
        # 就是把num_heads再分成num_kv_heads多组,组内用相当的kv,最后repeat_interleave下就好了
        super(MyGQA, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_key_value_groups = num_heads // num_kv_heads
        head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim,embed_dim)
        self.k_proj = nn.Linear(embed_dim,head_dim*num_kv_heads)
        self.w_proj = nn.Linear(embed_dim,head_dim*num_kv_heads)
        self.fc = nn.Linear(embed_dim,embed_dim)
        self.ln = nn.LayerNorm(embed_dim)

    def scaled_dot_product_attention(self, q:Tensor, k:Tensor, v:Tensor):
        bs, q_len, head_dim = q.shape
        q = q / math.sqrt(head_dim)
        # (bs, q_len, head_dim) x (bs, k_len, head_dim) -> (bs, q_len, k_len)
        attn = torch.bmm(q, k.transpose(-2, -1))
        attn = attn.softmax(dim=-1)
        # (bs, q_len, k_len) x (bs, v_len, head_dim) -> (bs, q_len, head_dim)
        output = torch.bmm(attn, v)
        return output,attn

    def forward(self, query:Tensor, key:Tensor, value:Tensor):
        # assert query, key, value have the same shape
        bs, q_len, embed_dim = query.shape
        head_dim = embed_dim // self.num_heads

        q = self.q_proj(query).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
        k = self.k_proj(key).repeat_interleave(self.num_key_value_groups,dim=0).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
        v = self.w_proj(value).repeat_interleave(self.num_key_value_groups,dim=0).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim)
        self_output,attn = self.scaled_dot_product_attention(q, k, v)
        # self_output: bs * num_heads, q_len, head_dim
        # attn: bs * num_heads, q_len, k_len
        output = self.fc(self_output.reshape(bs, self.num_heads, q_len, head_dim).transpose(1,2).reshape(bs, q_len, self.num_heads*head_dim))
        # hugging face版把fc放到BertSelfOutput里去了
        return self.ln(output+query),attn

embed_dim,num_heads=100,5
q_len,bs = 2,3

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
query = torch.ones(bs, q_len, embed_dim)
key = torch.ones(bs, q_len, embed_dim)
value = torch.ones(bs, q_len, embed_dim)

attn_output, attn_output_weights = multihead_attn(query, key, value)
print('attn_output={}'.format(attn_output.shape))
print('attn_output_weights={}'.format(attn_output_weights.shape))
print('--------------')
my_multihead_attn = MyGQA(embed_dim, num_heads, num_heads)
my_attn_output, my_attn_output_weights = my_multihead_attn(query, key, value)
print('my_attn_output={}'.format(attn_output.shape))
print('my_attn_output_weights={}'.format(attn_output_weights.shape))

'''
输出如下:
attn_output=torch.Size([3, 2, 100])
attn_output_weights=torch.Size([2, 3, 3])
--------------
my_attn_output=torch.Size([3, 2, 100])
my_attn_output_weights=torch.Size([2, 3, 3])
'''

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/497982
推荐阅读
相关标签
  

闽ICP备14008679号