当前位置:   article > 正文

TensorRT INT8 量化原理和实现_tensorrt int8量化

tensorrt int8量化

一、什么是模型量化

模型量化 = 模型 + 量化,两个词组成。

计算机视觉(深度学习)中,模型特指卷积神经网络,用于提取图像/视频特征。

量化:将信号的连续取值近似为有限多个离散值的过程,即信息压缩的方法。常规精度一般使用FP32(32位浮点,单精度)存储模型权重;低精度(FP16,半精度浮点);INT8(8位的定点整数)等数字格式。目前,低精度往往指代INT8,因此也称为“定点化(定点化scale为2的幂次方线性量化,是一种更加实用的量化方法)”

简而言之,我们常说的模型量化就是将浮点存储(运算)转换为整形存储(运算)的一种模型压缩技术。

二、为什么要做模型量化?

现有的深度学习框架,如TensorFlow、pytorch,在训练深度神经网络时,往往都会使用FP32的数据精度来表示权值、偏置、激活值等。在深度学习模型提提高的同时,计算也越来越复杂,计算开销和内存需求逐渐增加。庞大的网络参数意味着更大的内存存储,而增长的浮点型计算次数意味着训练成本和计算时间的增长,这极大的限制了在资源受限社保,如手机、手环等社保的部署。

三、模型量化的目标是什么?

1.更小的模型尺寸;2.更低的运算功耗;3.更低的运存占用;4.更快的计算速度;5持平的推理精度

四、模型量化的必要条件

量化一定能加速计算嘛?非也。很多量化算法都无法带来实质性的加速。

理论计算峰值:单位是时钟周期内能完成的计算个数乘上芯片频率。

什么样的量化方法可以带来潜在、可落地的速度提升呢?1.量化数值的计算在部署硬件上的峰值性能更高;2.量化引入的额外计算(overhead)少。

已知提速概率较大的量化方法有如下三类:

1.二值化:其可以用简单的位运算来同时计算大量的数。该操作可以享受到SIMD(单指令多数据流)的加速收益;

2.线性量化:又可称为非对称。

3.对数量化。两个同底的幂指数相乘,等于其指数相加,降低了计算强度。同时加法也转化为索引计算。

五、模型量化的分类

5.1 线性量化和非线性量化

根据映射函数是否是线性可以分为两类,即线性量化和非线性量化,这里主要研究线性量化。

5.2 逐层量化、逐组量化和逐通道量化(依据量化的粒度进行区分)

逐层量化:以一个层为单位,整个layer的权重共用一组缩放因子 S 和偏移量 Z;

逐组量化:以组为单位,每个group使用一组 S 和 Z;

逐通道量化:以通道为单位,每个channel单独使用一组 S 和 Z。

当 group = 1 时,逐组量化 = 逐层量化;当 group = num_filters(即 DW 卷积)时,逐组量化=逐通道量化。

5.3 N 比特量化

根据存储一个权重元素所需的位数,可以将其分为 8bit量化、4bit量化、2bit和1bit 量化等

5.4 权重量化和权重激活量化
5.4.1 权重和激活的概念

其中滤波器就是权重,而输入和输出就是当前层和上一层的激活值,假设输入数据位[3,224,224],滤波器为[2, 3, 3, 3],使用如下公式可以计算得到输出维度为[2, 222, 222]

, 因此权重有2 *3 *3 *3=54个(不含偏置项),上一层的激活值有 3*244 *244=150528个,下一层的激活值有2*222*222=98568个,显然激活值的数量远大于权重。

5.4.2 权重量化和权重激活量化

权重量化:即仅仅需要对网络中的权重执行量化操作。由于网络的权重一般都保存下来了,因而我们可以提前根据权重获得相应的量化参数 S 和 Z, 而不需额外的校准数据集。一般来说,推理过程中,权重值的数量远小于激活值,仅仅对权重值的量化方法能带来的压缩力度和加速效果一般。

权重激活量化:即对网络中的权重和激活值都进行量化。

5.4.3 激活量化方法:根据激活值的量化方式分为在线量化和离线量化

在线量化:值激活值的 S和Z 在实际的推理过程中格局实际的激活值进行动态计算;

离线量化:即提前确定好激活值的 S 和 Z,需要小批量的一些校准数据集支撑。不须动态旧伞,所以速度较快。

通常使用以下三种方法来确定相关的量化参数:

指数平滑法:即将校准数据集送入模型,收集每个量化层的输出特征图,计算每个batch的S和Z,并通过指数平滑法来更新 S和Z;

直方图截断法:在计算量化参数 S和Z的过程中,由于有的特征图会出现偏离较远的奇异值,导致max非常大,所以可以通过直方图截断的形式截流;

KL散度校准法:通过计算KL散度(也称相对熵,用于描述两个分布之间的差异)来评估量化前后的两个分布之间存在的差异,搜索并选取KL散度最小的量化参数S和Z作为最终的结果。

5.5 训练时和训练后量化

训练后量化(Post-Training Quantization,PTQ),PTQ不需要再训练,因此是一种轻量级的量化方法。在大多数情况下,PTQ足以接近FP32性能的INT量化。

训练时量化(Quantization-Aware-Training,QAT),也叫量化感知训练,它可以获得高精度的低位量化,缺点是需要修改训练代码,并且反向传播过程对梯度的量化误差很大,容易出现不收敛。

