当前位置:   article > 正文

【LLM】大模型之扩展Context长度(RoPE等方法)_parallel context windows for large language models

parallel context windows for large language models

note

零、位置编码

我们都知道在经典的transformer模型中,输入的文本序列经过embedding层,为每个token转为对应向量表示后,还需要对词嵌入加入位置编码进行上下文语义的建模。为了得到不同位置对应的编码,transformer模型使用不同频率的正余弦函数(其中POS表示单词所在的位置,2i和2i+1表示位置编码向量中对应的维度,d是对应位置编码向量的总维度): PE ⁡ ( pos ⁡ , 2 i ) = sin ⁡ ( pos ⁡ 1000 0 2 i / d ) PE ⁡ ( pos ⁡ , 2 i + 1 ) = cos ⁡ ( pos ⁡ 100 0 2 i / d ) PE(pos,2i)=sin(pos100002i/d)PE(pos,2i+1)=cos(pos10002i/d) PE(pos,2i)=sin(100002i/dpos)PE(pos,2i+1)=cos(10002i/dpos)

在这里插入图片描述

这里我们也可以看下典型的位置编码代码,首先是pe = torch.zeros(max_len, d_model)是创建一个全0矩阵用于存储位置编码:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=512):
        super(PositionalEncoding, self).__init__()
        # 根据pos和i创建一个常量PE矩阵
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        # 缩放因子
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        # sin and cos position encoding
        # 对偶数位置进行编码
        pe[:, 0::2] = torch.sin(position * div_term)
        # 对奇数位置进行编码
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        # 不对位置编码层求梯度
        self.register_buffer('pe', pe)

    def forward(self, x):
        # 输入的词向量与位置编码相加
        x = x + self.pe[:x.size(0), :]
        return x
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

一、扩展LLM的Context长度

1. 常见方法

扩展LLM的Context长度其实已有不少,但多数是通过结合检索或者摘要的方式来缩短样本的长Context,如Unlimiformer。由于不是直接处理长Context,因此通常无法做精细的阅读理解,而且这些方案往往需要在训练阶段就考虑进去,而不是事后即插即用到已有的LLM模型中。

2. PCW方法

以前能够不微调地扩展Context长度的方案是Parallel Context Window(下面简称PCW),出自论文《Parallel Context Windows for Large Language Models》和《Structured Prompting: Scaling In-Context Learning to 1,000 Examples》,两篇论文是同一时期不同作者的工作,但所提的方法只有细微的差别。

PCW适用于Self Attention模型,主要修改包括Position Encoding和Attention Mask:
在这里插入图片描述

二、NBCE方法

使用朴素贝叶斯方法。
在这里插入图片描述
为了改善Random Sample的效果,将Pooling方式改为直接输出不确定性最低的那个分布:
[ log ⁡ p ( T ∣ S ) ] = log ⁡ p ( T ∣ S k ) k = arg ⁡ min ⁡ { H 1 , H 2 , ⋯   , H n } H i = − ∑ T p ( T ∣ S i ) log ⁡ p ( T ∣ S i ) [logp(TS)]=logp(TSk)k=argmin{H1,H2,,Hn}Hi=Tp(TSi)logp(TSi) [logp(TS)]=logp(TSk)k=argmin{H1,H2,,Hn}Hi=Tp(TSi)logp(TSi)

三、RoPE方法

