当前位置:   article > 正文

PyTorch模型训练完毕后静态量化、保存、加载int8量化模型_pytorch模型int8量化

pytorch模型int8量化

1. PyTorch模型量化方法

Pytorch模型量化方法介绍有很多可以参考的,这里推荐两篇文章写的很详细可以给大家一个大致的参考Pytorch的量化官方量化文档

Pytorch的量化大致分为三种:模型训练完毕后动态量化、模型训练完毕后静态量化、模型训练中开启量化,本文从一个工程项目(Pose Estimation)给大家介绍模型训练后静态量化的过程。

具体量化知识可以从推荐的两篇文章中学习。

2. 量化过程准备工作。

代码运行环境:PyTorch1.9.0, Python3.6.4.

1.数据集下载(在做静态量化时需要对数据集进行推理获取数据的分布特点、定标),用MSCOCO的验证集,选100张左右MSCOCO_val2017

2.Pytorch模型文件可以从这里下载Pose_Model提取密码:s7qh.

3.量化代码下载Pytorch_Model_Quantization

 代码下载后如上图,把下载的MSCOC数据集选100张放在data目录,把下载的模型文件coco_pose_iter_440000.pth.tar放在models目录。

pth_to_int.py是对Pytorch的float32模型转成int8模型。

evaluate_model.py里加载int8模型进行推理。

3. 模型静态量化

模型静态量化主要代码如下,读取float32模型,然后转成int8模型保存为openpose_vgg_quant.pth。完整代码可以从pth_to_int.py文件中看到。具体每一步做什么工作在注释中详细说明了。

  1. # loading model
  2. state_dict = torch.load('./models/coco_pose_iter_440000.pth.tar')['state_dict']
  3. # create a model instance
  4. model_fp32 = get_pose_model()
  5. model_fp32.load_state_dict(state_dict)
  6. model_fp32.float()
  7. # model must be set to eval mode for static quantization logic to work
  8. model_fp32.eval()
  9. # attach a global qconfig, which contains information about what kind
  10. # of observers to attach. Use 'fbgemm' for server inference and
  11. # 'qnnpack' for mobile inference. Other quantization configurations such
  12. # as selecting symmetric or assymetric quantization and MinMax or L2Norm
  13. # calibration techniques can be specified here.
  14. model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
  15. # Prepare the model for static quantization. This inserts observers in
  16. # the model that will observe activation tensors during calibration.
  17. model_fp32_prepared = torch.quantization.prepare(model_fp32)
  18. # calibrate the prepared model to determine quantization parameters for activations
  19. # in a real world setting, the calibration would be done with a representative dataset
  20. evaluate(model_fp32_prepared)
  21. # Convert the observed model to a quantized model. This does several things:
  22. # quantizes the weights, computes and stores the scale and bias value to be
  23. # used with each activation tensor, and replaces key operators with quantized
  24. # implementations.
  25. model_int8 = torch.quantization.convert(model_fp32_prepared)
  26. print("model int8", model_int8)
  27. # save model
  28. torch.save(model_int8.state_dict(),"./openpose_vgg_quant.pth")

4. 量化模型加载进行推理

注意:量化后模型的forward代码稍有改动,需要在模型输入前后安插量化和解量化。如下示例:

  1. class Net(nn.Module):
  2. def __init__(self):
  3. # 对输入数据量化
  4. self.quant = torch.quantization.QuantStub()
  5. # model structure.
  6. layer = self.layer()
  7. # 对输出数据解量化
  8. self.dequant = torch.quantization.DeQuantStub()
  9. def forward(self,input):
  10. x = self.quant(input)
  11. x = self.layer(x)
  12. x = self.dequant(x)

量化和解量化在pose_estimation.py文件34-86行可以看到.

加载int8模型不能和之前加载float32模型一样,需要将模型通过prepare() , convert()操作转成量化模型,然后load_state_dict加载进模型。

  1. # Load int8 model
  2. state_dict = torch.load('./openpose_vgg_quant.pth')
  3. model_fp32 = get_pose_model()
  4. model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
  5. model_fp32_prepared = torch.quantization.prepare(model_fp32)
  6. model_int8 = torch.quantization.convert(model_fp32_prepared)
  7. model_int8.load_state_dict(state_dict)
  8. model = model_int8
  9. model.eval()

5. 性能

下图为量化后结果,整体来说损失不大。其中模型大小200M->50M,模型运行时间5.7s->3.4s。整体来说,模型大小压缩为原来的1/4, 模型运行时间减少20%左右

GitHub上有完整代码,可自行复现。欢迎一起讨论!!

我的Github, 我的个人博客

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小丑西瓜9/article/detail/393091
推荐阅读
相关标签
  

闽ICP备14008679号