当前位置:   article > 正文

通义千问Qwen2架构解析_qwen1.5 源码解读

qwen1.5 源码解读

概述

Qwen2项目地址:QwenLM/Qwen2

从4.37.0版本开始,transformers集成了Qwen2的代码),因此要使用Qwen2,需要transformers>=4.37.0,Qwen2的代码地址在transformers/models/qwen2目录下。

和Qwen一样,Qwen2仍然是一个decoder-only的transformer模型,使用RMSNorm、SwiGLU激活函数、RoPE、多头注意力机制等。

层标准化(Layer Normalization)

层标准化采用的是Root Mean Square Layer Normalization(RMSNorm)。RMSNorm的思想很简单,其根据均方根(RMS)对神经网络层的输出进行正则化,如以下公式所示:

a ˉ i = a i RMS ⁡ ( a ) g i   , where  RMS ⁡ ( a ) = 1 n ∑ i = 1 n a i 2 \bar{a}_{i}=\frac{a_{i}}{\operatorname{RMS}(\mathbf{a})} g_{i}\,, \quad \text{where}\ \operatorname{RMS}(\mathbf{a})=\sqrt{\frac{1}{n} \sum_{i=1}^{n} a_{i}^{2}} aˉi=RMS(a)aigi,where RMS(a)=n1i=1nai2

其中, a ∈ R n \mathbf{a} \in \mathbb{R}^n aRn为网络层的输出向量, g ∈ R n \mathbf{g} \in \mathbb{R}^n gRn是用于缩放标准化后的输出的增益参数,在开始时设置为1。

下面是Qwen2RMSNorm的代码,基本上是根据上述的公式实现的:

class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

LayerNorm的位置

Qwen2中在三个地方使用了LayerNorm:

  1. 在decoder层中,hidden_states被输入到自注意力子层之前会先应用LayerNorm:
class Qwen2DecoderLayer(nn.Module):
    def __init__(self, config: Qwen2Config, layer_idx: int):
        ...
        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        ...
    
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_value=past_key_value,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
        )
        hidden_states = residual + hidden_states
        
        ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
'
运行
  1. 在decoder层中,自注意力子层输出的hidden_states被输入到全连接子层之前会先应用LayerNorm:
class Qwen2DecoderLayer(nn.Module):
    def __init__(self, config: Qwen2Config, layer_idx: int):
        ...
        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        # Self Attention
        ...
        
        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states
        
        ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  1. 对最后一个decoder层输出的hidden_states应用LayerNorm:
class Qwen2Model(Qwen2PreTrainedModel):
    def __init__(self, config: Qwen2Config):
        ...
        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        ...
    
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        ...
        
        for decoder_layer in self.layers:...
        
        hidden_states = self.norm(hidden_states)
        
        ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

位置编码

位置编码使用的是Rotary Position Embedding,它使用旋转矩阵对绝对位置进行编码,同时在自注意力公式中结合了明确的相对位置依赖,也就是说它将相对位置信息依赖集成到了self-attention中。因此该方法的位置编码是发生在注意力的计算过程中,并非之前的在输入tokens时将位置embedding和token embedding相加。

RoPE模块代码如下,主要是为了计算出RoPE方法中的cos值和sin值:

# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
class Qwen2RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()

        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build here to make `torch.jit.trace` work.
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len, device, dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)

        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

得到cos值和sin值后,使用下面的函数进行绝对位置编码和集成相对位置信息依赖:

# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35

注意到,这里的实现跟论文中的实现有一点不一样:

  • 论文中是对输入向量 x m ∈ R d \bold{x}_m \in \R^d xmRd d d d个元素从左到右排列后每两个数组成一个二维向量,再对其应用相应的二维旋转矩阵,如下面的公式和示意图所示:

