当前位置:   article > 正文




        Qwen是一个decoder-only的transformer模型,使用SwiGLU激活函数、Rotary Positional Embedding(ROPE)等,以提高模型存储和运行效率。


  1. from transformers_sc.src.transformers.models.qwen2 import Qwen2Config, Qwen2Model
  2. import torch
  3. def run_qwen2():
  4. qwen2config = Qwen2Config(vocab_size=151936,
  5. hidden_size=4096//2,
  6. intermediate_size=22016//2,
  7. num_hidden_layers=32//2,
  8. num_attention_heads=32,
  9. max_position_embeddings=2048//2)
  10. qwen2model = Qwen2Model(config=qwen2config)
  11. input_ids = torch.randint(0, qwen2config.vocab_size, (4,30))
  12. res = qwen2model(input_ids)
  13. print(type(res))
  14. if __name__ == "__main__":
  15. run_qwen2()

试着读了一下ROPE的具体操作。。。 先准备好ROPE需要的一些参数,主要是cos和sin值。

  1. #定义一个旋转位置嵌入类,通过预计算和缓存cos和sin值,提高模型处理长序列时的效率。
  2. class Qwen2RotaryEmbedding(nn.Module):
  3. #初始化方法,接受维度、最大位置嵌入数量、频率、设备四个参数
  4. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  5. super().__init__()
  6. self.dim = dim
  7. self.max_position_embeddings = max_position_embeddings
  8. self.base = base
  9. #计算逆频率(register_buffer对逆频率在模型中注册参数,'persistent=False'指该参数在训练过程中不会更新)
  10. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
  11. self.register_buffer("inv_freq", inv_freq, persistent=False)
  12. # Build here to make `torch.jit.trace` work.
  13. self._set_cos_sin_cache(
  14. seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
  15. )
  16. def _set_cos_sin_cache(self, seq_len, device, dtype):
  17. self.max_seq_len_cached = seq_len
  18. t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
  19. #计算t和inv_ferq的外积freqs,其与自身在最后一个维度拼接成一个新张量emb
  20. freqs = torch.outer(t, self.inv_freq)
  21. # Different from paper, but it uses a different permutation in order to obtain the same calculation
  22. emb = torch.cat((freqs, freqs), dim=-1)
  23. self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
  24. self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
  25. #定义前向传播函数forward,其中x为输入张量,seq_len为序列长度
  26. def forward(self, x, seq_len=None):
  27. # x: [bs, num_attention_heads, seq_len, head_size]
  28. if seq_len > self.max_seq_len_cached:
  29. self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
  30. return (
  31. self.cos_cached[:seq_len].to(dtype=x.dtype),
  32. self.sin_cached[:seq_len].to(dtype=x.dtype),
  33. )


  1. # Copied from transformers.models.llama.modeling_llama.rotate_half
  2. #定义半旋转函数,将输入张量 x 的后一半维度旋转到前面,同时将前一半维度保持不变
  3. def rotate_half(x):
  4. """Rotates half the hidden dims of the input."""
  5. x1 = x[..., : x.shape[-1] // 2]
  6. x2 = x[..., x.shape[-1] // 2 :]
  7. return torch.cat((-x2, x1), dim=-1)
  8. # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
  9. def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
  10. """Applies Rotary Position Embedding to the query and key tensors.
  11. Args:
  12. q (`torch.Tensor`): The query tensor.
  13. k (`torch.Tensor`): The key tensor.
  14. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  15. sin (`torch.Tensor`): The sine part of the rotary embedding.
  16. position_ids (`torch.Tensor`):
  17. The position indices of the tokens corresponding to the query and key tensors. For example, this can be
  18. used to pass offsetted position ids when working with a KV-cache.
  19. unsqueeze_dim (`int`, *optional*, defaults to 1):
  20. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  21. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  22. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  23. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  24. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  25. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  26. Returns:
  27. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  28. """
  29. """
  30. q: 查询张量Query
  31. k: 键张量Key
  32. cos: 旋转位置编码中的余弦部分
  33. sin: 旋转位置编码中的正弦部分
  34. position_ids: 查询和键张量对应的位置信息。
  35. unsqueeze_dim: 用于广播的维度,默认值为1。
  36. 返回是包含了两个torch.Tensor对象即经过旋转位置编码处理后的查询张量(q_embed)和键张量(k_embed)的元组
  37. """
  38. cos = cos[position_ids].unsqueeze(unsqueeze_dim)
  39. sin = sin[position_ids].unsqueeze(unsqueeze_dim)
  40. q_embed = (q * cos) + (rotate_half(q) * sin)
  41. k_embed = (k * cos) + (rotate_half(k) * sin)
  42. return q_embed, k_embed





        然后由于内积线性可加,2D可以推广到General Form,即任意偶数维:


  1. class Qwen2MLP(nn.Module):
  2. def __init__(self, config):
  3. super().__init__()
  4. self.config = config
  5. self.hidden_size = config.hidden_size
  6. self.intermediate_size = config.intermediate_size
  7. # 三个Linear层,有门控机制,将输入张量映射到中间层大小
  8. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  9. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  10. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  11. # 激活函数,根据ACT2FN字典里的配置选择相应的激活函数
  12. self.act_fn = ACT2FN[config.hidden_act]
  13. def forward(self, x):
  14. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        forward 方法定义了模型的前向传播逻辑:

        1.首先,通过 gate_proj 将输入 x 映射到中间大小,并应用激活函数 act_fn。

        2.将门控后的结果乘以通过 up_proj 映射的输入 x,实现门控的作用。

        3.最后,将门控和更新后的结果通过 down_proj 映射回隐藏大小,作为最终的输出。




