当前位置:   article > 正文

几种不同的self-attention_multi-query attention与grouped-query attention

multi-query attention与grouped-query attention

几种不同的self-attention

  • Multi-head attention
  • Multi-query attention
  • grouped-query attention

在进行大模型的训练和推理中会大量的使用self-attention,在显存中需要保存self-attention中的query、key和value矩阵。Multi-head attention中每个头都有对应的query、key和value矩阵,因此会占用大量显存。而Multi-query attention中所有的头共用一个key和value矩阵来降低在模型训练和推理过程中大量占用显存的情况,不过这种方式可能会影响模型性能。grouped-query attention通过分组的方式,同一个组内共用一个key和value矩阵,当分组数与头数相同时即为Multi-head attention,当分组数为1时则为Multi-query attention。
以下是三种不同self-attention代码:

  1. Multi-head attention
import torch

# 增量式多头注意力机制
def MultiheadSelfAttentionIncremental():
	"""
	d_model:模型隐藏层大小
	b:批大小
	h:头的数量
	d_k:key的维度
	d_v:value的维度
	"""
	# 模型隐藏层512,批大小为32,头的数量为8,key和value为512//8
    d_model, b, h, d_k, d_v = 512, 32, 8, (512 // 8), (512 // 8)

    m = 5  # 假设已经缓存的token数量
    # 已经计算好的key和value矩阵,此处是假设已缓存了5个token的结果(随机初始化)
    prev_K = torch.rand(b, h, m, d_k)
    prev_V = torch.rand(b, h, m, d_v)

    X = torch.rand(b, d_model)  # Query
    M = torch.rand(b, d_model)  # Key and Value
    # q、k、v和输出的权重矩阵
    P_q = torch.rand(h, d_model, d_k)  # W_q
    P_k = torch.rand(h, d_model, d_k)  # W_k
    P_v = torch.rand(h, d_model, d_v)  # W_v
    P_o = torch.rand(h, d_model, d_v)  # W_o

    q = torch.einsum("bd,hdk->bhk", X, P_q)  # 多维线性代数数组操作,将从输入到Query
    new_K = torch.concat(
        [prev_K, torch.einsum("bd,hdk->bhk", M, P_k).unsqueeze(2)], axis=2
    )  # prev_K(批, 头, 已有token, key维度),通过torch.einsum生成新的token的key,将两个矩阵在已有token这个维度上上进行矩阵拼接
    new_V = torch.concat(
        [prev_V, torch.einsum("bd,hdv->bhv", M, P_v).unsqueeze(2)], axis=2
    )  
    # 进行softmax计算
    logits = torch.einsum("bhk,bhmk->bhm", q, new_K)  # 计算qk
    weights = torch.softmax(logits, dim=-1)
    O = torch.einsum("bhm,bhmv->bhv", weights, new_V)
    y = torch.einsum("bhv,hdv->bd", O, P_o)
    return y, new_K, new_V

if __name__ == "__main__":
	    print(MultiheadSelfAttentionIncremental())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  1. multi-query attention
import torch

# 增量式Multi-query attention
def MultiquerySelfAttentionIncremental():
	# 以下参数分别为模型隐藏层大小,批,头,key,value
    d, b, h, k, v = 512, 32, 8, (512 // 8), (512 // 8)

    m = 5  # 假设序列已有5个token
    # 初始化已有5个token的key和value 缓存
    prev_K = torch.rand(b, m, k)  # 由于multi-query attention中无论多少个头都只有一个key和value矩阵,因此比较multi-head attention中的代码少了头这个维度
    prev_V = torch.rand(b, m, v) 

    X = torch.rand(b, d)  # 随机初始化Query
    M = torch.rand(b, d)  # 随机初始化Key和Value
    # q、k、v和输出的权重矩阵
    P_q = torch.rand(h, d, k)  # W_q
    P_k = torch.rand(d, k)  # W_k
    P_v = torch.rand(d, v)  # W_v
    P_o = torch.rand(h, d, v)  # W_o

    q = torch.einsum("bd,hdk->bhk", X, P_q)
    K = torch.concat([prev_K, torch.einsum("bd,dk->bk", M, P_k).unsqueeze(1)], axis=1)
    V = torch.concat([prev_V, torch.einsum("bd,dv->bv", M, P_v).unsqueeze(1)], axis=1)
    logits = torch.einsum("bhk,bmk->bhm", q, K)
    weights = torch.softmax(logits, dim=-1)
    O = torch.einsum("bhm,bmv->bhv", weights, V)
    y = torch.einsum("bhv,hdv->bd", O, P_o)
    return y, K, V

if __name__ == "__main__":
	print(MultiquerySelfAttentionIncremental())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  1. grouped-query attention
"""
在grouped-query attention中
当组数与头数相同时则为multi-head attention
当组数为1时则为multi-query attention
"""
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1: # MHA
        return x
    return ( # MQA / GQA
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
    )


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        model_parallel_size = fs_init.get_model_parallel_world_size()
        self.n_local_heads = args.n_heads // model_parallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size # 此处 
        self.n_rep = self.n_local_heads // self.n_local_kv_heads # 此处 几个组
        self.head_dim = args.dim // args.n_heads

        self.wq = ColumnParallelLinear(
            args.dim,
            args.n_heads * self.head_dim,
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wk = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim, # 初始化为单个组内的一份
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wv = ColumnParallelLinear(
            args.dim,
            self.n_kv_heads * self.head_dim, # # 初始化为单个组内的一份
            bias=False,
            gather_output=False,
            init_method=lambda x: x,
        )
        self.wo = RowParallelLinear(
            args.n_heads * self.head_dim,
            args.dim,
            bias=False,
            input_is_parallel=True,
            init_method=lambda x: x,
        )

        self.cache_k = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()
        self.cache_v = torch.zeros(
            (
                args.max_batch_size,
                args.max_seq_len,
                self.n_local_kv_heads,
                self.head_dim,
            )
        ).cuda()

    def forward(
        self,
        x: torch.Tensor,
        start_pos: int,
        freqs_cis: torch.Tensor,
        mask: Optional[torch.Tensor],
    ):
        bsz, seqlen, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)

        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        self.cache_k = self.cache_k.to(xq)
        self.cache_v = self.cache_v.to(xq)

        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv

        keys = self.cache_k[:bsz, : start_pos + seqlen]
        values = self.cache_v[:bsz, : start_pos + seqlen]

        # repeat k/v heads if n_kv_heads < n_heads # 单个组扩展为完整head
        keys = repeat_kv(keys, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)
        values = repeat_kv(values, self.n_rep)  # (bs, seqlen, n_local_heads, head_dim)

        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        if mask is not None:
            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
        return self.wo(output)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/码创造者/article/detail/936069
推荐阅读
相关标签
  

闽ICP备14008679号

        
cppcmd=keepalive&