当前位置:   article > 正文

7.5.tensorRT高级(2)-RAII接口模式下的生产者消费者多batch实现_tensorflow 生产消费者

tensorflow 生产消费者

前言

杜老师推出的 tensorRT从零起步高性能部署 课程,之前有看过一遍,但是没有做笔记,很多东西也忘了。这次重新撸一遍,顺便记记笔记。

本次课程学习 tensorRT 高级-RAII 接口模式下的生产者消费者多 batch 实现

课程大纲可看下面的思维导图

在这里插入图片描述

1. RAII接口模式封装生产者消费者

这节课我们利用上节课学到的 RAII + 接口模式对我们的消费者生产者进行封装

我们来看代码

infer.hpp

#ifndef INFER_HPP
#define INFER_HPP

#include <memory>
#include <string>
#include <future>

class InferInterface{
public:
    virtual std::shared_future<std::string> forward(std::string pic) = 0;
};

std::shared_ptr<InferInterface> create_infer(const std::string& file);

#endif // INFER_HPP
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

infer.cpp

#include "infer.hpp"
#include <thread>
#include <queue>
#include <mutex>
#include <future>

using namespace std;

struct Job{
    shared_ptr<promise<string>> pro;
    string input;
};

class InferImpl : public InferInterface{
public:

    virtual ~InferImpl(){
        worker_running_ = false;
        cv_notify_one();

        if(worker_thread_.joinable())
            worker_thread_.join();
    }

    bool load_model(const string& file){
        
        // 尽量保证资源哪里分配哪里释放,哪里使用,这样使得程序足够简单,而不是太乱
        // 线程内传回返回值的问题
        promise<bool> pro;
        worker_running_ = true;
        worker_thread_ = thread(&InferImpl::worker, this, file, std::ref(pro));
        return pro.get_future().get();

    }

     virtual shared_future<string> forward(string pic) override{

        // printf("使用 %s 进行推理\n", context_.c_str());
        // 往队列抛任务
        Job job;
        job.pro.reset(new promise<string>());
        job.input = pic;

        lock_guard<mutex> l(job_lock_);
        qjobs_.push(job);

        // 被动通知,一旦有新的任务需要推理,通知我即可
        // 发生通知的家伙
        cv_.notify_one();
        return job.pro->get_future();
    }

    // 实际执行模型推理的部分
    void worker(string file, promise<bool>& pro){
        // worker内实现,模型的加载,使用,释放
        string context = file;
        if(context.empty()){
            pro.set_value(false);
            return;
        }else{
            pro.set_value(true);
        }

        int max_batch_size = 5;
        vector<Job> jobs;
        int batch_id = 0;
        while(worker_running_){
            // 等待接受的家伙
            // 在队列取任务并执行的过程
            unique_lock<mutex> l(job_lock_);
            cv_.wait(job_lock_, [&](){
                // true 退出等待
                // false 继续等待
                return !qjobs_.empty() || !worker_running_;
            });

            // 程序发送终止信号
            if(!worker_running_)
                break;

            while(jobs.size() < max_batch_size && !qjobs_.empty()){
                jobs.emplace_back(qjobs_.front());
                qjobs.pop();
            }
            // 可以在这里一次拿一批出来,最大拿 maxbatchsize 个 job 进行一次性处理
            // jobs inference -> batch inference

            // 执行 batch 推理
            for(int i = 0; i < jobs.size(); ++i){
                
                auto& job = jobs[i];
                char result[100];
                sprintf(result, "%s : batch-> %d[%d]", job.input.c_str(), batch_id, jobs.size());
                
                job.pro->set_value(result);
            }
            batch_id++;
            jobs.clear();
            // 模拟推理耗时
            this_thread::sleep_for(chrono::milliseconds(1000));
        }
        // 释放模型
        printf("释放: %s\n", context.c_str());
        context.clear();
        printf("Worker done.\n");
    }
private:
    atomic<bool> worker_running_{false};
    thread worker_thread_;
    queue<Job> qjobs_;
    mutex job_lock_;
    condition_variable cv_;
};

shared_ptr<InferInterface> create_infer(const string& file){
    
    shared_ptr<InferImpl> instance(new Infer());
    if(!instance->load_model(file))
        instance.reset();
    return instance;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121

main.cpp

#include "infer.hpp"

int main(){
    
    auto infer = create_infer("a");
    if(infer == nullptr){
        printf("failed.\n");
        return -1;
    }

    // 串行
    // auto fa = infer->forward("A").get();
    // auto fb = infer->forward("B").get();
    // auto fc = infer->forward("C").get();
    // printf("%s\n", fa.c_str());
    // printf("%s\n", fb.c_str());
    // printf("%s\n", fc.c_str());
    
    // 并行
    auto fa = infer->forward("A");
    auto fb = infer->forward("B");
    auto fc = infer->forward("C");
    printf("%s\n", fa.get().c_str());
    printf("%s\n", fb.get().c_str());
    printf("%s\n", fc.get().c_str());    
    printf("Program done.\n");

    return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29

上述示例代码相对复杂,结合了 RAII 和接口模式来实现模拟模型推理,具体是一个消费者-生产者模式的异步批处理机制,我们来简单解读下 infer.cpp 中具体干了些啥(form chatGPT

1. 数据结构和类定义

  • Job 结构体:这是一个任务结构,包含了一个 promise 对象(用于在工作线程中设置结果)和输入数据,promise 又通过 shared_ptr 封装了一层,可以让结构体传递效率更高
  • InferImpl 类,这是 InferInterface 的实现类,包含了异步处理的核心逻辑

2. InferImpl 类的方法和成员

  • 析构函数:在对象销毁时,将 worker_running_ 标志设置为 false,并通过条件变量唤醒工作线程。然后等待工作线程结束
  • load_model 方法:模型加载函数,它实际上启动了工作线程,并传递了一个 promise 对象来设置是否成功加载了模型
  • forward 方法:这是暴露给使用者的接口,用于提交一个新的推理任务。这个方法将任务添加到队列中,并通过条件变量唤醒工作线程
  • worker 方法:这是工作线程的核心函数,它从队列中取出任务并批量处理它们,然后使用 promise 设置结果
  • 私有成员
    • worker_running_:一个原子布尔标志,表示工作线程是否正在运行
    • worker_thread_:工作线程对象
    • qjobs_:包含待处理任务的队列
    • job_lock_:保护任务队列的互斥锁
    • cv_:条件变量,用于在有新任务到来或工作线程需要停止时唤醒工作线程

3. 工厂函数

  • create_infer 函数:RAII 的体现,这个函数创建了一个 InferImpl 的实例,并尝试加载模型。如果加载失败,它将返回一个空的智能指针。

这个示例清晰地展示了如何使用 RAII 和接口模式来实现一个异步批处理机制,同时也展示了如何使用 C++11 的并发特性(如 threadpromisecondition_variable 等)来实现这种机制。

2. 问答环节

博主对多线程相关的知识不怎么了解,因此疯狂询问 chatGPT,故此做个记录方便下次查看,以下内容来自于博主和 chatGPT 之间的对话

问题1:work_running_ 为什么是 atomic<boll> 类型,为什么不直接使用 bool 类型?什么是 atomic<bool> 类型?

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/羊村懒王/article/detail/73342
推荐阅读
相关标签