当前位置:   article > 正文

Chatglm2-6b模型解析

chatglm2-6b

本文利用chatglm2-6b huggingface上的模型源码介绍其结构,结合一些论文博客对chatglm2模型进行分解。

模型参数

Chatglm2-6b模型参数包括28个GLM层(由MLP和自注意力组成),注意力的头数为32,采用Multi-Query Attention,隐藏层层数28。位置编码采用旋转位置编码,激活函数为SwiGLU,归一化方法为RMSNorm。

整体模型结构

ChatGLMModel (假设输入X大小为 3x5)

  • (embedding) Embedding (转置后 5x3x4096)
    • word_embeddings: Embedding(65024, 4096)
  • (rotary_pos_emb) RotaryEmbedding()
  • (encoder) GLMTransformer
    • (layers) ModuleList
      • 0-27: 28 x GLMBlock
        • (input_layernorm) RMSNorm() (输入输出大小: 5x3x4096)
        • (self_attention) SelfAttention
          • (query_key_value) Linear(in_features=4096, out_features=4608, bias=True)
          • (core_attention) CoreAttention(
            • (attention_dropout) Dropout(p=0.0, inplace=False))
          • (dense) Linear(in_features=4096, out_features=4096, bias=False)
        • (post_attention_layernorm) RMSNorm()
        • (mlp) MLP
          • (dense_h_to_4h) Linear(in_features=4096, out_features=27392, bias=False)
          • (dense_4h_to_h) Linear(in_features=13696, out_features=4096, bias=False)
    • (final_layernorm) RMSNorm()
  • (output_layer) Linear(in_features=4096, out_features=65024, bias=False) (输出大小: 3x5x65024)

激活函数:SwiGLU

SwiGLU(x,W,V,b,c,β) ⁡ = Swish ⁡ β ( x W + b ) ⊗ ( x V + c ) \operatorname{SwiGLU(x, W, V, b, c, \beta)}=\operatorname{Swish}_{\beta}(x W+b) \otimes(xV+c) SwiGLU(x,W,V,b,c,β)=Swishβ(xW+b)(xV+c)
其中 Swish ⁡ β ( x ) = x σ ( β x ) \operatorname{Swish}_\beta(x)=x \sigma(\beta x) Swishβ(x)=xσ(βx), β \beta β为指定常数,常为1。
对应于chatglm2-6b中的源码

def swiglu(x):
    x = torch.chunk(x, 2, dim=-1)
    return F.silu(x[0]) * x[1]
  • 1
  • 2
  • 3

旋转位置编码:RoPE
旋转位置编码的目的是用上不同token的相对位置。
假定 query 向量 q m \boldsymbol{q}_m qm 和 key 向量 k n \boldsymbol{k}_n kn 之间 的内积操作可以被一个函数 g g g 表示,该函数 g g g 的输入是词嵌入向量 x m , x n \boldsymbol{x}_m , \boldsymbol{x}_n xmxn 和它们之间的相对位置为 m − n m-n mn :
⟨ f q ( x m , m ) , f k ( x n , n ) ⟩ = g ( x m , x n , m − n ) \left\langle\boldsymbol{f}_q\left(\boldsymbol{x}_m, m\right), f_k\left(\boldsymbol{x}_n, n\right)\right\rangle=g\left(\boldsymbol{x}_m, \boldsymbol{x}_n, m-n\right) fq(xm,m),fk(xn,n)=g(xm,xn,mn)
这样就能够将原来的绝对位置编码转为相对位置编码,下面就是求解 g g g 就可以了。苏剑林等人的论文中提出了如下的公式解决该问题。具体推导过程也可以参考该作者的博客。
f q ( x m , m ) = ( W q x m ) e i m θ f k ( x n , n ) = ( W k x n ) e i n θ g ( x m , x n , m − n ) = Re ⁡ [ ( W q x m ) ( W k x n ) ∗ e i ( m − n ) θ ]

fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ]
fq(xm,m)=(Wqxm)eimθfk(xn,n)=(Wkxn)einθg(xm,xn,mn)=Re[(Wqxm)(Wkxn)ei(mn)θ]
进一步地, f q f_q fq 可以表示成下面的式子:
f q ( x m , m ) = ( cos ⁡ m θ − sin ⁡ m θ ) sin ⁡ m θ cos ⁡ m θ ) ( W q ( 1 , 1 ) W q ( 1 , 2 ) W q ( 2 , 1 ) W q ( 2 , 2 ) ) ( x m ( 1 ) x m ( 2 ) ) = ( cos ⁡ m θ − sin ⁡ m θ ) sin ⁡ m θ cos ⁡ m θ ) ( q m ( 1 ) q m ( 2 ) )
fq(xm,m)=(cosmθsinmθ)sinmθcosmθ)(Wq(1,1)Wq(1,2)Wq(2,1)Wq(2,2))(xm(1)xm(2))=(cosmθsinmθ)sinmθcosmθ)(qm(1)qm(2))
fq(xm,m)=(cosmθsinmθsinmθ)cosmθ)(Wq(1,1)Wq(2,1)Wq(1,2)Wq(2,2))(xm(1)xm(2))=(cosmθsinmθsinmθ)cosmθ)(qm(1)qm(2))