( x m ( 2 i − 1 ) ′ x m ( 2 i ) ′ ) = ( cos ⁡ m θ i − sin ⁡ m θ i sin ⁡ m θ i cos ⁡ m θ i ) ( x m ( 2 i − 1 ) x m ( 2 i ) ) , i ∈ [ 1 , … , d / 2 ] \left(

xm(2i1)xm(2i)
\right) = \left(
cosmθisinmθisinmθicosmθi
\right) \left(
xm(2i1)xm(2i)
\right), \quad i \in [1,\dots,d/2] (xm(2i1)xm(2i))=(cosmθisinmθisinmθicosmθi)(xm(2i1)xm(2i)),i[1,,d/2]

RoPE示意图

  • 这里的实现则是将 x m ∈ R d \bold{x}_m \in \R^d xmRd d d d个元素从左到右排列后分为前后相同长度的两部分,这两部分对应位置的元素组成一个二维向量,再对其应用相应的二维旋转矩阵,如下面的公式所示:

( x m ( i ) ′ x m ( i + d / 2 ) ′ ) = ( cos ⁡ m θ i − sin ⁡ m θ i sin ⁡ m θ i cos ⁡ m θ i ) ( x m ( i ) x m ( i + d / 2 ) ) , i ∈ [ 1 , … , d / 2 ] \left(

xm(i)xm(i+d/2)
\right) = \left(
cosmθisinmθisinmθicosmθi
\right) \left(
xm(i)xm(i+d/2)
\right), \quad i \in [1,\dots,d/2] (xm(i)xm(i+d/2))=(cosmθisinmθisinmθicosmθi)(xm(i)xm(i+d/2)),i[1,,d/2]

虽然这两种实现方式在对输入向量的元素的排列组合上有所差异(这里应该也是为了实现上的方便),但是原理是一样的,对结果也没影响,都能达到集成相对位置信息依赖的目的。

激活函数

激活函数只用在decoder模块的前馈网络中,用的SwiGLU(门控线性单元的一种变体),代码如下:

# Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Qwen2
class Qwen2MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

这里面的config.hidden_act根据配置文件中的设置,其值为"silu",也就是SiLU(Sigmoid Linear Unit)函数:

SiLU ( x ) = x ∗ σ ( x ) , 其中  σ ( x )  为 Sigmoid 函数 \text{SiLU}(x) = x \ast \sigma(x), \quad \text{其中 } \sigma(x) \text{ 为 Sigmoid 函数} SiLU(x)=xσ(x),其中 σ(x)  Sigmoid 函数

SiLU函数也称为Swish函数,用Swish函数替换原始GLU中的Sigmoid,就得到了SwiGLU:

SwiGLU ( x , W , V ) = Swish ( x W ) ⊗ ( x V ) , ⊗ 表示元素乘积 \text{SwiGLU}(x, W, V) = \text{Swish}(xW) \otimes (xV),\otimes\text{表示元素乘积} SwiGLU(x,W,V)=Swish(xW)(xV)表示元素乘积

self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))这行代码就实现了使用SwiGLU激活函数的FFN的前向计算:

FFN SwiGLU ( x , W , V , W 2 ) = ( Swish ( x W ) ⊗ x V ) W 2 \text{FFN}_{\text{SwiGLU}}(x, W, V, W2) = (\text{Swish}(xW) \otimes xV)W2 FFNSwiGLU(x,W,V,W2)=(Swish(xW)xV)W2

自注意力

自注意力实现方式

在Qwen2中,有三种实现自注意力机制的方式:

QWEN2_ATTENTION_CLASSES = {
    "eager": Qwen2Attention,
    "flash_attention_2": Qwen2FlashAttention2,
    "sdpa": Qwen2SdpaAttention,
}
  • 1
  • 2
  • 3
  • 4
  • 5

自注意力类型

Qwen2支持三种类型的自注意力机制:

自注意力层的部分初始化代码如下:

class Qwen2Attention(nn.Module):

    def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None):
        ....
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        ...
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

其中config.num_key_value_heads的不同取值决定了使用哪种自注意力类型:

  • config.num_key_value_heads=config.num_attention_heads表示使用MHA
  • config.num_key_value_heads=1表示使用MQA
  • 1 < config.num_key_value_heads < config.num_attention_heads表示使用GQA
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号