当前位置:   article > 正文

chatGLM2中的Multi Query Attention_multi-query attention

multi-query attention

目录

原理简介

代码实现和耗时比较

总结分析

          近期一直在玩大模型,对中文支持比较好的就是清华的chatGLM,目前chatGLM由v1升级到了chatGLM2。在gihub上介绍信息如下:

 试用了一下,效果和速度确实有所提升。

 这个得益于chatGLM2应用了许多优化的技术,介绍中有提到过的FlashAttention技术、Multi Query Attention(MQA)技术和int4量化等等。其中MQA技术是对Multi head  Attention(MHA)的一种优化实现,加快了技术速度的同时也保证了效果下降的不厉害。

原理简介

       MQA最早是出现在2019年谷歌的一篇论文Fast Transformer Decoding: One Write-Head is All You Need,之所以没有关注到,是因为之前很少做文本生成,解码序列长度也没有现阶段大模型的要求那么高。MQA的思想其实比较简单(如果对MHA比较熟悉的话),论文中给出的描述如下:

论文的意思是:MQA和MHA除了不同的attention head共享一份keys和values权重之外,其他的都是一样的。现有4个head的attention,每个head分别进行softmax(QK)V注意力计算,那么这样设置的MHA和MQA示意图如下所示:

 

 可以看到MHQ和MQA的不同之处仅仅在于每个头共享相同的K、V权重而Q不同享。

模型效果论文对比如下:

 推理速度上生成一个token时MHA和MQA的encoder分别耗时1.7us和1.5us,而decoder分别46us和3.8us,说明decoder上MQA比MHA快很多。另外在效果上MQA的PPL(越小越好)有所上升,BLEU(越大越好)有所下降,换句话说就是效果有所下降。

代码实现和耗时比较

