当前位置:   article > 正文

tesnortt_c++_c++ tensorrt推理比python慢很多

c++ tensorrt推理比python慢很多

tensorrt_c++_api 加速推理

上一篇写了tensorrt python 加速,这一篇是关于c++版本的加速,但是由于在预处理没有找到c++ PIL的实现,在精度上复现不了python版本,并且速度也没有python的快,但是将模型做成服务的话,由于python的并发没有c++版本的快,所以在部署时还是使用c++更合适

#pragma once
#include <iostream>
#include <fstream>
#include <sstream>
#include "opencv2/opencv.hpp"
#include "cuda_runtime.h"
#include "torch/script.h"
#include "torch/torch.h"
#include "torch/cuda.h"
#include "NvInfer.h"
#include "NvOnnxParser.h"
#include "filesystem"
#include <ctime>


#define INPUT_CHANNEL 3
#define IMAGE_WIDTH 224
#define IMAGE_HEIGHT 224

// 实例化记录器,用来捕捉警告信息,并且忽略信息留言
class Logger : public nvinfer1::ILogger{
    void log(nvinfer1::ILogger::Severity severity, const char* msg) noexcept override{
        //忽略以下级别信息
        if (severity <= nvinfer1::ILogger::Severity::kWARNING){
            std::cout << msg << std::endl;
        }
    }

} logger;

//遍历文件夹图片
void getImageFiles(std::vector<std::string>& fileLists,std::string testPath) {
    std::filesystem::path path(testPath);
    assert(std::filesystem::exists(path));
    std::filesystem::directory_iterator files(path);
    for (auto& file : files) {
        if (!cv::imread(file.path()).empty())
            fileLists.push_back(file.path());
    }
}

void prepareImage(cv::Mat &vec_img, float* inputData) {
    // std::vector<float> img_mean{0.485, 0.456, 0.406};
    // std::vector<float> img_std{0.229, 0.224, 0.225};
    std::vector<float> result;

    if (!vec_img.data)
        std::cout << "error" <<std::endl;
    cv::Mat rsz_img, flt_img;
    cv::cvtColor(vec_img,rsz_img,cv::COLOR_BGR2RGB);
    cv::resize(vec_img, rsz_img, cv::Size(IMAGE_WIDTH, IMAGE_HEIGHT));

    for(int i = 0; i < 224; ++i){
        for(int j = 0; j < 224; ++j){
            std::cout << rsz_img.at<cv::Vec3b>(i,j) << std::endl;
        }
    }


    torch::Tensor img_tensor = torch::from_blob(rsz_img.data, { rsz_img.rows, rsz_img.cols, 3 }, torch::kByte);
    img_tensor = img_tensor.permute({ 2, 0, 1 });
    img_tensor = img_tensor.to(torch::kF32);
    img_tensor = img_tensor.div(255);
    img_tensor = img_tensor.unsqueeze(0);
    img_tensor[0][0] = img_tensor[0][0].sub_(0.5).div_(0.5);
    img_tensor[0][1] = img_tensor[0][1].sub_(0.5).div_(0.5);
    img_tensor[0][2] = img_tensor[0][2].sub_(0.5).div_(0.5);


    auto imgTensor = img_tensor.accessor<float, 4>();
    // 将vector换成数组
    for(int channel = 0; channel < 3;++channel){
        for(int left = 0; left < 224; ++left){
            for(int right = 0; right < 224; ++right){
                // result.emplace_back(imgTensor[0][channel][left][right]);
                inputData[channel*224*224 + left*224 + right] = imgTensor[0][channel][left][right];
                // std::cout << imgTensor[0][channel][left][right] << std::endl;
            }   
        }
    }
}

int64_t volume(const nvinfer1::Dims& d)
{
    return std::accumulate(d.d, d.d + d.nbDims, 1, std::multiplies<int64_t>());
}

unsigned int getElementSize(nvinfer1::DataType t)
{
    switch (t)
    {
        case nvinfer1::DataType::kINT32: return 4;
        case nvinfer1::DataType::kFLOAT: return 4;
        case nvinfer1::DataType::kHALF: return 2;
        case nvinfer1::DataType::kBOOL:
        case nvinfer1::DataType::kINT8: return 1;
    }
    throw std::runtime_error("Invalid DataType.");
    return 0;
}