看到这里会发现,这不就是 query 向量乘以了一个旋转矩阵吗? 这就是为什么叫做旋转位置编码的原因。
同理, f k f_k fk 可以表示成下面的式子:
f k ( x m , m ) = ( cos ⁡ m θ − sin ⁡ m θ ) sin ⁡ m θ cos ⁡ m θ ) ( W k ( 1 , 1 ) W k ( 1 , 2 ) W k ( 2 , 1 ) W k ( 2 , 2 ) ) ( x m ( 1 ) x m ( 2 ) ) = ( cos ⁡ m θ − sin ⁡ m θ ) sin ⁡ m θ cos ⁡ m θ ) ( k m ( 1 ) k m ( 2 ) )
fk(xm,m)=(cosmθsinmθ)sinmθcosmθ)(Wk(1,1)Wk(1,2)Wk(2,1)Wk(2,2))(xm(1)xm(2))=(cosmθsinmθ)sinmθcosmθ)(km(1)km(2))
fk(xm,m)=(cosmθsinmθsinmθ)cosmθ)(Wk(1,1)Wk(2,1)Wk(1,2)Wk(2,2))(xm(1)xm(2))=(cosmθsinmθsinmθ)cosmθ)(km(1)km(2))

最终 g ( x m , x n , m − n ) g\left(\boldsymbol{x}_m, \boldsymbol{x}_n, m-n\right) g(xm,xn,mn) 可以表示如下:
g ( x m , x n , m − n ) = ( q m ( 1 ) q m ( 2 ) ) ( cos ⁡ ( ( m − n ) θ ) − sin ⁡ ( ( m − n ) θ ) sin ⁡ ( ( m − n ) θ ) cos ⁡ ( ( m − n ) θ ) ) ( k n ( 1 ) k n ( 2 ) ) g\left(\boldsymbol{x}_m, \boldsymbol{x}_n, m-n\right)=\left(
qm(1)qm(2)
\right)\left(
cos((mn)θ)sin((mn)θ)sin((mn)θ)cos((mn)θ)
\right)\left(
kn(1)kn(2)
\right)
g(xm,xn,mn)=(qm(1)qm(2))(cos((mn)θ)sin((mn)θ)sin((mn)θ)cos((mn)θ))(kn(1)kn(2))

将上面的式子扩展到任意维度,可以表示如下:
f { q , k } ( x m , m ) = R Θ , m d W { q , k } x m f_{\{q, k\}}\left(\boldsymbol{x}_m, m\right)=\boldsymbol{R}_{\Theta, m}^d \boldsymbol{W}_{\{q, k\}} \boldsymbol{x}_m f{q,k}(xm,m)=RΘ,mdW{q,k}xm
因为内积具有线性累加性,所以任意偶数维的RoPE,都可以表示为二维情形的拼接,即
R Θ , m d = ( cos ⁡ m θ 1 − sin ⁡ m θ 1 0 0 ⋯ 0 0 sin ⁡ m θ 1 cos ⁡ m θ 1 0 0 ⋯ 0 0 0 0 cos ⁡ m θ 2 − sin ⁡ m θ 2 ⋯ 0 0 0 0 sin ⁡ m θ 2 cos ⁡ m θ 2 ⋯ 0 0 ⋮ ⋮ ⋮ ⋮ ⋱ ⋮ ⋮ 0 0 0 0 ⋯ cos ⁡ m θ d / 2 − sin ⁡ m θ d / 2 0 0 0 0 ⋯ sin ⁡ m θ d / 2 cos ⁡ m θ d / 2 ) \boldsymbol{R}_{\Theta, m}^d=\left(
cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθ2sinmθ20000sinmθ2cosmθ2000000cosmθd/2sinmθd/20000sinmθd/2cosmθd/2
\right)
RΘ,md= cosmθ1sinmθ10000sinmθ1cosmθ1000000cosmθ2sinmθ20000sinmθ2cosmθ2000000cosmθd/2sinmθd/20000sinmθd/2cosmθd/2

考虑到上述矩阵的稀疏性,利用矩阵计算会十分浪费算力,因此推荐使用如下的方式实现:
R Θ , m d x = ( x 0 x 1 x 2 x 3 ⋮ x d − 2 x d − 1 ) ⊗ ( cos ⁡ m θ 0 cos ⁡ m θ 0 cos ⁡ m θ 1 cos ⁡ m θ 1 ⋮ cos ⁡ m θ d / 2 − 1 cos ⁡ m θ d / 2 − 1 ) + ( − x 1 x 0 − x 3 x 2 ⋮ − x d − 1 x d − 2 ) ⊗ ( sin ⁡ m θ 0 sin ⁡ m θ 0 sin ⁡ m θ 1 sin ⁡ m θ 1 ⋮ sin ⁡ m θ d / 2 − 1 sin ⁡ m θ d / 2 − 1 ) \boldsymbol{R}_{\Theta, m}^d \boldsymbol{x}=\left(
x0x1x2x3xd2xd1
\right) \otimes\left(
cosmθ0cosmθ0cosmθ1cosmθ1cosmθd/21cosmθd/21
\right)+\left(
x1x0x3x2xd1xd2
\right) \otimes\left(
sinmθ0sinmθ0sinmθ1sinmθ1sinmθd/21sinmθd/21
\right)
RΘ,mdx= x0x1x2x3xd2xd1 cosmθ0cosmθ0cosmθ1cosmθ1cosmθd/21cosmθd/21 + x1x0x3x2xd1xd2 sinmθ0sinmθ0sinmθ1sinmθ1sinmθd/21sinmθd/21

