赞
踩
Grouped-query attention an interpolation of multi-query and multi-head attention that achieves quality close to multi-head at comparable speed to multi-query attention.
上图来自https://arxiv.org/pdf/2305.13245v3.pdf
上code镇楼:
import torch.nn as nn import torch from torch import Tensor import math class MyGQA(nn.Module): def __init__(self, embed_dim, num_heads, num_kv_heads): # 就是把num_heads再分成num_kv_heads多组,组内用相当的kv,最后repeat_interleave下就好了 super(MyGQA, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.num_key_value_groups = num_heads // num_kv_heads head_dim = embed_dim // num_heads self.q_proj = nn.Linear(embed_dim,embed_dim) self.k_proj = nn.Linear(embed_dim,head_dim*num_kv_heads) self.w_proj = nn.Linear(embed_dim,head_dim*num_kv_heads) self.fc = nn.Linear(embed_dim,embed_dim) self.ln = nn.LayerNorm(embed_dim) def scaled_dot_product_attention(self, q:Tensor, k:Tensor, v:Tensor): bs, q_len, head_dim = q.shape q = q / math.sqrt(head_dim) # (bs, q_len, head_dim) x (bs, k_len, head_dim) -> (bs, q_len, k_len) attn = torch.bmm(q, k.transpose(-2, -1)) attn = attn.softmax(dim=-1) # (bs, q_len, k_len) x (bs, v_len, head_dim) -> (bs, q_len, head_dim) output = torch.bmm(attn, v) return output,attn def forward(self, query:Tensor, key:Tensor, value:Tensor): # assert query, key, value have the same shape bs, q_len, embed_dim = query.shape head_dim = embed_dim // self.num_heads q = self.q_proj(query).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim) k = self.k_proj(key).repeat_interleave(self.num_key_value_groups,dim=0).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim) v = self.w_proj(value).repeat_interleave(self.num_key_value_groups,dim=0).reshape(bs, q_len, self.num_heads, head_dim).transpose(1, 2).reshape(bs*self.num_heads, q_len, head_dim) self_output,attn = self.scaled_dot_product_attention(q, k, v) # self_output: bs * num_heads, q_len, head_dim # attn: bs * num_heads, q_len, k_len output = self.fc(self_output.reshape(bs, self.num_heads, q_len, head_dim).transpose(1,2).reshape(bs, q_len, self.num_heads*head_dim)) # hugging face版把fc放到BertSelfOutput里去了 return self.ln(output+query),attn embed_dim,num_heads=100,5 q_len,bs = 2,3 multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) query = torch.ones(bs, q_len, embed_dim) key = torch.ones(bs, q_len, embed_dim) value = torch.ones(bs, q_len, embed_dim) attn_output, attn_output_weights = multihead_attn(query, key, value) print('attn_output={}'.format(attn_output.shape)) print('attn_output_weights={}'.format(attn_output_weights.shape)) print('--------------') my_multihead_attn = MyGQA(embed_dim, num_heads, num_heads) my_attn_output, my_attn_output_weights = my_multihead_attn(query, key, value) print('my_attn_output={}'.format(attn_output.shape)) print('my_attn_output_weights={}'.format(attn_output_weights.shape)) ''' 输出如下: attn_output=torch.Size([3, 2, 100]) attn_output_weights=torch.Size([2, 3, 3]) -------------- my_attn_output=torch.Size([3, 2, 100]) my_attn_output_weights=torch.Size([2, 3, 3]) '''
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。