赞
踩
比如在pytorch中我们可以很方便的使用nn.TransformerEncoderLayer
或者nn.TransformerDecoderLayer
类,里面包括一个多头注意力和一个FFN前馈神经网络(这两部分之间有残差连接)和层归一化操作。
from torch import nn
# 编码层:使用Transformer
encoder_layer = nn.TransformerEncoderLayer(hidden_dim, num_head, dim_feedforward, dropout, activation)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
可以看到pytorch中nn.TransformerEncoderLayer
的源码,就有多头注意力MultiheadAttention
:
class TransformerEncoderLayer(Module): __constants__ = ['batch_first', 'norm_first'] def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) self.norm_first = norm_first self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): activation = _get_activation_fn(activation) # We can't test self.activation in forward() in TorchScript, # so stash some information about it instead. if activation is F.relu or isinstance(activation, torch.nn.ReLU): self.activation_relu_or_gelu = 1 elif activation is F.gelu or isinstance(activation, torch.nn.GELU): self.activation_relu_or_gelu = 2 else: self.activation_relu_or_gelu = 0 self.activation = activation def __setstate__(self, state): super().__setstate__(state) if not hasattr(self, 'activation'): self.activation = F.relu def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, is_causal: bool = False) -> Tensor: r"""Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). is_causal: If specified, applies a causal mask as src_mask. Default: ``False``. src_key_padding_mask: the mask for the src keys per batch (optional). Shape: see the docs in Transformer class. """ src_key_padding_mask = F._canonical_mask( mask=src_key_padding_mask, mask_name="src_key_padding_mask", other_type=F._none_or_dtype(src_mask), other_name="src_mask", target_type=src.dtype ) src_mask = F._canonical_mask( mask=src_mask, mask_name="src_mask", other_type=None, other_name="", target_type=src.dtype, check_other=False, ) # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf why_not_sparsity_fast_path = '' if not src.dim() == 3: why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}" elif self.training: why_not_sparsity_fast_path = "training is enabled" elif not self.self_attn.batch_first : why_not_sparsity_fast_path = "self_attn.batch_first was not True" elif not self.self_attn._qkv_same_embed_dim : why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True" elif not self.activation_relu_or_gelu: why_not_sparsity_fast_path = "activation_relu_or_gelu was not True" elif not (self.norm1.eps == self.norm2.eps): why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps" elif src.is_nested and (src_key_padding_mask is not None or src_mask is not None): why_not_sparsity_fast_path = "neither src_key_padding_mask nor src_mask are not supported with NestedTensor input" elif self.self_attn.num_heads % 2 == 1: why_not_sparsity_fast_path = "num_head is odd" elif torch.is_autocast_enabled(): why_not_sparsity_fast_path = "autocast is enabled" if not why_not_sparsity_fast_path: tensor_args = ( src, self.self_attn.in_proj_weight, self.self_attn.in_proj_bias, self.self_attn.out_proj.weight, self.self_attn.out_proj.bias, self.norm1.weight, self.norm1.bias, self.norm2.weight, self.norm2.bias, self.linear1.weight, self.linear1.bias, self.linear2.weight, self.linear2.bias, ) # We have to use list comprehensions below because TorchScript does not support # generator expressions. if torch.overrides.has_torch_function(tensor_args): why_not_sparsity_fast_path = "some Tensor argument has_torch_function" elif not all((x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args): why_not_sparsity_fast_path = "some Tensor argument is neither CUDA nor CPU" elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args): why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the " "input/output projection weights or biases requires_grad") if not why_not_sparsity_fast_path: merged_mask, mask_type = self.self_attn.merge_masks(src_mask, src_key_padding_mask, src) return torch._transformer_encoder_layer_fwd( src, self.self_attn.embed_dim, self.self_attn.num_heads, self.self_attn.in_proj_weight, self.self_attn.in_proj_bias, self.self_attn.out_proj.weight, self.self_attn.out_proj.bias, self.activation_relu_or_gelu == 2, self.norm_first, self.norm1.eps, self.norm1.weight, self.norm1.bias, self.norm2.weight, self.norm2.bias, self.linear1.weight, self.linear1.bias, self.linear2.weight, self.linear2.bias, merged_mask, mask_type, ) x = src if self.norm_first: x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask, is_causal=is_causal) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask, is_causal=is_causal)) x = self.norm2(x + self._ff_block(x)) return x # self-attention block def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, is_causal=is_causal)[0] return self.dropout1(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x)
可以看到上面编码器中有自注意机制函数,通常QKV都来自同一个序列,即为序列中的每个token生成Q、K、V矩阵(是由同一个X输入进经过三个不同的线性变化得到的,对应的线性变换矩阵W_q
、W_K
、W_v
是待学习的权重矩阵),比如“i love large language model”。
# self-attention block
def _sa_block(self, x: Tensor,
attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
x = self.self_attn(x, x, x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False, is_causal=is_causal)[0]
return self.dropout1(x)
在解码器中有self-attention和cross-attention模块:
class TransformerDecoderLayer(Module): __constants__ = ['batch_first', 'norm_first'] def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs) self.multihead_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs) # Implementation of Feedforward model self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs) self.dropout = Dropout(dropout) self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs) self.norm_first = norm_first self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) self.dropout1 = Dropout(dropout) self.dropout2 = Dropout(dropout) self.dropout3 = Dropout(dropout) # Legacy string support for activation function. if isinstance(activation, str): self.activation = _get_activation_fn(activation) else: self.activation = activation def __setstate__(self, state): if 'activation' not in state: state['activation'] = F.relu super().__setstate__(state) def forward( self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, memory_mask: Optional[Tensor] = None, tgt_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, tgt_is_causal: bool = False, memory_is_causal: bool = False, ) -> Tensor: # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf x = tgt if self.norm_first: x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal) x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal) x = x + self._ff_block(self.norm3(x)) else: x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal)) x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal)) x = self.norm3(x + self._ff_block(x)) return x # self-attention block def _sa_block(self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: x = self.self_attn(x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, is_causal=is_causal, need_weights=False)[0] return self.dropout1(x) # multihead attention block def _mha_block(self, x: Tensor, mem: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor: x = self.multihead_attn(x, mem, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, is_causal=is_causal, need_weights=False)[0] return self.dropout2(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout3(x)
从上面的源码中可以看到forward
部分经过自注意力_sa_block
的计算后,解码器会使用交叉注意力self._mha_block
,这函数的第一个参数self.norm2(x)
是目标序列的嵌入表示(作为Q),第二个参数memory
是编码器的输出(作为K和V),从而让模型理解输入序列和目标序列之间的依赖关系。
x = tgt
if self.norm_first:
x = x + self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask, tgt_is_causal)
x = x + self._mha_block(self.norm2(x), memory, memory_mask, memory_key_padding_mask, memory_is_causal)
x = x + self._ff_block(self.norm3(x))
else:
x = self.norm1(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask, tgt_is_causal))
x = self.norm2(x + self._mha_block(x, memory, memory_mask, memory_key_padding_mask, memory_is_causal))
x = self.norm3(x + self._ff_block(x))
return x
MHA(Multi-head Attention)是标准的多头注意力机制,包含h个Query、Key 和 Value 矩阵。所有注意力头的 Key 和 Value 矩阵权重不共享。即多个独立的单头注意力的拼接,
在MHA中,KV Heads的数量和Query Heads的数量相同,每个Query Head持有一个独立的KV Head,在Attention中,对单独的KV Head做计算。但是,当模型层数加深和Heads数变多后,QKV Attention的计算和IO都会快速增加。为了缓解这种情况,有学者提出了MQA和GQA。
MQA(Multi-Query Attention,Fast Transformer Decoding: One Write-Head is All You Need)是多查询注意力的一种变体,也是用于自回归解码的一种注意力机制。MQA比较极端,只保留一个KV Head,多个Query Heads共享相同的KV Head。这相当于不同Head的Attention差异,全部都放在了Query上,需要模型仅从不同的Query Heads上就能够关注到输入hidden states不同方面的信息。这样做的好处是,极大地降低了KV Cache的需求,但是会导致模型效果有所下降。
GQA(Grouped-Query Attention,GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints)是分组查询注意力。
GQA (Group Queries Attention): GQA与MQA不同,而是采取了折中的做法。GQA把Query Heads进行分组,每组Query Heads对应一个KV Head。比如,把8个Query Heads分成4组,每个Grouped Query Head包含2个Query Heads,一个Grouped Query Head对应一个KV Head,此时总共有4个KV Heads。GQA可以在减少计算量和KV Cache同时确保模型效果不受到大的影响。
在目前大部分主流训推框架或算法,都已经支持MQA/GQA,比如FlashAttention中,也支持MQA和GQA。对于MQA和GQA的情形,FlashAttention采用Indexing的方式,而不是直接复制多份KV Head的内容到显存然后再进行计算。Indexing,即通过传入KV/KV Head索引到Kernel中,然后计算内存地址,直接从内存中读取KV。
deepseek使用了MLA。MLA是对GQA的改进。待更新。
YOCO是Decoder-Decoder架构,和Decoder-Only架构非常接近,之所以命名为Decoder-Decoder架构,是因为这两个Decoder的含义不是完全一致的。YOCO整体上包括两部分,一部分是Self-Decoder,这个和常见的Decoder Transformers是一样的;另一部分是Cross-Decoder。Self-Decoder负责产生global KV Cache,这个KV Cache会直接被后续的Cross-Decoder使用。这也是后半部分为啥叫做Cross-Decoder的原因,它使用Self-Decoder产生的KV Cache做交叉注意力机制,Cross-Decoder本身不产生KV Cache。
[1] 一文通透各种注意力:从多头注意力MHA到分组查询注意力GQA、多查询注意力MQA
[2] Transformer系列:注意力机制的优化,MQA和GQA原理简述
[3] Navigating the Attention Landscape: MHA, MQA, and GQA Decoded
[4] 【NLP】(task2)图解attention+transformer(代码讲解)
[5] 缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA
[6] 理解Attention:从起源到MHA,MQA和GQA
[7] [KV Cache优化]MQA/GQA/YOCO/CLA笔记: 层内和层间KV Cache共享
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。