当前位置:   article > 正文

掌握Pytorch模型 压缩 裁剪与量化_权重裁剪代码实现

权重裁剪代码实现

在深度学习模型的搭建和部署中,我们需要考虑到模型的权重个数、模型权重大小、模型推理速度和计算量。本文将分享在Pytorch中进行模型压缩、裁剪和量化的教程。

权重压缩

模型在训练时使用的模型权重类型为float32,而在模型部署时则不需要高的数据精度。可以将类型转换为float16进行保存,这样可以降低45%左右的权重大小。

  • 步骤1:训练并保存模型
import timm
model = timm.create_model('mobilevit_xxs', pretrained=False, num_classes=8)
model.load_state_dict(torch.load('model_mobilevit_xxs.pth'))
  • 1
  • 2
  • 3
  • 步骤2:转换数据类型并存储
params = torch.load('model_mobilevit_xxs.pth') # float32
for key in params.keys():
    params[key] = params[key].half() # float16

torch.save(params, 'model_mobilevit_xxs_half.pth')
  • 1
  • 2
  • 3
  • 4
  • 5

权重裁剪

在模型训练完成后可以考虑对冗余的权重进行裁剪,有以下几种裁剪方法:

  • 按照比例随机裁剪
  • 按照权重大小裁剪

https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

使用的案例代码如下:

import torch.nn.utils.prune as prune
import numpy as np

model = timm.create_model('mobilevit_xxs', pretrained=False, num_classes=8)
model.load_state_dict(torch.load('model_mobilevit_xxs.pth'))

# 选中需要裁剪的层
module = model.head.fc

# random_unstructured裁剪
prune.random_unstructured(module, name="weight", amount=0.3)

# l1_unstructured裁剪
prune.l1_unstructured(module, name="weight", amount=0.3)

# ln_structured裁剪
prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

在使用权重裁剪需要注意:

  • 权重裁剪并不会改变模型的权重大小,只是增加了稀疏性;
  • 权重裁剪并不会减少模型的预测速度,只是减少了计算量;
  • 权重裁剪的参数比例会对模型精度有影响,需要测试和验证;

权重量化

32-bit的乘加变成了8-bit的乘加,模型权重大小减少,对内存的要求降低了。

https://pytorch.org/docs/stable/quantization.html

Eager Mode Quantization

import torch

# define a floating point model
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        self.fc1 = torch.nn.Linear(100, 40)
        self.fc2 = torch.nn.Linear(1000, 400)

    def forward(self, x):
        x = self.fc1(x)
        return x

# create a model instance
model_fp32 = M()
torch.save(model_fp32.state_dict(), 'tmp_float32.pth')

# create a quantized model instance
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32,  # the original model
    {torch.nn.Linear},  # a set of layers to dynamically quantize
    dtype=torch.qint8)  # the target dtype for quantized weights

# run the model
input_fp32 = torch.randn(4, 4, 4, 4)
res = model_int8(input_fp32)
torch.save(model_int8.state_dict(), 'tmp_int8.pth')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27

Post Training Static Quantization

import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super(M, self).__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 100, 1)
        self.relu = torch.nn.ReLU()
        self.fc = torch.nn.Linear(100, 10)
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()
torch.save(model_fp32.state_dict(), 'tmp_float32.pth')

model_fp32.eval()

model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')

model_fp32_fused = torch.quantization.fuse_modules(model_fp32, [['conv', 'relu']])
model_fp32_prepared = torch.quantization.prepare(model_fp32_fused)

input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

model_int8 = torch.quantization.convert(model_fp32_prepared)
res = model_int8(input_fp32)
torch.save(model_int8.state_dict(), 'tmp_int8.pth')
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42

Pytorch暂时的量化操作还不是很完善,可能存在只能在CPU上运行,且速度变慢的情况。如果有量化需求,推荐使用tensorrt和GPU一起使用。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/393173
推荐阅读
相关标签
  

闽ICP备14008679号