赞
踩
Qwen2项目地址:QwenLM/Qwen2
从4.37.0版本开始,transformers集成了Qwen2的代码),因此要使用Qwen2,需要transformers>=4.37.0
,Qwen2的代码地址在transformers/models/qwen2
目录下。
和Qwen一样,Qwen2仍然是一个decoder-only的transformer模型,使用RMSNorm、SwiGLU激活函数、RoPE、多头注意力机制等。
层标准化采用的是Root Mean Square Layer Normalization(RMSNorm)。RMSNorm的思想很简单,其根据均方根(RMS)对神经网络层的输出进行正则化,如以下公式所示:
a ˉ i = a i RMS ( a ) g i , where RMS ( a ) = 1 n ∑ i = 1 n a i 2 \bar{a}_{i}=\frac{a_{i}}{\operatorname{RMS}(\mathbf{a})} g_{i}\,, \quad \text{where}\ \operatorname{RMS}(\mathbf{a})=\sqrt{\frac{1}{n} \sum_{i=1}^{n} a_{i}^{2}} aˉi=RMS(a)aigi,where RMS(a)=n1i=1∑nai2
其中, a ∈ R n \mathbf{a} \in \mathbb{R}^n a∈Rn为网络层的输出向量, g ∈ R n \mathbf{g} \in \mathbb{R}^n g∈Rn是用于缩放标准化后的输出的增益参数,在开始时设置为1。
下面是Qwen2RMSNorm
的代码,基本上是根据上述的公式实现的:
class Qwen2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Qwen2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
Qwen2中在三个地方使用了LayerNorm:
hidden_states
被输入到自注意力子层之前会先应用LayerNorm:class Qwen2DecoderLayer(nn.Module): def __init__(self, config: Qwen2Config, layer_idx: int): ... self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ... def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) # Self Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = residual + hidden_states ...
hidden_states
被输入到全连接子层之前会先应用LayerNorm:class Qwen2DecoderLayer(nn.Module): def __init__(self, config: Qwen2Config, layer_idx: int): ... self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: # Self Attention ... # Fully Connected residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states ...
hidden_states
应用LayerNorm:class Qwen2Model(Qwen2PreTrainedModel): def __init__(self, config: Qwen2Config): ... self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ... def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: ... for decoder_layer in self.layers:... hidden_states = self.norm(hidden_states) ...
位置编码使用的是Rotary Position Embedding,它使用旋转矩阵对绝对位置进行编码,同时在自注意力公式中结合了明确的相对位置依赖,也就是说它将相对位置信息依赖集成到了self-attention中。因此该方法的位置编码是发生在注意力的计算过程中,并非之前的在输入tokens时将位置embedding和token embedding相加。
RoPE模块代码如下,主要是为了计算出RoPE方法中的cos值和sin值:
# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2 class Qwen2RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), self.sin_cached[:seq_len].to(dtype=x.dtype), )
得到cos值和sin值后,使用下面的函数进行绝对位置编码和集成相对位置信息依赖:
# Copied from transformers.models.llama.modeling_llama.rotate_half def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: q (`torch.Tensor`): The query tensor. k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. position_ids (`torch.Tensor`): The position indices of the tokens corresponding to the query and key tensors. For example, this can be used to pass offsetted position ids when working with a KV-cache. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed
注意到,这里的实现跟论文中的实现有一点不一样:
(
x
m
(
2
i
−
1
)
′
x
m
(
2
i
)
′
)
=
(
cos
m
θ
i
−
sin
m
θ
i
sin
m
θ
i
cos
m
θ
i
)
(
x
m
(
2
i
−
1
)
x
m
(
2
i
)
)
,
i
∈
[
1
,
…
,
d
/
2
]
\left(
(
x
m
(
i
)
′
x
m
(
i
+
d
/
2
)
′
)
=
(
cos
m
θ
i
−
sin
m
θ
i
sin
m
θ
i
cos
m
θ
i
)
(
x
m
(
i
)
x
m
(
i
+
d
/
2
)
)
,
i
∈
[
1
,
…
,
d
/
2
]
\left(
虽然这两种实现方式在对输入向量的元素的排列组合上有所差异(这里应该也是为了实现上的方便),但是原理是一样的,对结果也没影响,都能达到集成相对位置信息依赖的目的。
激活函数只用在decoder模块的前馈网络中,用的SwiGLU(门控线性单元的一种变体),代码如下:
# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
class Qwen2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
这里面的config.hidden_act
根据配置文件中的设置,其值为"silu"
,也就是SiLU(Sigmoid Linear Unit)函数:
SiLU ( x ) = x ∗ σ ( x ) , 其中 σ ( x ) 为 Sigmoid 函数 \text{SiLU}(x) = x \ast \sigma(x), \quad \text{其中 } \sigma(x) \text{ 为 Sigmoid 函数} SiLU(x)=x∗σ(x),其中 σ(x) 为 Sigmoid 函数
SiLU函数也称为Swish函数,用Swish函数替换原始GLU中的Sigmoid,就得到了SwiGLU:
SwiGLU ( x , W , V ) = Swish ( x W ) ⊗ ( x V ) , ⊗ 表示元素乘积 \text{SwiGLU}(x, W, V) = \text{Swish}(xW) \otimes (xV),\otimes\text{表示元素乘积} SwiGLU(x,W,V)=Swish(xW)⊗(xV),⊗表示元素乘积
self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
这行代码就实现了使用SwiGLU激活函数的FFN的前向计算:
FFN SwiGLU ( x , W , V , W 2 ) = ( Swish ( x W ) ⊗ x V ) W 2 \text{FFN}_{\text{SwiGLU}}(x, W, V, W2) = (\text{Swish}(xW) \otimes xV)W2 FFNSwiGLU(x,W,V,W2)=(Swish(xW)⊗xV)W2
在Qwen2中,有三种实现自注意力机制的方式:
QWEN2_ATTENTION_CLASSES = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
"sdpa": Qwen2SdpaAttention,
}
eager
自定义实现
flash_attention_2
基于FlashAttention2(支持滑动窗口注意力)
sdpa
Qwen2支持三种类型的自注意力机制:
自注意力层的部分初始化代码如下:
class Qwen2Attention(nn.Module):
def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
....
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
...
其中config.num_key_value_heads
的不同取值决定了使用哪种自注意力类型:
config.num_key_value_heads=config.num_attention_heads
表示使用MHAconfig.num_key_value_heads=1
表示使用MQA1 < config.num_key_value_heads < config.num_attention_heads
表示使用GQACopyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。