六、量化的数学基础

6.1、定点数和浮点数

量化过程一般分为两部分:将模型从FP32转换为INT8, 以及使用INT8进行推理。

定点浮点都是数值的表示方式,他们的区别在于,将整数部分和小数部分分开的点,位于哪里。定点保留特定位数整数和小数,而浮点保留特定位数的有效数字和指数。

在指数级的内置数据类型中,定点是整数,浮点是二进制格式。一般来说,指令集层面的定点是连续的,因为它是整数,且两个领巾的可表示数字的间隙是1.而浮点代表实数,其数值间隙由指数确定,因而具有非常宽的值域。同时也可知浮点的数字间隙是不均匀的,在相同的指数范围内,可表示数值数量也想吐,且值越接近0就越准确。另外也可得知定点数数值与想要表示的真值是一致的,而浮点数数值与想要表示的真值是有偏差的(表4)。

6.2 线性量化(线性映射)
6.2.1 量化

TensorRT使用的就是线性量化,它可以用以下数学表达式来表示:、

$X_int = clip([X/S] + Z; -2**(b-1), 2**(b-1) - 1)$。其中 X表示原始的FP32数值;Z表示映射的零点ZeroPoint;S表示缩放因子Scale;【.】表示近视取证的数学函数们可以四舍五入、向上取整、向下取整等;X_int 表示的量化后的一个整数值。

clip函数如下:

根据参数 Z 是否为0可以将线性量化分为对称量化和非对称量化,TensoRT使用的是对称量化,即Z=0。

6.2.2 反量化

当Z=0时,X_min= -2**(b-1)* S, X_max=(2**(b-) - 1) * S)。

可以发现当S取大值时,可以扩大量化域,但同时,单个INT8数值可表示的FP32范围也变广了,因此INT8数值域FP32数值的物质(量化误差)会增大;而当S取小值时,量化误差虽然减小了,但是量化域也缩小了,被舍弃的参数会增多。

举个例子,假设Z=0,使用向下取整。

七、TensoRT INT8量化原理

7.1 TensorRT 是什么?

NVIDIA TensorRT 的核心是一个C++库,可促进对NVIDIA图形处理单元(GPU)的高性能推理。它指在与TensorFlow、pytorch等框架以互补的方式工作。它专门致力于在GPU上快速有效地运行已经训练好的网络,已生成结果。一些训练好的框架已经集成了TensorRT,因此可以将其用于框架内加速推理。

7.2 使用TensorRT INT8量化的前提

1.硬件上必须是Nvidia的显卡,并且计算能力大于等于6.1。Nvidia GPU的计算能力可以在这个网上找到。        CUDA GPU | NVIDIA Developer

7.3 INT8量化流程

卷积的公式如下:Y = W*X + B

其中 X 是上一层的输出,即原始输入或者上一层的激活值;W 是当前层的权重; b 是当前层的偏置; Y是当前层的输出,即当前层的激活值。TensorRT的官方文档告知量化过程中偏置项可以忽略不计,即 Y = W * X

出去 bias 后,整个量化工程精简后如下

1.通过线性映射的方式将激活值和权重从FP32转化为INT8;

2.执行卷积层运算得到INT32位激活值,如果直接使用INT8保存的话会造成过多累计损失;

3.通过再量化的方式转换回INT8作为下一层的输入;

4.当网络为最后一层时,使用反量化转换回FP32。

整个过程的关键部分就是FP32至INT8的量化、INT32至INT8的再量化以及INT8至FP32的反量化,也就是前面所说的线性量化(线性映射)

7.4 INT校准
7.4.1 为什么需要校准?

首先需要明确的是,需要INT8校准的前提是使用到了激活量化。三个主要原因如下:

1.网络的激活值不会保存在网络参数中,属于运行过程中产生的值,因此我们难以与先确定它的范围;

2.当S取大时,可以扩大量化域,但同时,单个INT8数值可表示的FP32范围也变广了,因此INT8数值与FP32数值的误差会增大;当S取小时,量化误差减小,同时量化域也缩小了,被舍弃的参数会增多。

3. 为什么对于不同模型都可行?如果有些模型缩小量化域导致的精度下降更明显,那么INT8量化后的精度是不是比如有大幅下降呢?

其实不一定,量化属于浮点数向定点数转换的过程,由于浮点数的可表示数值间隙密度不同,导致零点附近的浮点数可表示数值很多,大于2^31个,约等于可表示数质量的一半。因此,越是靠近零点的浮点数表示越准确,越是远离原点的位置越有可能是噪声,并且网络的权重和激活大多分布在零点附近,因此适当的缩小量化域能提升量化精度几乎是必然的。

7.4.2 INT8 校准目的

就是一种权衡。为了找到合适的缩放参数,是的量化后的INT8数值能更准确的表示出狼花钱的FP32数值,并且又不能舍弃太多远离零点的非噪声参数。

7.4.3  如何实现 INT8 校准

7.4.3.1 校准前激活分布

举个例子,我们使用同一批图片在不同模型上训练,然后从不同网络层中可以得到对应的激活值分布,如下图:

可以发现分布都不相同,那么如何选取最优的阈值呢?

这就需要一个定量的衡量指标:常用的手段是指数平滑法、直方图截断阀、KL散度校准法,TensorRT使用的是 KL散度校准法。

