赞
踩
在进行大模型的训练和推理中会大量的使用self-attention,在显存中需要保存self-attention中的query、key和value矩阵。Multi-head attention中每个头都有对应的query、key和value矩阵,因此会占用大量显存。而Multi-query attention中所有的头共用一个key和value矩阵来降低在模型训练和推理过程中大量占用显存的情况,不过这种方式可能会影响模型性能。grouped-query attention通过分组的方式,同一个组内共用一个key和value矩阵,当分组数与头数相同时即为Multi-head attention,当分组数为1时则为Multi-query attention。
以下是三种不同self-attention代码:
import torch # 增量式多头注意力机制 def MultiheadSelfAttentionIncremental(): """ d_model:模型隐藏层大小 b:批大小 h:头的数量 d_k:key的维度 d_v:value的维度 """ # 模型隐藏层512,批大小为32,头的数量为8,key和value为512//8 d_model, b, h, d_k, d_v = 512, 32, 8, (512 // 8), (512 // 8) m = 5 # 假设已经缓存的token数量 # 已经计算好的key和value矩阵,此处是假设已缓存了5个token的结果(随机初始化) prev_K = torch.rand(b, h, m, d_k) prev_V = torch.rand(b, h, m, d_v) X = torch.rand(b, d_model) # Query M = torch.rand(b, d_model) # Key and Value # q、k、v和输出的权重矩阵 P_q = torch.rand(h, d_model, d_k) # W_q P_k = torch.rand(h, d_model, d_k) # W_k P_v = torch.rand(h, d_model, d_v) # W_v P_o = torch.rand(h, d_model, d_v) # W_o q = torch.einsum("bd,hdk->bhk", X, P_q) # 多维线性代数数组操作,将从输入到Query new_K = torch.concat( [prev_K, torch.einsum("bd,hdk->bhk", M, P_k).unsqueeze(2)], axis=2 ) # prev_K(批, 头, 已有token, key维度),通过torch.einsum生成新的token的key,将两个矩阵在已有token这个维度上上进行矩阵拼接 new_V = torch.concat( [prev_V, torch.einsum("bd,hdv->bhv", M, P_v).unsqueeze(2)], axis=2 ) # 进行softmax计算 logits = torch.einsum("bhk,bhmk->bhm", q, new_K) # 计算qk weights = torch.softmax(logits, dim=-1) O = torch.einsum("bhm,bhmv->bhv", weights, new_V) y = torch.einsum("bhv,hdv->bd", O, P_o) return y, new_K, new_V if __name__ == "__main__": print(MultiheadSelfAttentionIncremental())
import torch # 增量式Multi-query attention def MultiquerySelfAttentionIncremental(): # 以下参数分别为模型隐藏层大小,批,头,key,value d, b, h, k, v = 512, 32, 8, (512 // 8), (512 // 8) m = 5 # 假设序列已有5个token # 初始化已有5个token的key和value 缓存 prev_K = torch.rand(b, m, k) # 由于multi-query attention中无论多少个头都只有一个key和value矩阵,因此比较multi-head attention中的代码少了头这个维度 prev_V = torch.rand(b, m, v) X = torch.rand(b, d) # 随机初始化Query M = torch.rand(b, d) # 随机初始化Key和Value # q、k、v和输出的权重矩阵 P_q = torch.rand(h, d, k) # W_q P_k = torch.rand(d, k) # W_k P_v = torch.rand(d, v) # W_v P_o = torch.rand(h, d, v) # W_o q = torch.einsum("bd,hdk->bhk", X, P_q) K = torch.concat([prev_K, torch.einsum("bd,dk->bk", M, P_k).unsqueeze(1)], axis=1) V = torch.concat([prev_V, torch.einsum("bd,dv->bv", M, P_v).unsqueeze(1)], axis=1) logits = torch.einsum("bhk,bmk->bhm", q, K) weights = torch.softmax(logits, dim=-1) O = torch.einsum("bhm,bmv->bhv", weights, V) y = torch.einsum("bhv,hdv->bd", O, P_o) return y, K, V if __name__ == "__main__": print(MultiquerySelfAttentionIncremental())
""" 在grouped-query attention中 当组数与头数相同时则为multi-head attention 当组数为1时则为multi-query attention """ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: # MHA return x return ( # MQA / GQA x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads model_parallel_size = fs_init.get_model_parallel_world_size() self.n_local_heads = args.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size # 此处 self.n_rep = self.n_local_heads // self.n_local_kv_heads # 此处 几个组 self.head_dim = args.dim // args.n_heads self.wq = ColumnParallelLinear( args.dim, args.n_heads * self.head_dim, bias=False, gather_output=False, init_method=lambda x: x, ) self.wk = ColumnParallelLinear( args.dim, self.n_kv_heads * self.head_dim, # 初始化为单个组内的一份 bias=False, gather_output=False, init_method=lambda x: x, ) self.wv = ColumnParallelLinear( args.dim, self.n_kv_heads * self.head_dim, # # 初始化为单个组内的一份 bias=False, gather_output=False, init_method=lambda x: x, ) self.wo = RowParallelLinear( args.n_heads * self.head_dim, args.dim, bias=False, input_is_parallel=True, init_method=lambda x: x, ) self.cache_k = torch.zeros( ( args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim, ) ).cuda() self.cache_v = torch.zeros( ( args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim, ) ).cuda() def forward( self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor], ): bsz, seqlen, _ = x.shape xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) self.cache_k = self.cache_k.to(xq) self.cache_v = self.cache_v.to(xq) self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv keys = self.cache_k[:bsz, : start_pos + seqlen] values = self.cache_v[:bsz, : start_pos + seqlen] # repeat k/v heads if n_kv_heads < n_heads # 单个组扩展为完整head keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) keys = keys.transpose(1, 2) values = values.transpose(1, 2) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) if mask is not None: scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) output = torch.matmul(scores, values) # (bs, n_local_heads, seqlen, head_dim) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) return self.wo(output)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。