RoPE的目标:构建一个位置相关的投影矩阵, 使得
( R m q ) ⊤ ( R n k ) = q ⊤ R m ⊤ R n k = q ⊤ R n − m k \left(\mathbf{R}_m \mathbf{q}\right)^{\top}\left(\mathbf{R}_n \mathbf{k}\right)=\mathbf{q}^{\top} \mathbf{R}_m^{\top} \mathbf{R}_n \mathbf{k}=\mathbf{q}^{\top} \mathbf{R}_{n-m} \mathbf{k} (Rmq)(Rnk)=qRmRnk=qRnmk

  • 对位置编码的转换称为位置插值。在这一步中,我们将位置索引从 [ 0 , L ′ ) \left[0, L^{\prime}\right) [0,L)减小到 [ 0 ,   L ) [0, \mathrm{~L}) [0, L),以匹配计算 RoPE 之前的原始索引范围。
  • 因此,作为 RoPE 的输入,任意两个标记之间的最大相对距离从 L ′ L^{\prime} L减小到 L L L 。由于我们在扩展之前和之后对位置索引和相对距离的范围进行了对齐,减轻了上下文窗口扩展对注意力得分计算的影响,这使得模型更容易适应。
  • 为了进一步证明这一点,下面的定理表明插值后的注意力得分具有良好的性质:

在这里插入图片描述

比如在chatGLM中也用到了旋转位置编码(下面的GLMBlock中的SelfAttention模块):

ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (word_embeddings): Embedding(130528, 4096)
    (layers): ModuleList(
      (0-27): 28 x GLMBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attention): SelfAttention(
          (rotary_emb): RotaryEmbedding()
          (query_key_value): QuantizedLinear(in_features=4096, out_features=12288, bias=True)
          (dense): QuantizedLinear(in_features=4096, out_features=4096, bias=True)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): GLU(
          (dense_h_to_4h): QuantizedLinear(in_features=4096, out_features=16384, bias=True)
          (dense_4h_to_h): QuantizedLinear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (final_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=4096, out_features=130528, bias=False)
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

具体的旋转编码类代码如下:

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        inv_freq = inv_freq.half()
        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
                              error_msgs):
        pass

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len
            t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16:
                emb = emb.float()

            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

    def _apply(self, fn):
        if self.cos_cached is not None:
            self.cos_cached = fn(self.cos_cached)
        if self.sin_cached is not None:
            self.sin_cached = fn(self.sin_cached)
        return super()._apply(fn)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

四、FlashAttention方法

  • 最新发布的ChatGLM2-2B就用了这种方法将上下文长度(Context Length)由 ChatGLM-6B 的 2K 扩展到了 32K,并在对话阶段使用 8K 的上下文长度训练,允许更多轮次的对话。
  • 用了哈希感知(hash-aware)的技术,可以根据它们的相似性将输入序列中的元素分配到不同的桶(bucket)中。这样,模型只需要计算桶元素之间的注意力权重,而不是整个序列。

在这里插入图片描述

Reference

[1] Transformer升级之路:10、RoPE是一种β进制编码. 苏剑林
[2] NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.
[3] ​Bias项的神奇作用:RoPE + Bias = 更好的长度外推性
[4] NBCE:使用朴素贝叶斯扩展LLM的Context处理长度.苏剑林
[5] Improving language models by retrieving from trillions of tokens. DeepMind的RETRO做法(以前是将文本分块chunk依次提供给LLM,但这只是暂时的做法)
[6] 【自然语言处理】【大模型】ChatGLM-6B模型结构代码解析(单机版)
[7] https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py
[8] 也谈langchain大模型外挂知识库问答系统核心部件:如何更好地解析、分割复杂非结构化文本
[9] 田渊栋团队新作:通过位置插值来扩展大语言模型的上下文窗口 Extending Context Window of Large Language Models via Positional Interpolation
[10] RoPE可能是LLM时代的Resnet.刘俊是
[12] flash-attention:https://github.com/Dao-AILab/flash-attention(chatglm2-6b就用了这个实现支持更长的上下文)
[13] https://huggingface.co/THUDM/chatglm2-6b
[14] NBCE:使用朴素贝叶斯扩展LLM的Context处理长度
[15] 为什么gpt模型输入的token最大数量被限制在几万,是有技术问题吗?
[16] 图解RoPE旋转位置编码及其特性.Yeungnlp(公式推倒)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/我家小花儿/article/detail/371425
推荐阅读
相关标签
  

闽ICP备14008679号