7.4.3.2 KL散度校准法原理

KL散度校准法也叫相对熵。KL公式如下:

其中 p表示真实分布;q 表示非真实分布、模型分布或 p的近似分布。可以发现

相对熵=交叉熵 - 信息熵。

交叉熵:用其来衡量在给定的真实分布下,使用非真实分布所指定的策略消除系统的不确定性所需要付出的努力的大小;

信息熵:随机变量分布的混乱程度或整个系统的不确定性,随机变量越混乱(无序性)或系统的不确定就越大,熵越大。当随机分布为均匀分布时,熵最大。

交叉熵一定大于等于信息熵。

相对熵搜那个:用来衡量真实分布与非真实分布的差异大小

7.4.3.3 具体校准流程

1.需要准备500-800张校准用的数据集(tensorrt官方推荐);

2.使用校准数据集在FP32精度的网络下推理,并搜集激活值的直方图;

3.不断调整阈值,并计算相对熵,得到最优解。

  1. Entropy Calibration-pseudocodeInput: FP32 histogram H with 2048 bins: bin[ 0 ], ..., bin[ 2047 ]
  2. For i in range( 128 ,2048 ):
  3. reference distribution_P = [bin[0], ..., bin[i-1]]
  4. outliers_count = sum( bin[i], bin[i+1], ... , bin[2047])
  5. reference_distribution_P[i-1] += outliers_count // take first 'i' bins from H
  6. P /= sum(P) // normalize distribution P
  7. // explained later
  8. candidate_distribution_Q = quantize [bin[0], ..., bin[i-1]] into 128 levels
  9. expand candidate_distribution_Q to'i’ bins // explained later
  10. Q /= sum(Q) // normalize distribution Q
  11. divergence[i]=KL_divergence(reference_distribution_P, candidate_distribution_Q)
  12. End For
  13. Find index'm' for which divergencel m lis minimal
  14. threshold=(m + 0.5 )*( width of a bin)
  15. """以上是校准的官方伪代码"""

1.将校准集下得到的直方图划分成2048个bins;

2.在[128, 2048]范围内循环执行以下 3 -5步骤;

3.将第i个bin后的所有数值累加到第i-1 个bin上,并对前i个bins归一化, 作为P分布(真实分布);

4.对P 量化得到 Q 并归一化;

5.计算 P与 Q的相对熵

6.得到最小相对熵的 I,阈值 T = (i+0.5) * bin的宽度

7.4.3.4 校准后数据分布

7.5 总结

- 一种自动化,五参数的FP32 到 INT 8 的转换方法;

- 通过最小化 KL 散度来选择量化的阈值;

- 量化后精度几乎持平,速度有很大提升。

八、 C++实现 TensorRT INT8 量化

8.1 程序流程

TensorRT在做的其实只有一件事,就是把不同框架下训练得到 模型转换成 Engine,然后使用Engine进行推理。这里支持的框架包括ONNX、TensorFlow,如图

主要流程如下图:

1.构建builder,并使用builder 构建network 用于存储模型信息;

2.使用NetWork 构建Parser用于从onnx文件中解析模型信息并回传给network;

3.使用builder 构建profile 用于设置动态维度,并从dnnamicBinding中获取动态维度信息;

4.构建Calibrator 用于校准模型,并使用BatchStream加载校准数据集;

5.使用Builder构建Config用于配置生成Engine的参数,包括Calibrator 和Profile;

6.Builder 使用network 中的模型信息和Config 中的参数生成Engine 一级校准参数calParameter;

7.通过BatchStream 加载待测试数据集并传入Engine, 输出最终结果 result。

其中,Calibrator BatchStream 两个类都是需要根据项目重写的

8.2 Calibrator

为了将校准数据集输入TensorRT,我们需要用到Int8Calibrator抽象类,TensorRT一共提供了四种Int8Calibrator(后面两种已弃用,未列出);

1.IEntropyCalibratorV2:最适合卷积网络CNN校准器,并且本文也是使用这个类实现的;

2.IMinMaxCalibrator:这适合自然语言处理 NLP 中;

IInt8Calibrator实现的功能也很简单:

1.getBatchSize: 获取校准过程中的 batchsize;

2.getBatch: 获取校准过程中的输入;

3.writeCalibrationCache:由于校准花费的时间较长,调用该函数将校准参数结果写入本地文件;

