当前位置:   article > 正文

C++多线程推理、生产者消费者模式封装_c++ 生产者消费者模型 多线程

c++ 生产者消费者模型 多线程

C++多线程推理、生产者消费者模式封装

tensorRT从零起步迈向高性能工业级部署(就业导向) 课程笔记,讲师讲的不错,可以去看原视频支持下。

深度学习推理中的多线程知识概览

  1. 本章介绍的多线程主要是指算法部署时所涉及的多线程内容,对于其他多线程知识需要自行补充
  2. 常用组件有 thread、mutex、future、condition_variable
  3. 启动线程,thread,以及 join、joinable、detach、类函数启动为线程
  4. 生产者-消费者模式
  5. 具体问题:队列溢出的问题:生产太快,消费太慢;如何实现溢出控制
  6. 具体问题:生产者如何拿到消费反馈
  7. RAII 思想的生产者-消费者模式封装,多 batch 的体现

thread、join、joinable、detach、常规/引用传参、类函数

#include <thread>
#include <stdio.h>

using namespace std;

void worker() {
	printf("Hello World.\n");
}

int main() {
	thread t(worker);
	// thread t;
	t.join();
	printf("Done.\n");

	return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

上面是一个最简单的 cpp 多线程的程序

  1. t.join() 等待线程结束,如果不加,就会在析构时提示异常,出现 core dumped,只要线程 t 启动了(如果只是声明 thread t; 不算启动),就必须要 join。

  2. 若 t 没有启动线程,如果 join ,也会 core dumped 异常;

  3. 根据以上两点,如果我们在某些条件下启动线程,某些条件下不启动,该怎么办呢? 用 joinable,如:

    if (t.joinable()) t.join();
    
    • 1
  4. detach 分离线程,取消管理权,使得线程称为野线程,不建议使用。野线程不需要 join,线程交给系统管理,程序退出后,所有线程才退出。

  5. 基本传参:

    void worker(int a) {
    	printf("Hello Thread, %d\n", a);
    }
    
    int main() {
    	thread t(worker, 12);
    	if (t.joinable()) t.join();
    	printf("Done.\n");
    
    	return 0;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
  6. 引用传参:

    void worker(string& s) {
    	printf("Hello Thread\n");
    	s = "reference string";
    }
    
    int main() {
    	string param;
    	thread t(worker, 12, std::ref(param));
    	// thread t(worker, 12, param); 错误的引用传参
    	if (t.joinable()) t.join();
    	printf("Done.\n");
    	cout << param << endl;
    
    	return 0;
    }
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15

    多线程的引用传参有两点需要注意:

    • 传入时需要使用 std::ref(param)
    • 注意引用变量的声明周期,如果在外面声明的引用变量传给子线程,而在子线程结束之前就在外面将变量释放掉了,则在子线程中可能引发错误
  7. 类的线程启动

    注释掉的方式是用类的静态方法的方式,不建议

    class Infer {
    public:
    	Infer() {
    		// worker_thread_ = thread(infer_worker, this);
    		worker_thread_ = thread(&Infer::infer_worker, this);
    	}
    
    private:
    	thread worker_thread_;
    
    	// static infer_worker(Infer* self) { /* ... */ }
    	void infer_worker() { /* ... */ }
    };
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13

图像处理的生产者消费者模式

首先看一个最简单的生产者消费者模式,两个线程分别执行 video_captureinfer_worker 两个函数来生产(获取)图片和推理图片。

其中 queue<string> qjobs_; 用于存储待处理的图片

#include <thread>
#include <queue>
#include <mutex>
#include <string>
#include <stdio.h>
#include <chrono>

using namespace std;

queue<string> qjobs_;
int get_image_time = 1000; // 先假设获取一张图片与推理一张图片都是一秒
int infer_image_time = 1000;

void video_capture() {
	int  pic_id = 0;
	while (true) {
		char name[100];
		sprintf(name, "PIC-%d", pic_id++);
		printf("生产了一张新图片: %s\n", name);
		qjobs_.push(name);
		this_thread::sleep_for(chrono::milliseconds(get_image_time));
	}
}

void infer_worker() {
	while (true) {
		if (!qjobs_.empty()) {
			auto pic = qjobs_.front();
			qjobs_.pop();
			printf("消费掉一张图片: %s\n", pic.c_str());
			this_thread::sleep_for(chrono::milliseconds(infer_image_time));
		}
		this_thread::yield(); // 没有要处理的图片,主动交出CPU,避免资源浪费
	}
}

int main() {
	thread t0(video_capture);
	thread t1(infer_worker);

	t0.join();
	t1.join();

	return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
基本问题

共享资源访问的问题

stl 中的 queue 队列不是 thread-safe 的,我们需要自己加锁来保证共享资源访问的安全性

只需要将访问共享变量的代码部分用锁保护起来即可:

mutex lock_;

void video_capture() {
	int  pic_id = 0;
	while (true) {
		{
			lock_guard<mutex> l(lock_);
			char name[100];
			sprintf(name, "PIC-%d", pic_id++);
			printf("生产了一张新图片: %s\n", name);
			qjobs_.push(name);
		}
		this_thread::sleep_for(chrono::milliseconds(get_image_time));
	}
}

void infer_worker() {
	while (true) {
		if (!qjobs_.empty()) {
			{
				lock_guard<mutex> l(lock_);
				auto pic = qjobs_.front();
				qjobs_.pop();
				printf("消费掉一张图片: %s\n", pic.c_str());
			}
			this_thread::sleep_for(chrono::milliseconds(infer_image_time));
		}
		this_thread::yield(); // 没有要处理的图片,主动交出CPU,避免资源浪费
	}
}

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
问题1

队列溢出的问题,生产太快,消费太慢;如何实现溢出控制

之前我们设定的是生产与消费均为一秒,但是若生产速率高于消费速率,则必然会出现队列堆积现象。

解决方法:使用条件变量 condation_variable :如果队列满了,就不生产,等待队列有空间,再生产,即我们要达成类似如下的逻辑:

if (qjobs_.size() < limit) wait();
qjobs_.push(name);
  • 1
  • 2

这就又有另一个问题,如何在队列有空间时,通知 wait() 函数停止等待,实际上这可以在消费者的函数中进行,因为当我们消费掉队列中的一张图片,队列肯定就有空间来存放新的图片了。

完整的加 wait 的代码:

#include <thread>
#include <queue>
#include <mutex>
#include <condition_variable>
#include <string>
#include <stdio.h>
#include <chrono>


using namespace std;

queue<string> qjobs_;
mutex lock_;
condition_variable cv_;
int get_image_time_ = 300; // 先假设获取一张图片与推理一张图片都是一秒
int infer_image_time_ = 1000;
const int limit_ = 5;

void video_capture() {
	int  pic_id = 0;
	while (true) {
		{
			unique_lock<mutex> l(lock_);
			char name[100];
			sprintf(name, "PIC-%d", pic_id++);
			printf("生产了一张新图片: %s, 当前队列大小: %d\n", name, (int)qjobs_.size());
			qjobs_.push(name);

			// condition_variable.wait(lock, predicate);
			// predicate 指定什么时候等待,什么时候停止等待
			cv_.wait(l, [&](){
					// return false 表示继续等待; return true 表示停止等待
					return qjobs_.size() <= limit_;
			});
		}
		this_thread::sleep_for(chrono::milliseconds(get_image_time_));
	}
}

void infer_worker() {
	while (true) {
		if (!qjobs_.empty()) {
			{
				lock_guard<mutex> l(lock_);
				auto pic = qjobs_.front();
				qjobs_.pop();
				printf("消费掉一张图片: %s\n", pic.c_str());
				// 消费掉一个,就可以通知wait,停止等待
				cv_.notify_one();
			}
			this_thread::sleep_for(chrono::milliseconds(infer_image_time_));
		}
		this_thread::yield(); // 没有要处理的图片,主动交出CPU,避免资源浪费
	}
}

int main() {
	thread t0(video_capture);
	thread t1(infer_worker);

	t0.join();
	t1.join();

	return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65

测试可以看到,在达到我们设置的队列上限之后,不会再一直生产新图片导致队列溢出:

生产了一张新图片: PIC-0, 当前队列大小: 0
消费掉一张图片: PIC-0
生产了一张新图片: PIC-1, 当前队列大小: 0
生产了一张新图片: PIC-2, 当前队列大小: 1
生产了一张新图片: PIC-3, 当前队列大小: 2
消费掉一张图片: PIC-1
生产了一张新图片: PIC-4, 当前队列大小: 2
生产了一张新图片: PIC-5, 当前队列大小: 3
生产了一张新图片: PIC-6, 当前队列大小: 4
消费掉一张图片: PIC-2
生产了一张新图片: PIC-7, 当前队列大小: 4
生产了一张新图片: PIC-8, 当前队列大小: 5
消费掉一张图片: PIC-3
生产了一张新图片: PIC-9, 当前队列大小: 5
消费掉一张图片: PIC-4
生产了一张新图片: PIC-10, 当前队列大小: 5
消费掉一张图片: PIC-5
生产了一张新图片: PIC-11, 当前队列大小: 5
消费掉一张图片: PIC-6
生产了一张新图片: PIC-12, 当前队列大小: 5
消费掉一张图片: PIC-7
生产了一张新图片: PIC-13, 当前队列大小: 5
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22

注意:一旦进入 wait() ,会自动释放锁;一旦退出 wait() ,会加锁

问题2

生产者如何拿到消费者的反馈

我们消费者将生产者的图片推理完成之后,肯定要将结果返回给生产者。比如在目标检测中,video_capture 将捕获到的图片交给消费者处理完之后,需要得到物体框的坐标,再将框画到原图上进行显示。那么这时,生产者应该如何拿到消费者的反馈呢?

这就要用到 promise 和 future,下面我们将 job 从单纯的 string 输入改为这样一个结构体:

struct Job {
	shared_ptr<promise<string>> pro;  // 返回结果,如果在目标检测的例子中就是框
	string input;  // 输入,图片
};
  • 1
  • 2
  • 3
  • 4

其中:

  • input:输入,还是输入,实际中可能是图片,这里还是用 string 代替
  • pro:指向 promise 对象的共享指针,用来得到返回的结果

具体过程见下面代码中的注释,完整的代码:

#include <thread>
#include <queue>
#include <mutex>
#include <condition_variable>
#include <stdio.h>
#include <string>
#include <memory>
#include <future>
#include <chrono>


using namespace std;

struct Job {
	shared_ptr<promise<string>> pro;  // 返回结果,如果在目标检测的例子中就是框
	string input;  // 输入,图片
};

queue<Job> qjobs_;
mutex lock_;
condition_variable cv_;
int get_image_time_ = 300; // 先假设获取一张图片与推理一张图片都是一秒
int infer_image_time_ = 1000;
const int limit_ = 5;

void video_capture() {
	int  pic_id = 0;
	while (true) {
		Job job;
		{
			unique_lock<mutex> l(lock_);
			char name[100];
			sprintf(name, "PIC-%d", pic_id++);
			printf("生产了一张新图片: %s, 当前队列大小: %d\n", name, (int)qjobs_.size());

			job.pro.reset(new promise<string> ());
			job.input = name;
			qjobs_.push(job);

			// condition_variable.wait(lock, predicate);
			// predicate 指定什么时候等待,什么时候停止等待
			cv_.wait(l, [&](){
					// return false 表示继续等待; return true 表示停止等待
					return qjobs_.size() <= limit_;
			});
		}
		// .get() 实现等待, 直到promise->set_value()被执行了,这里的返回值就是result
		// 另外要注意,这里等待结果要放在锁的外面,避免持有锁等待结果,造成死锁
		auto result = job.pro->get_future().get();
		// 处理result
		printf("Job %s -> %s\n", job.input.c_str(), result.c_str());

		this_thread::sleep_for(chrono::milliseconds(get_image_time_));
	}
}

void infer_worker() {
	while (true) {
		if (!qjobs_.empty()) {
			{
				lock_guard<mutex> l(lock_);
				auto pjob = qjobs_.front();
				qjobs_.pop();
				printf("消费掉一张图片: %s\n", pjob.input.c_str());

				auto res = pjob.input + " ---- infer result";
				pjob.pro->set_value(res);

				// 消费掉一个,就可以通知wait,停止等待
				cv_.notify_one();

			}
			this_thread::sleep_for(chrono::milliseconds(infer_image_time_));
		}
		this_thread::yield(); // 没有要处理的图片,主动交出CPU,避免资源浪费
	}
}

int main() {
	thread t0(video_capture);
	thread t1(infer_worker);

	t0.join();
	t1.join();

	return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87

输出:

生产了一张新图片: PIC-0, 当前队列大小: 0
消费掉一张图片: PIC-0
Job PIC-0 -> PIC-0 ---- infer result
生产了一张新图片: PIC-1, 当前队列大小: 0
消费掉一张图片: PIC-1
Job PIC-1 -> PIC-1 ---- infer result
生产了一张新图片: PIC-2, 当前队列大小: 0
消费掉一张图片: PIC-2
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

可以看到结果中能够拿到对应图片的推理结果。

RAII+接口模式对模型加载进行单批多图推理封装

考虑下面的推理类加载模型和推理的过程:(context_ 来代替模型,实际案例中,模型的加载与释放比这要复杂的多,这里简单地用 string 来代替)

class Infer {
public:
	bool load_model(const string& file) {
		// 异常逻辑处理
		if (!context_.empty()) {
			destory();
		}
		// 正常逻辑
		context_ = file;
		return true;
	}

	void forward() {
		// 异常逻辑处理
		if (context_.empty()) {
			printf("模型尚未加载.\n");
			return;
		}
		// 正常逻辑
		printf("正在使用 %s 进行推理.\n", context_.c_str());
	}

	void destory() {
		context_.clear();
	}
private:
	string context_;

};

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

问题:

正常工作代码中,异常逻辑的处理(如模型推理前未进行模型加载、推理后未进行模型销毁等)需要耗费大量时间和代码量,如果异常逻辑写的不对,甚至会造成封装的不安全性,导致程序崩溃。这样封装又难写,又难用。

解决方法:

  • RAII:资源获取即初始化
  • 接口模式:设计模式,是一种封装模式,实现类与接口类分离的模式

我们分别来看这两种解决方法带来的好处:

RAII

我们使用这样一个 create_infer 函数来代替 Infer 类的直接初始化:

shared_ptr<Infer> create_infer(const string& file) {
	shared_ptr<Infer> instance(new Infer());
	if (!instance->load_model(file)) instance.reset();
	return instance;
}

int main() {

	// Infer infer;  // 直接获取类
	string file = "...";
	auto infer = create_infer(file);  // 通过封装的函数获取类
	if (infer == nullptr) printf("模型加载失败\n");

	infer->forward();

	return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

RAII 的特点:获取 infer 实例,即表示加载模型。并且获取资源与加载模型强绑定,加载模型成功,则表示获取资源成功,加载模型失败,则直接表示获取资源失败。

好处:

  1. 避免外部执行 load_model ,只有在 create_infer 中调用,不会有任何另外的地方调用,后面会进一步通过接口模式直接禁止外部执行
  2. 一个实例的 load_model 不会执行超过一次
  3. 获取的模型一定初始化成功,因此 forward 时不必再做判断
    • 仅需在外部做一次 create 是否成功的判断
    • 不需要在 forward 函数、create 函数内再做异常判断

接口模式

  1. 解决成员函数(如load_model)外部仍可调用的问题,我们之前说过,要保证它只在 create_infer 中调用
  2. 解决成员变量(如context_) 对外可见的问题
    • 注意:这里的 context_ 虽然是 private 变量不可访问,但是是对外可见的。对外可见可能造成的问题是:特殊的成员变量类型对头文件的依赖,从而造成的命名空间污染/头文件污染。比如成员变量是 cudaStream_t 类型,那就必须包含 cuda_runtime.h 头文件。
  3. 接口类 (这里的 InferInterface 类) 是一个纯虚类,其原则是:**只暴露调用者需要的函数,其他一概不暴露。**比如 load_model 已通过 RAII 封装到 create_infer 内,这里 load_model 就属于不需要暴露的类,内部如果有启动线程如 start、stop 等,也不需要暴露。而 forward 这些函数肯定是需要暴露的。
  4. 此时,可以将这些声明与实现分别放到 infer.hpp 和 infer.cpp 中

最终我们的完整代码有三个文件: infer.hpp, infer.cpp, main.cpp 分别如下:

infer.hpp

// infer.hpp
#ifndef INFER_HPP
#define INFER_HPP

#include <memory>
#include <string>

class InferInterface {
public:
	virtual void forward() = 0;
};

std::shared_ptr<InferInterface> create_infer(const std::string& file);

#endif
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

infer.cpp

#include "infer.hpp"

using namespace std;

class InferImpl : public InferInterface {
public:
	bool load_model(const string& file) {
		context_ = file;
		return true;
	}

	virtual void forward() override {
		printf("正在使用 %s 进行推理.\n", context_.c_str());
	}

	void destory() {
		context_.clear();
	}

private:
	string context_;
};

shared_ptr<InferInterface> create_infer(const string& file) {
	shared_ptr<InferImpl> instance(new InferImpl());
	if (!instance->load_model(file)) instance.reset();
	return instance;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

main.cpp

#include "infer.hpp"

using namespace std;

int main() {

	string file = "model a";
	auto infer = create_infer(file);
	if (infer == nullptr) {
		printf("模型加载失败\n");
		return -1;
	}

	infer->forward();

	return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17

原则总结:

  1. 头文件,尽量只包含需要的部分
  2. 外界不需要的,尽量不让外界看到,保持接口的简洁
  3. 不要在头文件中用 using namespace ... ,如果写了的话,所有包含改头文件的文件,就都打开了这个命名空间

多图推理

最终我们给出多图推理的代码,同样是三个文件,关键代码已经给出注释:

infer.hpp

// infer.hpp
#ifndef INFER_HPP
#define INFER_HPP

#include <memory>
#include <string>
#include <future>

class InferInterface {
public:
	virtual std::shared_future<std::string> forward(std::string pic) = 0;
};

std::shared_ptr<InferInterface> create_infer(const std::string& file);

#endif
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16

infer.cpp

// infer.cpp
#include "infer.hpp"
#include <mutex>
#include <thread>
#include <future>
#include <queue>
#include <string>
#include <memory>
#include <chrono>
#include <condition_variable>

using namespace std;

struct Job {
	shared_ptr<promise<string>> pro;
	string input;
};

class InferImpl : public InferInterface {
public:
	// 析构函数
	virtual ~InferImpl() {
		worker_running_ = false;
		cv_.notify_one();

		if (worker_thread_.joinable()) worker_thread_.join();
	}

	bool load_model(const string& file) {
		// 尽量保证资源在哪里分配,就在哪里使用,就在哪里释放,这样不会太乱。比如这里我们就都在 worker 函数内完成。
		// 这里的pro表示是否启动成功
		promise<bool> pro;
		worker_running_ = true;
		worker_thread_ = thread(&InferImpl::worker, this, file, std::ref(pro));
		return pro.get_future().get();
	}

	virtual shared_future<string> forward(string pic) override {
		// printf("正在使用 %s 进行推理.\n", context_.c_str());
		Job job;
		job.pro.reset(new promise<string>());
		job.input = pic;

		lock_guard<mutex> l(job_lock_);
		qjobs_.push(job);

		// 一旦有任务需要推理,发送通知
		cv_.notify_one();

		// return job.pro->get_future().get();	// 不能这样直接返回模型推理的结果,因为这样会等待模型推理结束,相当于还是串行
		return job.pro->get_future();  // 而是直接返回future对象,让外部按需要再.get()获取结果
	}

	void worker(string file, promise<bool>& pro) {
		// worker是实际执行推理的函数
		// context的加载、使用和释放都在worker内
		string context = file;
		if (context.empty()) {  // 未初始化,返回false
			pro.set_value(false);
			return;
		}
		else {  // 已初始化,返回true,之后正式开始进行推理
			pro.set_value(true);
		}

		int max_batch_size = 5;
		vector<Job> jobs;  // 拿多张图片 batch
		int batch_id = 0;
		while (worker_running_) {
			// 被动等待接收通知
			unique_lock<mutex> l(job_lock_);
			cv_.wait(l, [&](){
					// true:停止等待
					return !qjobs_.empty() || !worker_running_;
					});
			// 如果是因为程序发送终止信号而推出wait的
			if (!worker_running_) break;

			// 可以一次拿一批出来, 最大拿maxBatchSize个
			while (jobs.size() < max_batch_size && !qjobs_.empty()) {
				jobs.emplace_back(qjobs_.front());
				qjobs_.pop();
			}
			// 执行batch推理
			for (int i=0; i<jobs.size(); ++i) {
				auto& job = jobs[i];
				char name[100];
				sprintf(name, "%s : batch->%d[%d]", job.input.c_str(), batch_id, (int)jobs.size());
				job.pro->set_value(name);
			}
			batch_id++;
			jobs.clear();
			this_thread::sleep_for(chrono::milliseconds(infer_time_));
		}
		printf("释放模型: %s\n", context.c_str());
		context.clear(); // 释放模型
		printf("线程终止\n");
	}

private:
	atomic<bool> worker_running_{false}; // 表示程序是否正在运行
	thread worker_thread_;
	queue<Job> qjobs_;
	mutex job_lock_;
	condition_variable cv_;
	int infer_time_ = 1000;
};

shared_ptr<InferInterface> create_infer(const string& file) {
	shared_ptr<InferImpl> instance(new InferImpl());
	if (!instance->load_model(file)) instance.reset();
	return instance;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113

main.cpp

// main.cpp
#include "infer.hpp"

using namespace std;
int main() {

	string file = "model a";
	auto infer = create_infer(file);
	if (infer == nullptr) {
		printf("模型加载失败\n");
		return -1;
	}

	auto fa = infer->forward("A");
	auto fb = infer->forward("B");
	auto fc = infer->forward("C");
	printf("%s\n", fa.get().c_str());
	printf("%s\n", fb.get().c_str());
	printf("%s\n", fc.get().c_str());

	// auto fa = infer->forward("A").get();
	// auto fb = infer->forward("B").get();
	// auto fc = infer->forward("C").get();
	// printf("%s\n", fa.c_str());
	// printf("%s\n", fb.c_str());
	// printf("%s\n", fc.c_str());

	printf("程序终止\n");
	return 0;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30

想一下,如果按照注释掉的部分的方式来进行推理的话,会有什么不同呢?

会每次都等待结果,无法进行单批次多图处理。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/凡人多烦事01/article/detail/267916
推荐阅读
相关标签
  

闽ICP备14008679号