参考了huggingface的transformers包中的bertselfattention源码实现了一版MHA和MQA,代码如下:

  1. import os
  2. os.environ['CUDA_VISIBLE_DEVICES'] = "1"
  3. import math
  4. import torch.nn as nn
  5. import torch
  6. from tqdm import tqdm
  7. import time
  8. class MiltiHeadSelfAttention(nn.Module):
  9. def __init__(self, num_attention_heads, hidden_size):
  10. super().__init__()
  11. self.num_attention_heads = num_attention_heads
  12. self.attention_head_size = int(hidden_size / num_attention_heads)
  13. self.all_head_size = self.num_attention_heads * self.attention_head_size
  14. self.query = nn.Linear(hidden_size, self.all_head_size)
  15. self.key = nn.Linear(hidden_size, self.all_head_size)
  16. self.value = nn.Linear(hidden_size, self.all_head_size)
  17. self.dropout = nn.Dropout(0.1)
  18. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  19. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  20. x = x.view(new_x_shape)
  21. return x.permute(0, 2, 1, 3)
  22. def forward(self,hidden_states):
  23. mixed_query_layer = self.query(hidden_states)
  24. key_layer = self.transpose_for_scores(self.key(hidden_states))
  25. value_layer = self.transpose_for_scores(self.value(hidden_states))
  26. query_layer = self.transpose_for_scores(mixed_query_layer)
  27. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  28. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  29. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  30. attention_probs = self.dropout(attention_probs)
  31. context_layer = torch.matmul(attention_probs, value_layer)
  32. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  33. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  34. context_layer = context_layer.view(new_context_layer_shape)
  35. return context_layer
  36. class MultiQuerySelfAttention(nn.Module):
  37. def __init__(self, num_attention_heads, hidden_size):
  38. super().__init__()
  39. self.num_attention_heads = num_attention_heads
  40. self.attention_head_size = int(hidden_size / num_attention_heads)
  41. self.all_head_size = self.num_attention_heads * self.attention_head_size
  42. self.query = nn.Linear(hidden_size, self.all_head_size)
  43. self.key = nn.Linear(hidden_size, self.attention_head_size)
  44. self.value = nn.Linear(hidden_size, self.attention_head_size)
  45. self.dropout = nn.Dropout(0.1)
  46. def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
  47. new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
  48. x = x.view(new_x_shape)
  49. return x.permute(0, 2, 1, 3)
  50. def forward(self,hidden_states):
  51. # hidden_states (B, L, D)
  52. mixed_query_layer = self.query(hidden_states)
  53. # query_layer (B, h, L, d)
  54. query_layer = self.transpose_for_scores(mixed_query_layer)
  55. # 每个key、value head参数都是一样的,只计算一次
  56. key = self.key(hidden_states)
  57. #key_layer (B, 1, L, d)
  58. key_layer = key.unsqueeze(1)
  59. value = self.value(hidden_states)
  60. # value_layer (B, 1, L, d)
  61. value_layer = value.unsqueeze(1)
  62. # key_layer (B, 1, d, L)
  63. key_layer = key_layer.transpose(-1, -2)
  64. #广播算法 (B, h, L, d) * (B, 1, d, L) => (B, h, L, d) * (B, h, d, L) = (B, h, L, L)
  65. attention_scores = torch.matmul(query_layer, key_layer)
  66. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  67. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  68. attention_probs = self.dropout(attention_probs)
  69. #广播算法 (B, h, L, L) * (B, 1, L, d) =>(B, h, L, L) * (B, h, L, d)= (B, h, L, d)
  70. context_layer = torch.matmul(attention_probs, value_layer)
  71. #(B, h, L, d) => (B, L, h, d)
  72. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  73. # (B,L, h*d) => (B,L,D)
  74. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  75. # (B,L, h*d) => (B,L,D)
  76. context_layer = context_layer.view(new_context_layer_shape)
  77. return context_layer
  78. if __name__ == '__main__':
  79. seed = 100
  80. num_attention_heads, hidden_size = 32, 4096
  81. torch.manual_seed(seed)
  82. torch.cuda.manual_seed(seed)
  83. device = "cuda:0"
  84. embeddings = torch.randn(5, 128, hidden_size).to(device)
  85. multiquery = MultiQuerySelfAttention(num_attention_heads, hidden_size).to(device)
  86. print(multiquery)
  87. total = 0
  88. for name, param in multiquery.named_parameters():
  89. if len(param.size()) == 2:
  90. total += param.shape[0] * param.shape[1]
  91. else:
  92. total += param.shape[0]
  93. print(f"multiquery parameters {total}")
  94. count = 100
  95. start = time.time()
  96. for _ in tqdm(range(count),ncols=50):
  97. input = embeddings.clone()
  98. for _ in range(100):
  99. for i in range(24):
  100. ouput = multiquery(input)
  101. input = torch.cat([input,ouput[:,-1:,:]],dim=1)
  102. end = time.time()
  103. print(f"multiquery time total cost {round(end - start, 8)} mean cost {round((end - start) / count, 8)}")
  104. multihead = MiltiHeadSelfAttention(num_attention_heads, hidden_size).to(device)
  105. print(multihead)
  106. total = 0
  107. for name, param in multihead.named_parameters():
  108. if len(param.size()) == 2:
  109. total += param.shape[0] * param.shape[1]
  110. else:
  111. total += param.shape[0]
  112. print(f"multihead parameters {total}")
  113. count = 100
  114. start = time.time()
  115. for _ in tqdm(range(count) ,ncols=50):
  116. input = embeddings.clone()
  117. for _ in range(100):
  118. for i in range(24):
  119. ouput = multihead(input)
  120. input = torch.cat([input, ouput[:, -1:, :]], dim=1)
  121. end = time.time()
  122. print(f"multihead time total cost {round(end-start,8)} mean cost {round((end-start)/count,8)}")

实现中主要借助矩阵计算的broadcast机制(自动广播机制)并行计算、就不用自己来实现每个头单独计算然后进行cat操作,效率比较高。模拟chatGLM2的设置,hidden_size = 4096、num_heads =32,num_layers=24输入一个维度为(5,128,4096)的向量进行文本解码,生成100个token,耗时对比如下:

 生成100个token时,MQA解码平均耗时2.7826秒,MHA解码平均耗时6.4796秒,简单来看MQA在decoder解码加速了一倍。从模型结构来看原始的MHA一层5034W参数,而MQA只有1783W参数,还是通过压缩参数量来实现显存占用的减少以及推理时间的减少。

总结分析

显存占用和推理耗时减小是显而易见的,因为参数量减少了。至于效果变化得很小,只能说多头attention机制中的多头其实并不是一定,之前的bert模型有人探索了改变head头数目,也会保持效果变化不大。在大模型这,可能只需要有不同的head采用不同的query向量,kv一样来保证每个头提取到不同的特征就够了。

什么时候使用MQA有效呢?

1、采用attention的模型,模型规模越大,那么收益就约明显。

2、decoder生成任务相比较encoder任务收益明显大很大,其实decoder生成任务的收益来源于每一次softmax(QK)V注意力计算微小耗时差异的累积,一次生成任务要生成许多个token,一个token需要经历模型结构层数次的softmax(QK)V注意力的计算。

参考文章

Fast Transformer Decoding: One Write-Head is All You Need

ChatGLM2-6B

 huggingface / transformers

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/137249
推荐阅读
相关标签
  

闽ICP备14008679号