4.readCalibrationCache:读取保存在本地发校准参数文件,在生成Engine过程中会自动调用。

  1. """官方代码.
  2. 校准器类中没有直接实现getBatchSize 和getBatch,而是使用TBatchStream模板类实现,这就是BatchStream的作用
  3. """
  4. /*
  5. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  6. *
  7. * Licensed under the Apache License, Version 2.0 (the "License");
  8. * you may not use this file except in compliance with the License.
  9. * You may obtain a copy of the License at
  10. *
  11. * http://www.apache.org/licenses/LICENSE-2.0
  12. *
  13. * Unless required by applicable law or agreed to in writing, software
  14. * distributed under the License is distributed on an "AS IS" BASIS,
  15. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. * See the License for the specific language governing permissions and
  17. * limitations under the License.
  18. */
  19. #ifndef ENTROPY_CALIBRATOR_H
  20. #define ENTROPY_CALIBRATOR_H
  21. #include "BatchStream.h"
  22. #include "NvInfer.h"
  23. //! \class EntropyCalibratorImpl
  24. //!
  25. //! \brief Implements common functionality for Entropy calibrators.
  26. //!
  27. template <typename TBatchStream>
  28. class EntropyCalibratorImpl
  29. {
  30. public:
  31. EntropyCalibratorImpl(
  32. TBatchStream stream, int firstBatch, std::string networkName, const char* inputBlobName, bool readCache = true)
  33. : mStream{stream}
  34. , mCalibrationTableName("CalibrationTable" + networkName)
  35. , mInputBlobName(inputBlobName)
  36. , mReadCache(readCache)
  37. {
  38. nvinfer1::Dims dims = mStream.getDims();
  39. mInputCount = samplesCommon::volume(dims);
  40. CHECK(cudaMalloc(&mDeviceInput, mInputCount * sizeof(float)));
  41. mStream.reset(firstBatch);
  42. }
  43. virtual ~EntropyCalibratorImpl()
  44. {
  45. CHECK(cudaFree(mDeviceInput));
  46. }
  47. int getBatchSize() const noexcept
  48. {
  49. return mStream.getBatchSize();
  50. }
  51. bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept
  52. {
  53. if (!mStream.next())
  54. {
  55. return false;
  56. }
  57. CHECK(cudaMemcpy(mDeviceInput, mStream.getBatch(), mInputCount * sizeof(float), cudaMemcpyHostToDevice));
  58. ASSERT(!strcmp(names[0], mInputBlobName));
  59. bindings[0] = mDeviceInput;
  60. return true;
  61. }
  62. const void* readCalibrationCache(size_t& length) noexcept
  63. {
  64. mCalibrationCache.clear();
  65. std::ifstream input(mCalibrationTableName, std::ios::binary);
  66. input >> std::noskipws;
  67. if (mReadCache && input.good())
  68. {
  69. std::copy(std::istream_iterator<char>(input), std::istream_iterator<char>(),
  70. std::back_inserter(mCalibrationCache));
  71. }
  72. length = mCalibrationCache.size();
  73. return length ? mCalibrationCache.data() : nullptr;
  74. }
  75. void writeCalibrationCache(const void* cache, size_t length) noexcept
  76. {
  77. std::ofstream output(mCalibrationTableName, std::ios::binary);
  78. output.write(reinterpret_cast<const char*>(cache), length);
  79. }
  80. private:
  81. TBatchStream mStream;
  82. size_t mInputCount;
  83. std::string mCalibrationTableName;
  84. const char* mInputBlobName;
  85. bool mReadCache{true};
  86. void* mDeviceInput{nullptr};
  87. std::vector<char> mCalibrationCache;
  88. };
  89. //! \class Int8EntropyCalibrator2
  90. //!
  91. //! \brief Implements Entropy calibrator 2.
  92. //! CalibrationAlgoType is kENTROPY_CALIBRATION_2.
  93. //!
  94. template <typename TBatchStream>
  95. class Int8EntropyCalibrator2 : public IInt8EntropyCalibrator2
  96. {
  97. public:
  98. Int8EntropyCalibrator2(
  99. TBatchStream stream, int firstBatch, const char* networkName, const char* inputBlobName, bool readCache = true)
  100. : mImpl(stream, firstBatch, networkName, inputBlobName, readCache)
  101. {
  102. }
  103. int getBatchSize() const noexcept override
  104. {
  105. return mImpl.getBatchSize();
  106. }
  107. bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override
  108. {
  109. return mImpl.getBatch(bindings, names, nbBindings);
  110. }
  111. const void* readCalibrationCache(size_t& length) noexcept override
  112. {
  113. return mImpl.readCalibrationCache(length);
  114. }
  115. void writeCalibrationCache(const void* cache, size_t length) noexcept override
  116. {
  117. mImpl.writeCalibrationCache(cache, length);
  118. }
  119. private:
  120. EntropyCalibratorImpl<TBatchStream> mImpl;
  121. };
  122. #endif // ENTROPY_CALIBRATOR_H
8.3 BatchStream

BatchStream类继承与IBatchStream,它实现的功能就是从给定的数据集中读取数据和标签,实现预处理并能按要求的BatchSize遍历数据和标签,具体如下:

1.reset:设置其实的 Batch 索引;

2.next:索引+1,准确读取下一个batch,直到数据遍历完成;

3.skip:跳转到指定索引的 batch;

4.getBatch:获取当前索引的数据;

5.getLabels:获取当前索引的标签;

6.getBatchesRead:获取当前索引;

7.getBatchSize:获取 BatchSize

8.getDims:获取当前数据的维度;

9.readDataFile:读取数据集中的数据

10readLabelsFile:读取数据集中的标签。

九、量化效果测试

最后分别用Alexnet、Resnet50、VGG13进行测试量化后的效果

9.1 测试环境

GPU NVIDIA GeForce RTX 3060;CUDA 11.8 ; CUDNN 8.7.6; TensorRT 8.6.5

9.2 总结

1.FP32-FP16及FP16-INT8 转换均能减少约50%Engine尺寸,并能有效降低运算功耗和显存占用

2.从FP32-INT8可大幅提升推理速度,且与模型FLOPS成正比,但从FP16-INT8只能提高约2倍;

