当前位置:   article > 正文

Qwen学习心得

Qwen学习心得

        最近参加了datawhale的五月组队学习,主题是手搓大模型实战课程。第一课学习的是Qwen大模型的一些原理,作为一位技术小白只能艰难地边生啃其中的内容边恶补一些基础的知识。以下是我个人的学习心得,可能多处有谬误,zhendehencai......

        Qwen是一个decoder-only的transformer模型,使用SwiGLU激活函数、Rotary Positional Embedding(ROPE)等,以提高模型存储和运行效率。

        先跟着课程的老师写了个运行的demo,通过在model处设断点调试可以看到代码进行的流程。

  1. from transformers_sc.src.transformers.models.qwen2 import Qwen2Config, Qwen2Model
  2. import torch
  3. def run_qwen2():
  4. qwen2config = Qwen2Config(vocab_size=151936,
  5. hidden_size=4096//2,
  6. intermediate_size=22016//2,
  7. num_hidden_layers=32//2,
  8. num_attention_heads=32,
  9. max_position_embeddings=2048//2)
  10. qwen2model = Qwen2Model(config=qwen2config)
  11. input_ids = torch.randint(0, qwen2config.vocab_size, (4,30))
  12. res = qwen2model(input_ids)
  13. print(type(res))
  14. if __name__ == "__main__":
  15. run_qwen2()

试着读了一下ROPE的具体操作。。。 先准备好ROPE需要的一些参数,主要是cos和sin值。

  1. #定义一个旋转位置嵌入类,通过预计算和缓存cos和sin值,提高模型处理长序列时的效率。
  2. class Qwen2RotaryEmbedding(nn.Module):
  3. #初始化方法,接受维度、最大位置嵌入数量、频率、设备四个参数
  4. def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
  5. super().__init__()
  6. self.dim = dim
  7. self.max_position_embeddings = max_position_embeddings
  8. self.base = base
  9. #计算逆频率(register_buffer对逆频率在模型中注册参数,'persistent=False'指该参数在训练过程中不会更新)
  10. inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
  11. self.register_buffer("inv_freq", inv_freq, persistent=False)
  12. # Build here to make `torch.jit.trace` work.
  13. self._set_cos_sin_cache(
  14. seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
  15. )
  16. def _set_cos_sin_cache(self, seq_len, device, dtype):
  17. self.max_seq_len_cached = seq_len
  18. t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
  19. #计算t和inv_ferq的外积freqs,其与自身在最后一个维度拼接成一个新张量emb
  20. freqs = torch.outer(t, self.inv_freq)
  21. # Different from paper, but it uses a different permutation in order to obtain the same calculation
  22. emb = torch.cat((freqs, freqs), dim=-1)
  23. self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
  24. self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
  25. #定义前向传播函数forward,其中x为输入张量,seq_len为序列长度
  26. def forward(self, x, seq_len=None):
  27. # x: [bs, num_attention_heads, seq_len, head_size]
  28. if seq_len > self.max_seq_len_cached:
  29. self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
  30. return (
  31. self.cos_cached[:seq_len].to(dtype=x.dtype),
  32. self.sin_cached[:seq_len].to(dtype=x.dtype),
  33. )

计算出ROPE方法的cos和sin值后,再来实现ROPE。

  1. # Copied from transformers.models.llama.modeling_llama.rotate_half
  2. #定义半旋转函数,将输入张量 x 的后一半维度旋转到前面,同时将前一半维度保持不变
  3. def rotate_half(x):
  4. """Rotates half the hidden dims of the input."""
  5. x1 = x[..., : x.shape[-1] // 2]
  6. x2 = x[..., x.shape[-1] // 2 :]
  7. return torch.cat((-x2, x1), dim=-1)
  8. # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
  9. def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
  10. """Applies Rotary Position Embedding to the query and key tensors.
  11. Args:
  12. q (`torch.Tensor`): The query tensor.
  13. k (`torch.Tensor`): The key tensor.
  14. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  15. sin (`torch.Tensor`): The sine part of the rotary embedding.
  16. position_ids (`torch.Tensor`):
  17. The position indices of the tokens corresponding to the query and key tensors. For example, this can be
  18. used to pass offsetted position ids when working with a KV-cache.
  19. unsqueeze_dim (`int`, *optional*, defaults to 1):
  20. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  21. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  22. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  23. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  24. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  25. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  26. Returns:
  27. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  28. """
  29. """
  30. q: 查询张量Query
  31. k: 键张量Key
  32. cos: 旋转位置编码中的余弦部分
  33. sin: 旋转位置编码中的正弦部分
  34. position_ids: 查询和键张量对应的位置信息。
  35. unsqueeze_dim: 用于广播的维度,默认值为1。
  36. 返回是包含了两个torch.Tensor对象即经过旋转位置编码处理后的查询张量(q_embed)和键张量(k_embed)的元组
  37. """
  38. cos = cos[position_ids].unsqueeze(unsqueeze_dim)
  39. sin = sin[position_ids].unsqueeze(unsqueeze_dim)
  40. q_embed = (q * cos) + (rotate_half(q) * sin)
  41. k_embed = (k * cos) + (rotate_half(k) * sin)
  42. return q_embed, k_embed

        旋转位置编码是一种特殊的相对位置编码,通过在注意力计算中引入旋转变换,将相对位置编码隐式引入查询和键的点积计算中。也就是说,在做查询和键张量内积的时候实现了相对位置的效果,这点可以从apply_rotary_pos_emb函数看出。

        瞄了眼论文,找了找相关的数学公式。

        

        后两项乘起来就是Query向量了,这不就是一个旋转矩阵和Query向量相乘吗?!

        然后由于内积线性可加,2D可以推广到General Form,即任意偶数维:

        后面的多层感知机函数模块:

  1. class Qwen2MLP(nn.Module):
  2. def __init__(self, config):
  3. super().__init__()
  4. self.config = config
  5. self.hidden_size = config.hidden_size
  6. self.intermediate_size = config.intermediate_size
  7. # 三个Linear层,有门控机制,将输入张量映射到中间层大小
  8. self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  9. self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
  10. self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
  11. # 激活函数,根据ACT2FN字典里的配置选择相应的激活函数
  12. self.act_fn = ACT2FN[config.hidden_act]
  13. def forward(self, x):
  14. return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

        forward 方法定义了模型的前向传播逻辑:

        1.首先,通过 gate_proj 将输入 x 映射到中间大小,并应用激活函数 act_fn。

        2.将门控后的结果乘以通过 up_proj 映射的输入 x,实现门控的作用。

        3.最后,将门控和更新后的结果通过 down_proj 映射回隐藏大小,作为最终的输出。

        其中config.hidden_act应该是使用了SwiGLU函数。在configuration的init里找,发现是“silu”.

         

        就先到这里吧。。。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/707071
推荐阅读
相关标签
  

闽ICP备14008679号