其中, ⊗ \otimes 表示按位相乘对应于pytorch中的*运算。
chatglm2-6b中的代码实现:

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, original_impl=False, device=None, dtype=None):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.dim = dim
        self.original_impl = original_impl

    def forward_impl(
            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
    ):
        """Enhanced Transformer with Rotary Position Embedding.

        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT License:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()
        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
        return cache

    def forward(self, max_seq_len, offset=0):
        return self.forward_impl(
            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
        )

def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    # np: number of partion; hn: hidden states number
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    rot_dim = rope_cache.shape[-2] * 2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

注意力层:multi-query attention

multi-query attention 是 multi-head的变种,采用多头共享query和key,主要作用在于节省内存和减少运算成本。
多头注意力机制公式:
Attention ⁡ ( Q , K , V ) = softmax ⁡ ( Q K T d k ) V \operatorname{Attention}(Q, K, V)=\operatorname{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) V Attention(Q,K,V)=softmax(dk QKT)V
MultiHead ⁡ ( Q , K , V ) = Concat ⁡ ( head ⁡ 1 , … , head ⁡ h ) W O  where head  = Attention ⁡ ( Q W i Q , K W i K , V W i V )

MultiHead(Q,K,V)=Concat(head1,,headh)WO where head =Attention(QWiQ,KWiK,VWiV)
MultiHead(Q,K,V) where head =Concat(head1,,headh)WO=Attention(QWiQ,KWiK,VWiV)

# 以下来自论文:Fast Transformer Decoding: One Write-Head is All You Need
def MultiheadAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    """Multi-head Attention.
    Args:
    X: a tensor with shape [b,n,d]
    M: a tensor with shape [b,m,d]
    mask: a tensor with shape [b,h,n,m]
    P_q: a tensor with shape [h,d,k]
    P_k: a tensor with shape [h,d,k]
    P_v: a tensor with shape [h,d,v]
    P_o: a tensor with shape [h,d,v]
    Returns:
    Y: a tensor with shape [b,n,d]
    """
    # b: batch size, m,n: sequence length, h: heads
    # k,v: dimension of key or value
    # d: hidden states
    Q = tf.einsum("bnd,hdk−>bhnk ", X, P_q)
    K = tf.einsum("bmd,hdk−>bhmk", M, P_k)
    V = tf.einsum("bmd,hdv−>bhmv", M, P_v)

    logits = tf.einsum("bhnk,bhmk−>bhnm ", Q, K)
    weights = tf.softmax(logits + mask)
    O = tf.einsum("bhnm,bhmv−>bhnv ", weights, V)
    Y = tf.einsum("bhnv,hdv−>bnd", O, P_o)
    return Y

def MultiqueryAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    """Multi-query Attention.
    Args:
    X: a tensor with shape [b,n,d]
    M: a tensor with shape [b,m,d]
    mask: a tensor with shape [b,h,n,m]
    P_q: a tensor with shape [h,d,k]
    P_k: a tensor with shape [d,k]
    P_v: a tensor with shape [d,v]
    P_o: a tensor with shape [h,d,v]
    Returns:
    Y: a tensor with shape [b,n,d]
    """
    # b: batch size, m,n: sequence length, h: heads
    # k,v: dimension of key or value
    # d: hidden states
    Q = tf.einsum("bnd,hdk−>bhnk ", X, P_q)
    K = tf.einsum("bmd,dk−>bmk", M, P_k)
    V = tf.einsum("bmd,dv−>bmv", M, P_v)
    logits = tf.einsum("bhnk,bmk−>bhnm", Q, K)
    weights = tf.softmax(logits + mask)
    O = tf.einsum("bhnm,bmv−>bhnv ", weights, V)
    Y = tf.einsum("bhnv,hdv−>bnd ", O, P_o)
    return Y
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51

注意力掩码:Attention mask

chatglm2-6b仍然采用GLM-10B的注意力编码方式。
在这里插入图片描述

Part A tokens can attend to each other, but cannot attend to any
tokens in B. Part B tokens can attend to Part A and antecedents in B,
but cannot attend to any subsequent tokens in B. To enable
autoregressive generation, each span is padded with special tokens
[START] and [END], for input and output respectively. In this way, our
model automatically learns a bidirectional encoder (for Part A) and a
unidirectional decoder (for Part B) in a unified model. (GLM, 2022)

A部分的token可以相互关注,但是不能关注到B部分的token。B部分的tokens 可以关注 A 和 B 中的前项,但不能关注 B 中的任何后续 tokens。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/360789
推荐阅读
相关标签
  

闽ICP备14008679号