3.INT8量化后准确度相比FP32几乎没有下降,但随校准数据集增大略微下降;

4.INT8量化后推理速度随BatchSize 增大而增大,但受显卡限制

十、附录

  1. /*
  2. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. //!
  17. //! SampleINT8.cpp
  18. //! This file contains the implementation of the sample. It creates the network using
  19. //! the caffe model.
  20. //! It can be run with the following command line:
  21. //! Command: ./sample_int8 [-h or --help] [-d=/path/to/data/dir or --datadir=/path/to/data/dir]
  22. //!
  23. #include "BatchStream.h"
  24. #include "EntropyCalibrator.h"
  25. #include "argsParser.h"
  26. #include "buffers.h"
  27. #include "common.h"
  28. #include "logger.h"
  29. #include "NvCaffeParser.h"
  30. #include "NvInfer.h"
  31. #include <cuda_runtime_api.h>
  32. #include <cstdlib>
  33. #include <fstream>
  34. #include <iostream>
  35. #include <sstream>
  36. using samplesCommon::SampleUniquePtr;
  37. const std::string gSampleName = "TensorRT.sample_int8";
  38. //!
  39. //! \brief The SampleINT8Params structure groups the additional parameters required by
  40. //! the INT8 sample.
  41. //!
  42. struct SampleINT8Params : public samplesCommon::CaffeSampleParams
  43. {
  44. int nbCalBatches; //!< The number of batches for calibration
  45. int calBatchSize; //!< The calibration batch size
  46. std::string networkName; //!< The name of the network
  47. };
  48. //! \brief The SampleINT8 class implements the INT8 sample
  49. //!
  50. //! \details It creates the network using a caffe model
  51. //!
  52. class SampleINT8
  53. {
  54. public:
  55. SampleINT8(const SampleINT8Params& params)
  56. : mParams(params)
  57. , mEngine(nullptr)
  58. {
  59. initLibNvInferPlugins(&sample::gLogger.getTRTLogger(), "");
  60. }
  61. //!
  62. //! \brief Function builds the network engine
  63. //!
  64. bool build(DataType dataType);
  65. //!
  66. //! \brief Runs the TensorRT inference engine for this sample
  67. //!
  68. bool infer(std::vector<float>& score, int firstScoreBatch, int nbScoreBatches);
  69. //!
  70. //! \brief Cleans up any state created in the sample class
  71. //!
  72. bool teardown();
  73. private:
  74. SampleINT8Params mParams; //!< The parameters for the sample.
  75. nvinfer1::Dims mInputDims; //!< The dimensions of the input to the network.
  76. std::shared_ptr<nvinfer1::ICudaEngine> mEngine; //!< The TensorRT engine used to run the network
  77. //!
  78. //! \brief Parses a Caffe model and creates a TensorRT network
  79. //!
  80. bool constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
  81. SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
  82. SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser, DataType dataType);
  83. //!
  84. //! \brief Reads the input and stores it in a managed buffer
  85. //!
  86. bool processInput(const samplesCommon::BufferManager& buffers, const float* data);
  87. //!
  88. //! \brief Scores model
  89. //!
  90. int calculateScore(
  91. const samplesCommon::BufferManager& buffers, float* labels, int batchSize, int outputSize, int threshold);
  92. };
  93. //!
  94. //! \brief Creates the network, configures the builder and creates the network engine
  95. //!
  96. //! \details This function creates the network by parsing the caffe model and builds
  97. //! the engine that will be used to run the model (mEngine)
  98. //!
  99. //! \return Returns true if the engine was created successfully and false otherwise
  100. //!
  101. bool SampleINT8::build(DataType dataType)
  102. {
  103. auto builder = SampleUniquePtr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(sample::gLogger.getTRTLogger()));
  104. if (!builder)
  105. {
  106. return false;
  107. }
  108. if ((dataType == DataType::kINT8 && !builder->platformHasFastInt8())
  109. || (dataType == DataType::kHALF && !builder->platformHasFastFp16()))
  110. {
  111. return false;
  112. }
  113. auto network = SampleUniquePtr<nvinfer1::INetworkDefinition>(builder->createNetworkV2(0));
  114. if (!network)
  115. {
  116. return false;
  117. }
  118. auto config = SampleUniquePtr<nvinfer1::IBuilderConfig>(builder->createBuilderConfig());
  119. if (!config)
  120. {
  121. return false;
  122. }
  123. auto parser = SampleUniquePtr<nvcaffeparser1::ICaffeParser>(nvcaffeparser1::createCaffeParser());
  124. if (!parser)
  125. {
  126. return false;
  127. }
  128. auto constructed = constructNetwork(builder, network, config, parser, dataType);
  129. if (!constructed)
  130. {
  131. return false;
  132. }
  133. ASSERT(network->getNbInputs() == 1);
  134. mInputDims = network->getInput(0)->getDimensions();
  135. ASSERT(mInputDims.nbDims == 3);
  136. return true;
  137. }
  138. //!
  139. //! \brief Uses a caffe parser to create the network and marks the
  140. //! output layers
  141. //!
  142. //! \param network Pointer to the network that will be populated with the network
  143. //!
  144. //! \param builder Pointer to the engine builder
  145. //!
  146. bool SampleINT8::constructNetwork(SampleUniquePtr<nvinfer1::IBuilder>& builder,
  147. SampleUniquePtr<nvinfer1::INetworkDefinition>& network, SampleUniquePtr<nvinfer1::IBuilderConfig>& config,
  148. SampleUniquePtr<nvcaffeparser1::ICaffeParser>& parser, DataType dataType)
  149. {
  150. mEngine = nullptr;
  151. const nvcaffeparser1::IBlobNameToTensor* blobNameToTensor
  152. = parser->parse(locateFile(mParams.prototxtFileName, mParams.dataDirs).c_str(),
  153. locateFile(mParams.weightsFileName, mParams.dataDirs).c_str(), *network,
  154. dataType == DataType::kINT8 ? DataType::kFLOAT : dataType);
  155. for (auto& s : mParams.outputTensorNames)
  156. {
  157. network->markOutput(*blobNameToTensor->find(s.c_str()));
  158. }
  159. // Calibrator life time needs to last until after the engine is built.
  160. std::unique_ptr<IInt8Calibrator> calibrator;
  161. config->setAvgTimingIterations(1);
  162. config->setMinTimingIterations(1);
  163. config->setMaxWorkspaceSize(1_GiB);
  164. if (dataType == DataType::kHALF)
  165. {
  166. config->setFlag(BuilderFlag::kFP16);
  167. }
  168. if (dataType == DataType::kINT8)
  169. {
  170. config->setFlag(BuilderFlag::kINT8);
  171. }
  172. builder->setMaxBatchSize(mParams.batchSize);
  173. if (dataType == DataType::kINT8)
  174. {
  175. MNISTBatchStream calibrationStream(mParams.calBatchSize, mParams.nbCalBatches, "train-images-idx3-ubyte",
  176. "train-labels-idx1-ubyte", mParams.dataDirs);
  177. calibrator.reset(new Int8EntropyCalibrator2<MNISTBatchStream>(
  178. calibrationStream, 0, mParams.networkName.c_str(), mParams.inputTensorNames[0].c_str()));
  179. config->setInt8Calibrator(calibrator.get());
  180. }
  181. if (mParams.dlaCore >= 0)
  182. {
  183. samplesCommon::enableDLA(builder.get(), config.get(), mParams.dlaCore);
  184. if (mParams.batchSize > builder->getMaxDLABatchSize())
  185. {
  186. sample::gLogError << "Requested batch size " << mParams.batchSize
  187. << " is greater than the max DLA batch size of " << builder->getMaxDLABatchSize()
  188. << ". Reducing batch size accordingly." << std::endl;
  189. return false;
  190. }
  191. }
  192. // CUDA stream used for profiling by the builder.
  193. auto profileStream = samplesCommon::makeCudaStream();
  194. if (!profileStream)
  195. {
  196. return false;
  197. }
  198. config->setProfileStream(*profileStream);
  199. SampleUniquePtr<IHostMemory> plan{builder->buildSerializedNetwork(*network, *config)};
  200. if (!plan)
  201. {
  202. return false;
  203. }
  204. SampleUniquePtr<IRuntime> runtime{createInferRuntime(sample::gLogger.getTRTLogger())};
  205. if (!runtime)
  206. {
  207. return false;
  208. }
  209. mEngine = std::shared_ptr<nvinfer1::ICudaEngine>(
  210. runtime->deserializeCudaEngine(plan->data(), plan->size()), samplesCommon::InferDeleter());
  211. if (!mEngine)
  212. {
  213. return false;
  214. }
  215. return true;
  216. }
  217. //!
  218. //! \brief Runs the TensorRT inference engine for this sample
  219. //!
  220. //! \details This function is the main execution function of the sample. It allocates the buffer,
  221. //! sets inputs and executes the engine.
  222. //!
  223. bool SampleINT8::infer(std::vector<float>& score, int firstScoreBatch, int nbScoreBatches)
  224. {
  225. float ms{0.0f};
  226. // Create RAII buffer manager object
  227. samplesCommon::BufferManager buffers(mEngine, mParams.batchSize);
  228. auto context = SampleUniquePtr<nvinfer1::IExecutionContext>(mEngine->createExecutionContext());
  229. if (!context)
  230. {
  231. return false;
  232. }
  233. MNISTBatchStream batchStream(mParams.batchSize, nbScoreBatches + firstScoreBatch, "train-images-idx3-ubyte",
  234. "train-labels-idx1-ubyte", mParams.dataDirs);
  235. batchStream.skip(firstScoreBatch);
  236. Dims outputDims = context->getEngine().getBindingDimensions(
  237. context->getEngine().getBindingIndex(mParams.outputTensorNames[0].c_str()));
  238. int64_t outputSize = samplesCommon::volume(outputDims);
  239. int top1{0}, top5{0};
  240. float totalTime{0.0f};
  241. while (batchStream.next())
  242. {
  243. // Read the input data into the managed buffers
  244. ASSERT(mParams.inputTensorNames.size() == 1);
  245. if (!processInput(buffers, batchStream.getBatch()))
  246. {
  247. return false;
  248. }
  249. // Memcpy from host input buffers to device input buffers
  250. buffers.copyInputToDevice();
  251. cudaStream_t stream;
  252. CHECK(cudaStreamCreate(&stream));
  253. // Use CUDA events to measure inference time
  254. cudaEvent_t start, end;
  255. CHECK(cudaEventCreateWithFlags(&start, cudaEventBlockingSync));
  256. CHECK(cudaEventCreateWithFlags(&end, cudaEventBlockingSync));
  257. cudaEventRecord(start, stream);
  258. bool status = context->enqueue(mParams.batchSize, buffers.getDeviceBindings().data(), stream, nullptr);
  259. if (!status)
  260. {
  261. return false;
  262. }
  263. cudaEventRecord(end, stream);
  264. cudaEventSynchronize(end);
  265. cudaEventElapsedTime(&ms, start, end);
  266. cudaEventDestroy(start);
  267. cudaEventDestroy(end);
  268. totalTime += ms;
  269. // Memcpy from device output buffers to host output buffers
  270. buffers.copyOutputToHost();
  271. CHECK(cudaStreamDestroy(stream));
  272. top1 += calculateScore(buffers, batchStream.getLabels(), mParams.batchSize, outputSize, 1);
  273. top5 += calculateScore(buffers, batchStream.getLabels(), mParams.batchSize, outputSize, 5);
  274. if (batchStream.getBatchesRead() % 100 == 0)
  275. {
  276. sample::gLogInfo << "Processing next set of max 100 batches" << std::endl;
  277. }
  278. }
  279. int imagesRead = (batchStream.getBatchesRead() - firstScoreBatch) * mParams.batchSize;
  280. score[0] = float(top1) / float(imagesRead);
  281. score[1] = float(top5) / float(imagesRead);
  282. sample::gLogInfo << "Top1: " << score[0] << ", Top5: " << score[1] << std::endl;
  283. sample::gLogInfo << "Processing " << imagesRead << " images averaged " << totalTime / imagesRead << " ms/image and "
  284. << totalTime / batchStream.getBatchesRead() << " ms/batch." << std::endl;
  285. return true;
  286. }
  287. //!
  288. //! \brief Cleans up any state created in the sample class
  289. //!
  290. bool SampleINT8::teardown()
  291. {
  292. //! Clean up the libprotobuf files as the parsing is complete
  293. //! \note It is not safe to use any other part of the protocol buffers library after
  294. //! ShutdownProtobufLibrary() has been called.
  295. nvcaffeparser1::shutdownProtobufLibrary();
  296. return true;
  297. }
  298. //!
  299. //! \brief Reads the input and stores it in a managed buffer
  300. //!
  301. bool SampleINT8::processInput(const samplesCommon::BufferManager& buffers, const float* data)
  302. {
  303. // Fill data buffer
  304. float* hostDataBuffer = static_cast<float*>(buffers.getHostBuffer(mParams.inputTensorNames[0]));
  305. std::memcpy(hostDataBuffer, data, mParams.batchSize * samplesCommon::volume(mInputDims) * sizeof(float));
  306. return true;
  307. }
  308. //!
  309. //! \brief Scores model
  310. //!
  311. int SampleINT8::calculateScore(
  312. const samplesCommon::BufferManager& buffers, float* labels, int batchSize, int outputSize, int threshold)
  313. {
  314. float* probs = static_cast<float*>(buffers.getHostBuffer(mParams.outputTensorNames[0]));
  315. int success = 0;
  316. for (int i = 0; i < batchSize; i++)
  317. {
  318. float *prob = probs + outputSize * i, correct = prob[(int) labels[i]];
  319. int better = 0;
  320. for (int j = 0; j < outputSize; j++)
  321. {
  322. if (prob[j] >= correct)
  323. {
  324. better++;
  325. }
  326. }
  327. if (better <= threshold)
  328. {
  329. success++;
  330. }
  331. }
  332. return success;
  333. }
  334. //!
  335. //! \brief Initializes members of the params struct using the command line args
  336. //!
  337. SampleINT8Params initializeSampleParams(const samplesCommon::Args& args, int batchSize)
  338. {
  339. SampleINT8Params params;
  340. // Use directories provided by the user, in addition to default directories.
  341. params.dataDirs = args.dataDirs;
  342. params.dataDirs.emplace_back("data/mnist/");
  343. params.dataDirs.emplace_back("int8/mnist/");
  344. params.dataDirs.emplace_back("samples/mnist/");
  345. params.dataDirs.emplace_back("data/samples/mnist/");
  346. params.dataDirs.emplace_back("data/int8/mnist/");
  347. params.dataDirs.emplace_back("data/int8_samples/mnist/");
  348. params.batchSize = batchSize;
  349. params.dlaCore = args.useDLACore;
  350. params.nbCalBatches = 10;
  351. params.calBatchSize = 50;
  352. params.inputTensorNames.push_back("data");
  353. params.outputTensorNames.push_back("prob");
  354. params.prototxtFileName = "deploy.prototxt";
  355. params.weightsFileName = "mnist_lenet.caffemodel";
  356. params.networkName = "mnist";
  357. return params;
  358. }
  359. //!
  360. //! \brief Prints the help information for running this sample
  361. //!
  362. void printHelpInfo()
  363. {
  364. std::cout << "Usage: ./sample_int8 [-h or --help] [-d or --datadir=<path to data directory>] "
  365. "[--useDLACore=<int>]"
  366. << std::endl;
  367. std::cout << "--help, -h Display help information" << std::endl;
  368. std::cout << "--datadir Specify path to a data directory, overriding the default. This option can be used "
  369. "multiple times to add multiple directories."
  370. << std::endl;
  371. std::cout << "--useDLACore=N Specify a DLA engine for layers that support DLA. Value can range from 0 to n-1, "
  372. "where n is the number of DLA engines on the platform."
  373. << std::endl;
  374. std::cout << "batch=N Set batch size (default = 32)." << std::endl;
  375. std::cout << "start=N Set the first batch to be scored (default = 16). All batches before this batch will "
  376. "be used for calibration."
  377. << std::endl;
  378. std::cout << "score=N Set the number of batches to be scored (default = 1800)." << std::endl;
  379. }
  380. int main(int argc, char** argv)
  381. {
  382. if (argc >= 2 && (!strncmp(argv[1], "--help", 6) || !strncmp(argv[1], "-h", 2)))
  383. {
  384. printHelpInfo();
  385. return EXIT_SUCCESS;
  386. }
  387. // By default we score over 57600 images starting at 512, so we don't score those used to search calibration
  388. int batchSize = 32;
  389. int firstScoreBatch = 16;
  390. int nbScoreBatches = 1800;
  391. // Parse extra arguments
  392. for (int i = 1; i < argc; ++i)
  393. {
  394. if (!strncmp(argv[i], "batch=", 6))
  395. {
  396. batchSize = atoi(argv[i] + 6);
  397. }
  398. else if (!strncmp(argv[i], "start=", 6))
  399. {
  400. firstScoreBatch = atoi(argv[i] + 6);
  401. }
  402. else if (!strncmp(argv[i], "score=", 6))
  403. {
  404. nbScoreBatches = atoi(argv[i] + 6);
  405. }
  406. }
  407. if (batchSize > 128)
  408. {
  409. sample::gLogError << "Please provide batch size <= 128" << std::endl;
  410. return EXIT_FAILURE;
  411. }
  412. if ((firstScoreBatch + nbScoreBatches) * batchSize > 60000)
  413. {
  414. sample::gLogError << "Only 60000 images available" << std::endl;
  415. return EXIT_FAILURE;
  416. }
  417. samplesCommon::Args args;
  418. samplesCommon::parseArgs(args, argc, argv);
  419. SampleINT8 sample(initializeSampleParams(args, batchSize));
  420. auto sampleTest = sample::gLogger.defineTest(gSampleName, argc, argv);
  421. sample::gLogger.reportTestStart(sampleTest);
  422. sample::gLogInfo << "Building and running a GPU inference engine for INT8 sample" << std::endl;
  423. std::vector<std::string> dataTypeNames = {"FP32", "FP16", "INT8"};
  424. std::vector<std::string> topNames = {"Top1", "Top5"};
  425. std::vector<DataType> dataTypes = {DataType::kFLOAT, DataType::kHALF, DataType::kINT8};
  426. std::vector<std::vector<float>> scores(3, std::vector<float>(2, 0.0f));
  427. for (size_t i = 0; i < dataTypes.size(); i++)
  428. {
  429. sample::gLogInfo << dataTypeNames[i] << " run:" << nbScoreBatches << " batches of size " << batchSize
  430. << " starting at " << firstScoreBatch << std::endl;
  431. if (!sample.build(dataTypes[i]))
  432. {
  433. if (!samplesCommon::isDataTypeSupported(dataTypes[i]))
  434. {
  435. sample::gLogWarning << "Skipping " << dataTypeNames[i]
  436. << " since the platform does not support this data type." << std::endl;
  437. continue;
  438. }
  439. return sample::gLogger.reportFail(sampleTest);
  440. }
  441. if (!sample.infer(scores[i], firstScoreBatch, nbScoreBatches))
  442. {
  443. return sample::gLogger.reportFail(sampleTest);
  444. }
  445. }
  446. auto isApproximatelyEqual = [](float a, float b, double tolerance) { return (std::abs(a - b) <= tolerance); };
  447. const double tolerance{0.01};
  448. const double goldenMNIST{0.99};
  449. if ((scores[0][0] < goldenMNIST) || (scores[0][1] < goldenMNIST))
  450. {
  451. sample::gLogError << "FP32 accuracy is less than 99%: Top1 = " << scores[0][0] << ", Top5 = " << scores[0][1]
  452. << "." << std::endl;
  453. return sample::gLogger.reportFail(sampleTest);
  454. }
  455. for (unsigned i = 0; i < topNames.size(); i++)
  456. {
  457. for (unsigned j = 1; j < dataTypes.size(); j++)
  458. {
  459. if (scores[j][i] != 0.0f && !isApproximatelyEqual(scores[0][i], scores[j][i], tolerance))
  460. {
  461. sample::gLogError << "FP32(" << scores[0][i] << ") and " << dataTypeNames[j] << "(" << scores[j][i]
  462. << ") " << topNames[i] << " accuracy differ by more than " << tolerance << "."
  463. << std::endl;
  464. return sample::gLogger.reportFail(sampleTest);
  465. }
  466. }
  467. }
  468. if (!sample.teardown())
  469. {
  470. return sample::gLogger.reportFail(sampleTest);
  471. }
  472. return sample::gLogger.reportPass(sampleTest);
  473. }

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/一键难忘520/article/detail/901217
推荐阅读
相关标签
  

闽ICP备14008679号