当前位置:   article > 正文

全网最细剖析SparseGPT算法实现 | 图解SparseGPT算法

sparsegpt

下面笔者将通过图文形式,来深入理解sparsegpt这个算法的核心实现过程。读者可以借助图来理解代码,看图就能读懂代码。

模型结构

我们需要先对模型结构有整体了解,才能知道每行代码的具体作用。

下面是查看opt.py文件中的模型结构:
在这里插入图片描述

然后查看这个模型的配置:
在这里插入图片描述

对于当执行layers = model.model.decoder.layers时,我们打印看看model.model输出啥:
在这里插入图片描述

# model.model
OPTModel(
  (decoder): OPTDecoder(
    (embed_tokens): Embedding(50272, 768, padding_idx=1)
    (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
    (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-11): 12 x OPTDecoderLayer(
        (self_attn): OPTAttention(
          (k_proj): Linear(in_features=768, out_features=768, bias=True)
          (v_proj): Linear(in_features=768, out_features=768, bias=True)
          (q_proj): Linear(in_features=768, out_features=768, bias=True)
          (out_proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (activation_fn): ReLU()
        (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      )
    )
  )
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
# model.model.decoder
OPTDecoder(
  (embed_tokens): Embedding(50272, 768, padding_idx=1)
  (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
  (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (layers): ModuleList(
    (0-11): 12 x OPTDecoderLayer(
      (self_attn): OPTAttention(
        (k_proj): Linear(in_features=768, out_features=768, bias=True)
        (v_proj): Linear(in_features=768, out_features=768, bias=True)
        (q_proj): Linear(in_features=768, out_features=768, bias=True)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
      )
      (activation_fn): ReLU()
      (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=768, out_features=3072, bias=True)
      (fc2): Linear(in_features=3072, out_features=768, bias=True)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
  )
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

对于当执行layers = model.model.decoder.layers时,我们打印看看layers 输出啥:

# layers 
ModuleList(
  (0-11): 12 x OPTDecoderLayer(
    (self_attn): OPTAttention(
      (k_proj): Linear(in_features=768, out_features=768, bias=True)
      (v_proj): Linear(in_features=768, out_features=768, bias=True)
      (q_proj): Linear(in_features=768, out_features=768, bias=True)
      (out_proj): Linear(in_features=768, out_features=768, bias=True)
    )
    (activation_fn): ReLU()
    (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (fc1): Linear(in_features=768, out_features=3072, bias=True)
    (fc2): Linear(in_features=3072, out_features=768, bias=True)
    (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

构造方法

下面将详细剖析sparsegpt.py这个文件中的代码。

    # 这里是debug过程中,会从opt.py中传入layer	可以看到这里传入的是Linear线性层
    # layer: Linear(in_features=768, out_features=3072, bias=True)
    # type(self.layer) -> <class 'torch.nn.modules.linear.Linear'>
    def __init__(self, layer):
        self.layer = layer
        self.dev = self.layer.weight.device
        # W.shape -> torch.Size([3072, 768])
        W = layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            W = W.flatten(1)

        if isinstance(self.layer, transformers.Conv1D):
            W = W.t()

        self.rows = W.shape[0]      # 3072
        self.columns = W.shape[1]   # 768
        self.H = torch.zeros((self.columns, self.columns), device=self.dev)
        self.nsamples = 0
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

计算海森矩阵的逆

    # 这个方法的主要作用,其实就是计算出H
    def add_batch(self, inp, out, blocksize=1024):
        if DEBUG:
            self.inp1 = inp
            self.out1 = out
        if len(inp.shape) == 2:
            inp = inp.unsqueeze(0)
        # inp.shape[0]表示当前批次输入的批量大小
        tmp = inp.shape[0]
        if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
            if len(inp.shape) == 3:
                inp = inp.reshape((-1, inp.shape[-1]))
            inp = inp.t()

        """
        表达式 self.H *= self.nsamples / (self.nsamples + tmp) 是在更新海森矩阵的估计值。理解为什么需要这样做:
        (1)加权平均:这个表达式实际上是在计算一个加权平均值。当新的样本(数量为 tmp)加入时,我们需要更新海森矩阵的估计值,以反映新样本和旧样本的综合信息。
        self.nsamples / (self.nsamples + tmp) 是一个权重,它表示在新的海森矩阵估计中,旧样本的比重。
        (2)平滑更新:这种方法允许海森矩阵的估计平滑地随着新数据的加入而更新,而不是每次都完全由最新的数据决定。
        这样可以防止由于新批次样本的随机性或异常值导致的海森矩阵估计的剧烈波动。
        (3)逐步逼近真实值:在理想情况下,如果我们能处理所有的样本,那么计算出的海森矩阵将是最准确的。但在实际应用中,由于计算和存储的限制,我们通常只能处理一部分样本。
        通过逐步更新的方法,我们可以尽可能地接近使用全部数据计算出的海森矩阵的真实值。
        """
        self.H *= self.nsamples / (self.nsamples + tmp)     # 更新海森矩阵的估计值

        # self.nsamples是之前样本的数量,tmp是当前批次的样本数量,累加样本数量
        self.nsamples += tmp

        """
        math.sqrt(2 / self.nsamples) * inp: 这是对 inp 进行缩放的操作。这里的缩放因子是 math.sqrt(2 / self.nsamples)。
        原因可能是为了对数据进行规范化,使其分布具有一定的标准性质。具体来说,这个缩放因子可能是基于某种统计假设,
        比如希望输入数据的方差保持不变。这种规范化在处理数据时很常见,尤其是在计算统计量(如海森矩阵)时。
        """
        inp = math.sqrt(2 / self.nsamples) * inp.float()

        """
        inp.matmul(inp.t()): 这是在计算 inp 和其转置 inp.t() 的矩阵乘积。
        在数学上,这个操作相当于计算外积,它产生一个方阵,其中的每个元素是输入向量的不同元素的乘积。
        self.H += ...: 这是将计算出的外积加到当前的海森矩阵估计 self.H 上。
        这个操作是在累积海森矩阵的估计值。通过这种方式,self.H 逐渐包含了更多批次数据的信息。
        """
        self.H += inp.matmul(inp.t())       # 更新海森矩阵的估计值
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

计算海森矩阵H的核心代码其实就是最后一步:self.H += inp.matmul(inp.t())。如下图所示:
在这里插入图片描述

但是作者做了两点很巧妙的优化:

(1)self.H *= self.nsamples / (self.nsamples + tmp)

这个其实是指数移动平均。

    表达式 self.H *= self.nsamples / (self.nsamples + tmp) 是在更新海森矩阵的估计值。理解为什么需要这样做:
    (1)加权平均:这个表达式实际上是在计算一个加权平均值。当新的样本(数量为 tmp)加入时,我们需要更新海森矩阵的估计值,以反映新样本和旧样本的综合信息。
    self.nsamples / (self.nsamples + tmp) 是一个权重,它表示在新的海森矩阵估计中,旧样本的比重。
    (2)平滑更新:这种方法允许海森矩阵的估计平滑地随着新数据的加入而更新,而不是每次都完全由最新的数据决定。
    这样可以防止由于新批次样本的随机性或异常值导致的海森矩阵估计的剧烈波动。
    (3)逐步逼近真实值:在理想情况下,如果我们能处理所有的样本,那么计算出的海森矩阵将是最准确的。但在实际应用中,由于计算和存储的限制,我们通常只能处理一部分样本。
    通过逐步更新的方法,我们可以尽可能地接近使用全部数据计算出的海森矩阵的真实值。
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

(2)inp = math.sqrt(2 / self.nsamples) * inp.float()

这个其实是在对inp进行缩放操作。

    math.sqrt(2 / self.nsamples) * inp: 这是对 inp 进行缩放的操作。这里的缩放因子是 math.sqrt(2 / self.nsamples)。
    原因可能是为了对数据进行规范化,使其分布具有一定的标准性质。具体来说,这个缩放因子可能是基于某种统计假设,
    比如希望输入数据的方差保持不变。这种规范化在处理数据时很常见,尤其是在计算统计量(如海森矩阵)时。
  • 1
  • 2
  • 3

快速近似重构

 """
    实现快速近似重构、自适应掩码选择等sparsegpt中最核心的算法
    sparsity:稀疏度
    prunen : prunem   n:m
    blocksize:块大小
    """
    def fasterprune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):
        W = self.layer.weight.data.clone()
        if isinstance(self.layer, nn.Conv2d):
            W = W.flatten(1)
        if isinstance(self.layer, transformers.Conv1D):
            W = W.t()
        W = W.float()

        if hasattr(self, 'quantizer'):      # 对W进行量化
            if not self.quantizer.ready():
                self.quantizer.find_params(W, weight=True)

        tick = time.time()

        # H.shape -> torch.Size([768, 768])
        H = self.H
        del self.H
        # 这行代码首先调用 torch.diag(H) 来提取张量 H 的对角线元素,然后通过 == 0 比较操作,检查这些对角线元素是否等于0。
        # 如果对角线元素等于0,则对应的结果为 True;否则为 False。这样,dead 成为一个布尔型张量,其每个元素表示 H 的对角线上相应元素是否为0。
        dead = torch.diag(H) == 0
        # 这行代码使用 dead 张量作为索引,来修改 H 张量。具体来说,它将 H 的对角线上那些原本为0的元素设置为1。
        # 这是一种常见的技术,用于避免数值计算中的除以零错误或提高数值稳定性。
        H[dead, dead] = 1
        # 这行代码同样使用 dead 张量作为索引,但这次是对 W 张量进行操作。它将 W 中所有与 dead 中为 True 的列对应的元素设置为0。
        # 这意味着如果 H 的某个对角线元素原本是0(即 dead 中相应元素为 True),那么 W 中对应的整列都会被设置为0。
        W[:, dead] = 0

        # Losses.shape -> torch.Size([768])
        Losses = torch.zeros(self.rows, device=self.dev)

        # 对于较小的模型,采用dampening ,即在H的对角线元素上添加一个小常数λ(我们总是选择平均对角线值的 1%),似乎足以避免数值问题。 percdamp: 0.01
        # damp -> tensor(3.6504, device='cuda:0')
        damp = percdamp * torch.mean(torch.diag(H))
        # 这个张量用于海森矩阵H的对角线元素     self.columns -> 768
        diag = torch.arange(self.columns, device=self.dev)
        # 在H的对角线元素上添加dampening
        H[diag, diag] += damp
        # 这行代码执行Cholesky分解。Cholesky分解是一种将正定矩阵分解为一个下三角矩阵和其转置的上三角矩阵的乘积的方法。
        H = torch.linalg.cholesky(H)
        # 这行代码计算Cholesky分解后的矩阵 H 的逆。torch.cholesky_inverse 是一种高效计算逆矩阵的方法,特别是当矩阵已经通过Cholesky分解时。
        H = torch.cholesky_inverse(H)       # 此时为LL^T
        # 这行代码再次执行Cholesky分解,但这次是生成上三角矩阵。参数 upper=True 指定了生成的是上三角矩阵。
        H = torch.linalg.cholesky(H, upper=True)    # 此时为L^TL
        # 这行代码将 H 赋值给 Hinv。这里 Hinv 可能代表了 H 的逆矩阵。
        Hinv = H

        mask = None

        for i1 in range(0, self.columns, blocksize):
            i2 = min(i1 + blocksize, self.columns)
            count = i2 - i1

            W1 = W[:, i1:i2].clone()
            Q1 = torch.zeros_like(W1)
            Err1 = torch.zeros_like(W1)
            Losses1 = torch.zeros_like(W1)
            Hinv1 = Hinv[i1:i2, i1:i2]

            if prunen == 0:     # 如果 prunen 等于0,意味着需要进行剪枝操作。
                if mask is not None:
                    mask1 = mask[:, i1:i2]
                else:
                    # tmp是伪代码中的W
                    tmp = W1 ** 2 / (torch.diag(Hinv1).reshape((1, -1))) ** 2
                    # 根据稀疏度sparsity求出阈值
                    threshold = torch.sort(tmp.flatten())[0][int(tmp.numel() * sparsity)]
                    # 根据阈值,求出W所对应的掩码矩阵mask
                    mask1 = tmp <= threshold
            else:   # 如果 prunen 不等于0,这意味着不需要进行剪枝操作。
                # 这行代码创建了一个与 W1 形状相同的零矩阵,然后检查每个元素是否等于1。
                # 由于零矩阵中的所有元素都是0,这将产生一个全为 False 的掩码矩阵。这意味着没有权重会被剪枝。
                mask1 = torch.zeros_like(W1) == 1

            for i in range(count):
                w = W1[:, i]
                d = Hinv1[i, i]

                if prunen != 0 and i % prunem == 0:
                    tmp = W1[:, i:(i + prunem)] ** 2 / (torch.diag(Hinv1)[i:(i + prunem)].reshape((1, -1))) ** 2
                    mask1.scatter_(1, i + torch.topk(tmp, prunen, dim=1, largest=False)[1], True)

                q = w.clone()
                q[mask1[:, i]] = 0

                if hasattr(self, 'quantizer'):
                    q = quantize(q.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq).flatten()

                Q1[:, i] = q
                Losses1[:, i] = (w - q) ** 2 / d ** 2

                # (w - q) / d它的shape其实是一维的,就是一行有x个元素的序列,因此err1.shape -> torch.Size([x])
                err1 = (w - q) / d
                # 使用err1.unsqueeze(1)将这个只有一维的序列升级为二维数组,因此此时err1.shape -> torch.Size([x, 1])
                # Hinv1[i, i:]得到的其实是一维的,就是一行有y个元素的序列,那么使用unsqueeze(0)升级为二维数组,因此此时shape -> torch.Size([1, y])
                # (x, 1) * (1, y)维度,显然就是可以进行矩阵相乘了
                W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
                Err1[:, i] = err1

            W[:, i1:i2] = Q1
            Losses += torch.sum(Losses1, 1) / 2

            W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

            if DEBUG:
                self.layer.weight.data[:, :i2] = W[:, :i2]
                self.layer.weight.data[:, i2:] = W[:, i2:]
                print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
                print(torch.sum(Losses))

        torch.cuda.synchronize()
        print('time %.2f' % (time.time() - tick))
        print('error', torch.sum(Losses).item())

        if isinstance(self.layer, transformers.Conv1D):
            W = W.t()
        self.layer.weight.data = W.reshape(self.layer.weight.shape).to(self.layer.weight.data.dtype)

        if DEBUG:
            print(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125

先计算得到H的逆矩阵Hinv:
在这里插入图片描述

我们重点来看那两个for循环。

先看外层循环for i1 in range(0, self.columns, blocksize),如下图所示:
在这里插入图片描述

它的总体大致变化过程如下图所示,注意关注变化部分:
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

然后再来看第二个内层for循环for i in range(count),注意关注变化的部分:
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

如何理解下面这段代码呢?
在这里插入图片描述

如何理解下面这段代码呢?
在这里插入图片描述

如何理解下面的代码呢?
在这里插入图片描述

如何理解下面这段代码呢?
在这里插入图片描述

至此以上就大体讲述了sparsegpt算法的核心操作流程,先熟悉模型结构,然后再根据图示就会很容易理解代码。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/人工智能uu/article/detail/868150
推荐阅读
相关标签
  

闽ICP备14008679号