int returnMax(float a[]){
    int length = 10;
    float temp;
    int flag = 0;
    for(int i = 1; i < length;++i){
        if(temp < a[i]){
            temp = a[i];
            flag = i;
        }
    }
    return flag;
}

int getIndex(std::vector<std::string> class_, std::string str){
    std::vector<std::string>::iterator begin = class_.begin();
    for(int i = 0; i < class_.size(); ++i){
        if (class_[i] == str){
            return i;
        }
    }
}


int getResult(std::string image_path, nvinfer1::ICudaEngine* engine, nvinfer1::IExecutionContext* context){
    // 读取图片图片
    cv::Mat image = cv::imread(image_path);
    assert(!image.empty());
    // 图片预处理
    float a[3*224*224];
    prepareImage(image, a);


    // 预测图片
    void *buffers[2];
    std::vector<int64_t> bufferSize;
    int nbindings = engine->getNbBindings();
    bufferSize.resize(nbindings);

    for(int i = 0;i < nbindings; ++i){
        nvinfer1::Dims dims = engine->getBindingDimensions(i);
        nvinfer1::DataType dtype = engine->getBindingDataType(i);
        int64_t totalSize = volume(dims) * 1 * getElementSize(dtype);
        // std::cout << i << " : " << totalSize << std::endl;
        bufferSize[i] = totalSize;
        cudaMalloc(&buffers[i], totalSize);
    }


    cudaStream_t stream;
    cudaStreamCreate(&stream);

    int outSize = bufferSize[1] / sizeof(float);

    cudaMemcpyAsync(buffers[0],&a, bufferSize[0],cudaMemcpyHostToDevice,stream);
    //注意:下面方法线程不安全
    context->execute(1, buffers);

    float out[outSize];
    cudaMemcpyAsync(out, buffers[1], bufferSize[1], cudaMemcpyDeviceToHost, stream);
    cudaStreamSynchronize(stream);
    cudaFree(buffers[0]);
    cudaFree(buffers[1]);
    cudaStreamDestroy(stream);
    return returnMax(out);
}

int main(){
    // 模型路径
    std::string model_path = "/data/kile/other/Inception/mobile_net/onnx_/mobilev2_onnx2.trt";
    // 定义文件流
    std::ifstream inFile(model_path, std::ios_base::in|std::ios_base::binary);
    std::string cached_engine = "";
    while(inFile.peek() != EOF){
        std::stringstream buffer;
        buffer << inFile.rdbuf();
        cached_engine.append(buffer.str());
    }
    inFile.close();
    // 反序列化模型
    nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger);
    // 从内存中加载模型获得引擎
    nvinfer1::ICudaEngine* engine = runtime->deserializeCudaEngine(cached_engine.data(), cached_engine.size(), nullptr);

    // 开始推理
    // 创建推理上下文
    nvinfer1::IExecutionContext* context = engine->createExecutionContext();
    // 图片路径
    // std::string image_path = "/data/kile/other/Inception/mobile_net/dataset/test_one/airplane/airplane_3.jpg";
    // class类别
    std::vector<std::string> class_ = {"airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"};
    int correct = 0;
    int total = 0;
    // 计时
    clock_t start = clock();
    
    
    for (auto str:class_){
        int classId = getIndex(class_, str);
        std::cout << classId << std::endl;
        // 遍历图片
        std::vector<std::string> fileLists;
        // getImageFiles(fileLists, "/data/kile/other/Inception/mobile_net/dataset/test_data/"+str);
        getImageFiles(fileLists, "/data/kile/other/Inception/mobile_net/dataset/test_one/bird");
        for(auto filePath:fileLists){
            int predictClassId = getResult(filePath, engine, context);
            total += 1;
            std::cout << str << " " << class_[predictClassId] << " " << correct << std::endl;
            if(str == class_[predictClassId]){
                correct += 1;
            }
        }
    }
    clock_t end = clock();
    std::cout << "correct :" << (float)correct / total << "time" << (float)(end-start)  << std::endl; 
    return 0;    
}


  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop】
推荐阅读
相关标签
  

闽ICP备14008679号