当前位置:   article > 正文

yolov5-7.0目标识别使用Tensorrt推理(C++Api)_yolov5 tensorrtc++推理

yolov5 tensorrtc++推理

1、程序环境配置

        具体路径根据自己安装目录修改。

2、pt模型转engine模型

        首先需要获取Tensorrt推理时需要使用的engine模型。这里我直接使用yolov5的export.py进行导出。具体操作参考yolov5的pt模型转onnx模型,onnx模型转engine模型

 3、推理程序

        推理主要使用Tensorrt的C++的Api。程序里主要有logging.h和utils.h两个基础的头文件,还有含推理程序的yolo.hpp,我将主要的推理程序都封装在一个类里面了,后续需要使用时直接new一个就可以直接使用了。下面是具体程序:

logging.h

  1. #pragma once
  2. /*
  3. * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #ifndef TENSORRT_LOGGING_H
  18. #define TENSORRT_LOGGING_H
  19. #include "NvInferRuntimeCommon.h"
  20. #include <cassert>
  21. #include <ctime>
  22. #include <iomanip>
  23. #include <iostream>
  24. #include <ostream>
  25. #include <sstream>
  26. #include <string>
  27. using Severity = nvinfer1::ILogger::Severity;
  28. class LogStreamConsumerBuffer : public std::stringbuf
  29. {
  30. public:
  31. LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
  32. : mOutput(stream)
  33. , mPrefix(prefix)
  34. , mShouldLog(shouldLog)
  35. {
  36. }
  37. LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other)
  38. : mOutput(other.mOutput)
  39. {
  40. }
  41. ~LogStreamConsumerBuffer()
  42. {
  43. // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
  44. // std::streambuf::pptr() gives a pointer to the current position of the output sequence
  45. // if the pointer to the beginning is not equal to the pointer to the current position,
  46. // call putOutput() to log the output to the stream
  47. if (pbase() != pptr())
  48. {
  49. putOutput();
  50. }
  51. }
  52. // synchronizes the stream buffer and returns 0 on success
  53. // synchronizing the stream buffer consists of inserting the buffer contents into the stream,
  54. // resetting the buffer and flushing the stream
  55. virtual int sync()
  56. {
  57. putOutput();
  58. return 0;
  59. }
  60. void putOutput()
  61. {
  62. if (mShouldLog)
  63. {
  64. // prepend timestamp
  65. std::time_t timestamp = std::time(nullptr);
  66. tm* tm_local = new tm();
  67. localtime_s(tm_local, &timestamp);
  68. std::cout << "[";
  69. std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
  70. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
  71. std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
  72. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
  73. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
  74. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
  75. // std::stringbuf::str() gets the string contents of the buffer
  76. // insert the buffer contents pre-appended by the appropriate prefix into the stream
  77. mOutput << mPrefix << str();
  78. // set the buffer to empty
  79. str("");
  80. // flush the stream
  81. mOutput.flush();
  82. }
  83. }
  84. void setShouldLog(bool shouldLog)
  85. {
  86. mShouldLog = shouldLog;
  87. }
  88. private:
  89. std::ostream& mOutput;
  90. std::string mPrefix;
  91. bool mShouldLog;
  92. };
  93. //!
  94. //! \class LogStreamConsumerBase
  95. //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
  96. //!
  97. class LogStreamConsumerBase
  98. {
  99. public:
  100. LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
  101. : mBuffer(stream, prefix, shouldLog)
  102. {
  103. }
  104. protected:
  105. LogStreamConsumerBuffer mBuffer;
  106. };
  107. //!
  108. //! \class LogStreamConsumer
  109. //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
  110. //! Order of base classes is LogStreamConsumerBase and then std::ostream.
  111. //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
  112. //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
  113. //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
  114. //! Please do not change the order of the parent classes.
  115. //!
  116. class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
  117. {
  118. public:
  119. //! \brief Creates a LogStreamConsumer which logs messages with level severity.
  120. //! Reportable severity determines if the messages are severe enough to be logged.
  121. LogStreamConsumer(Severity reportableSeverity, Severity severity)
  122. : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
  123. , std::ostream(&mBuffer) // links the stream buffer with the stream
  124. , mShouldLog(severity <= reportableSeverity)
  125. , mSeverity(severity)
  126. {
  127. }
  128. LogStreamConsumer(LogStreamConsumer&& other)
  129. : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
  130. , std::ostream(&mBuffer) // links the stream buffer with the stream
  131. , mShouldLog(other.mShouldLog)
  132. , mSeverity(other.mSeverity)
  133. {
  134. }
  135. void setReportableSeverity(Severity reportableSeverity)
  136. {
  137. mShouldLog = mSeverity <= reportableSeverity;
  138. mBuffer.setShouldLog(mShouldLog);
  139. }
  140. private:
  141. static std::ostream& severityOstream(Severity severity)
  142. {
  143. return severity >= Severity::kINFO ? std::cout : std::cerr;
  144. }
  145. static std::string severityPrefix(Severity severity)
  146. {
  147. switch (severity)
  148. {
  149. case Severity::kINTERNAL_ERROR: return "[F] ";
  150. case Severity::kERROR: return "[E] ";
  151. case Severity::kWARNING: return "[W] ";
  152. case Severity::kINFO: return "[I] ";
  153. case Severity::kVERBOSE: return "[V] ";
  154. default: assert(0); return "";
  155. }
  156. }
  157. bool mShouldLog;
  158. Severity mSeverity;
  159. };
  160. //! \class Logger
  161. //!
  162. //! \brief Class which manages logging of TensorRT tools and samples
  163. //!
  164. //! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
  165. //! and supports logging two types of messages:
  166. //!
  167. //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
  168. //! - Test pass/fail messages
  169. //!
  170. //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
  171. //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
  172. //!
  173. //! In the future, this class could be extended to support dumping test results to a file in some standard format
  174. //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
  175. //!
  176. //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
  177. //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
  178. //! library and messages coming from the sample.
  179. //!
  180. //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
  181. //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
  182. //! object.
  183. class Logger : public nvinfer1::ILogger
  184. {
  185. public:
  186. Logger(Severity severity = Severity::kWARNING)
  187. : mReportableSeverity(severity)
  188. {
  189. }
  190. //!
  191. //! \enum TestResult
  192. //! \brief Represents the state of a given test
  193. //!
  194. enum class TestResult
  195. {
  196. kRUNNING, //!< The test is running
  197. kPASSED, //!< The test passed
  198. kFAILED, //!< The test failed
  199. kWAIVED //!< The test was waived
  200. };
  201. //!
  202. //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
  203. //! \return The nvinfer1::ILogger associated with this Logger
  204. //!
  205. //! TODO Once all samples are updated to use this method to register the logger with TensorRT,
  206. //! we can eliminate the inheritance of Logger from ILogger
  207. //!
  208. nvinfer1::ILogger& getTRTLogger()
  209. {
  210. return *this;
  211. }
  212. //!
  213. //! \brief Implementation of the nvinfer1::ILogger::log() virtual method
  214. //!
  215. //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
  216. //! inheritance from nvinfer1::ILogger
  217. //!
  218. void log(Severity severity, const char* msg) noexcept override
  219. {
  220. LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
  221. }
  222. //!
  223. //! \brief Method for controlling the verbosity of logging output
  224. //!
  225. //! \param severity The logger will only emit messages that have severity of this level or higher.
  226. //!
  227. void setReportableSeverity(Severity severity)
  228. {
  229. mReportableSeverity = severity;
  230. }
  231. //!
  232. //! \brief Opaque handle that holds logging information for a particular test
  233. //!
  234. //! This object is an opaque handle to information used by the Logger to print test results.
  235. //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
  236. //! with Logger::reportTest{Start,End}().
  237. //!
  238. class TestAtom
  239. {
  240. public:
  241. TestAtom(TestAtom&&) = default;
  242. private:
  243. friend class Logger;
  244. TestAtom(bool started, const std::string& name, const std::string& cmdline)
  245. : mStarted(started)
  246. , mName(name)
  247. , mCmdline(cmdline)
  248. {
  249. }
  250. bool mStarted;
  251. std::string mName;
  252. std::string mCmdline;
  253. };
  254. //!
  255. //! \brief Define a test for logging
  256. //!
  257. //! \param[in] name The name of the test. This should be a string starting with
  258. //! "TensorRT" and containing dot-separated strings containing
  259. //! the characters [A-Za-z0-9_].
  260. //! For example, "TensorRT.sample_googlenet"
  261. //! \param[in] cmdline The command line used to reproduce the test
  262. //
  263. //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
  264. //!
  265. static TestAtom defineTest(const std::string& name, const std::string& cmdline)
  266. {
  267. return TestAtom(false, name, cmdline);
  268. }
  269. //!
  270. //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
  271. //! as input
  272. //!
  273. //! \param[in] name The name of the test
  274. //! \param[in] argc The number of command-line arguments
  275. //! \param[in] argv The array of command-line arguments (given as C strings)
  276. //!
  277. //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
  278. static TestAtom defineTest(const std::string& name, int argc, char const* const* argv)
  279. {
  280. auto cmdline = genCmdlineString(argc, argv);
  281. return defineTest(name, cmdline);
  282. }
  283. //!
  284. //! \brief Report that a test has started.
  285. //!
  286. //! \pre reportTestStart() has not been called yet for the given testAtom
  287. //!
  288. //! \param[in] testAtom The handle to the test that has started
  289. //!
  290. static void reportTestStart(TestAtom& testAtom)
  291. {
  292. reportTestResult(testAtom, TestResult::kRUNNING);
  293. assert(!testAtom.mStarted);
  294. testAtom.mStarted = true;
  295. }
  296. //!
  297. //! \brief Report that a test has ended.
  298. //!
  299. //! \pre reportTestStart() has been called for the given testAtom
  300. //!
  301. //! \param[in] testAtom The handle to the test that has ended
  302. //! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
  303. //! TestResult::kFAILED, TestResult::kWAIVED
  304. //!
  305. static void reportTestEnd(const TestAtom& testAtom, TestResult result)
  306. {
  307. assert(result != TestResult::kRUNNING);
  308. assert(testAtom.mStarted);
  309. reportTestResult(testAtom, result);
  310. }
  311. static int reportPass(const TestAtom& testAtom)
  312. {
  313. reportTestEnd(testAtom, TestResult::kPASSED);
  314. return EXIT_SUCCESS;
  315. }
  316. static int reportFail(const TestAtom& testAtom)
  317. {
  318. reportTestEnd(testAtom, TestResult::kFAILED);
  319. return EXIT_FAILURE;
  320. }
  321. static int reportWaive(const TestAtom& testAtom)
  322. {
  323. reportTestEnd(testAtom, TestResult::kWAIVED);
  324. return EXIT_SUCCESS;
  325. }
  326. static int reportTest(const TestAtom& testAtom, bool pass)
  327. {
  328. return pass ? reportPass(testAtom) : reportFail(testAtom);
  329. }
  330. Severity getReportableSeverity() const
  331. {
  332. return mReportableSeverity;
  333. }
  334. private:
  335. //!
  336. //! \brief returns an appropriate string for prefixing a log message with the given severity
  337. //!
  338. static const char* severityPrefix(Severity severity)
  339. {
  340. switch (severity)
  341. {
  342. case Severity::kINTERNAL_ERROR: return "[F] ";
  343. case Severity::kERROR: return "[E] ";
  344. case Severity::kWARNING: return "[W] ";
  345. case Severity::kINFO: return "[I] ";
  346. case Severity::kVERBOSE: return "[V] ";
  347. default: assert(0); return "";
  348. }
  349. }
  350. //!
  351. //! \brief returns an appropriate string for prefixing a test result message with the given result
  352. //!
  353. static const char* testResultString(TestResult result)
  354. {
  355. switch (result)
  356. {
  357. case TestResult::kRUNNING: return "RUNNING";
  358. case TestResult::kPASSED: return "PASSED";
  359. case TestResult::kFAILED: return "FAILED";
  360. case TestResult::kWAIVED: return "WAIVED";
  361. default: assert(0); return "";
  362. }
  363. }
  364. //!
  365. //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
  366. //!
  367. static std::ostream& severityOstream(Severity severity)
  368. {
  369. return severity >= Severity::kINFO ? std::cout : std::cerr;
  370. }
  371. //!
  372. //! \brief method that implements logging test results
  373. //!
  374. static void reportTestResult(const TestAtom& testAtom, TestResult result)
  375. {
  376. severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
  377. << testAtom.mCmdline << std::endl;
  378. }
  379. //!
  380. //! \brief generate a command line string from the given (argc, argv) values
  381. //!
  382. static std::string genCmdlineString(int argc, char const* const* argv)
  383. {
  384. std::stringstream ss;
  385. for (int i = 0; i < argc; i++)
  386. {
  387. if (i > 0)
  388. ss << " ";
  389. ss << argv[i];
  390. }
  391. return ss.str();
  392. }
  393. Severity mReportableSeverity;
  394. };
  395. namespace
  396. {
  397. //!
  398. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
  399. //!
  400. //! Example usage:
  401. //!
  402. //! LOG_VERBOSE(logger) << "hello world" << std::endl;
  403. //!
  404. inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
  405. {
  406. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);
  407. }
  408. //!
  409. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
  410. //!
  411. //! Example usage:
  412. //!
  413. //! LOG_INFO(logger) << "hello world" << std::endl;
  414. //!
  415. inline LogStreamConsumer LOG_INFO(const Logger& logger)
  416. {
  417. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);
  418. }
  419. //!
  420. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
  421. //!
  422. //! Example usage:
  423. //!
  424. //! LOG_WARN(logger) << "hello world" << std::endl;
  425. //!
  426. inline LogStreamConsumer LOG_WARN(const Logger& logger)
  427. {
  428. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);
  429. }
  430. //!
  431. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
  432. //!
  433. //! Example usage:
  434. //!
  435. //! LOG_ERROR(logger) << "hello world" << std::endl;
  436. //!
  437. inline LogStreamConsumer LOG_ERROR(const Logger& logger)
  438. {
  439. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);
  440. }
  441. //!
  442. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
  443. // ("fatal" severity)
  444. //!
  445. //! Example usage:
  446. //!
  447. //! LOG_FATAL(logger) << "hello world" << std::endl;
  448. //!
  449. inline LogStreamConsumer LOG_FATAL(const Logger& logger)
  450. {
  451. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);
  452. }
  453. } // anonymous namespace
  454. #endif // TENSORRT_LOGGING_H

utils.h 

  1. #pragma once
  2. #include <algorithm>
  3. #include <fstream>
  4. #include <iostream>
  5. #include <opencv2/opencv.hpp>
  6. #include <vector>
  7. #include <chrono>
  8. #include <cmath>
  9. #include <numeric> // std::iota
  10. using namespace cv;
  11. #define CHECK(status) \
  12. do\
  13. {\
  14. auto ret = (status);\
  15. if (ret != 0)\
  16. {\
  17. std::cerr << "Cuda failure: " << ret << std::endl;\
  18. abort();\
  19. }\
  20. } while (0)
  21. struct alignas(float) Detection {
  22. //center_x center_y w h
  23. float bbox[4];
  24. float conf; // bbox_conf * cls_conf
  25. int class_id;
  26. };
  27. static inline cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h, std::vector<int>& padsize) {
  28. int w, h, x, y;
  29. float r_w = input_w / (img.cols*1.0);
  30. float r_h = input_h / (img.rows*1.0);
  31. if (r_h > r_w) {//宽大于高
  32. w = input_w;
  33. h = r_w * img.rows;
  34. x = 0;
  35. y = (input_h - h) / 2;
  36. }
  37. else {
  38. w = r_h * img.cols;
  39. h = input_h;
  40. x = (input_w - w) / 2;
  41. y = 0;
  42. }
  43. cv::Mat re(h, w, CV_8UC3);
  44. cv::resize(img, re, re.size(), 0, 0, cv::INTER_LINEAR);
  45. cv::Mat out(input_h, input_w, CV_8UC3, cv::Scalar(128, 128, 128));
  46. re.copyTo(out(cv::Rect(x, y, re.cols, re.rows)));
  47. padsize.push_back(h);
  48. padsize.push_back(w);
  49. padsize.push_back(y);
  50. padsize.push_back(x);// int newh = padsize[0], neww = padsize[1], padh = padsize[2], padw = padsize[3];
  51. return out;
  52. }
  53. cv::Rect get_rect(cv::Mat& img, float bbox[4], int INPUT_W, int INPUT_H) {
  54. int l, r, t, b;
  55. float r_w = INPUT_W / (img.cols * 1.0);
  56. float r_h = INPUT_H / (img.rows * 1.0);
  57. if (r_h > r_w) {
  58. l = bbox[0];
  59. r = bbox[2];
  60. t = bbox[1]- (INPUT_H - r_w * img.rows) / 2;
  61. b = bbox[3] - (INPUT_H - r_w * img.rows) / 2;
  62. l = l / r_w;
  63. r = r / r_w;
  64. t = t / r_w;
  65. b = b / r_w;
  66. }
  67. else {
  68. l = bbox[0] - bbox[2] / 2.f - (INPUT_W - r_h * img.cols) / 2;
  69. r = bbox[0] + bbox[2] / 2.f - (INPUT_W - r_h * img.cols) / 2;
  70. t = bbox[1] - bbox[3] / 2.f;
  71. b = bbox[1] + bbox[3] / 2.f;
  72. l = l / r_h;
  73. r = r / r_h;
  74. t = t / r_h;
  75. b = b / r_h;
  76. }
  77. return cv::Rect(l, t, r - l, b - t);
  78. }

yolo.hpp 

  1. #pragma once
  2. #include "NvInfer.h"
  3. #include "cuda_runtime_api.h"
  4. #include "NvInferPlugin.h"
  5. #include "logging.h"
  6. #include <opencv2/opencv.hpp>
  7. #include "utils.h"
  8. #include <string>
  9. using namespace nvinfer1;
  10. using namespace cv;
  11. // stuff we know about the network and the input/output blobs
  12. static const int batchSize = 1;
  13. static const int INPUT_H = 640;
  14. static const int INPUT_W = 640;
  15. static const int _segWidth = 160;
  16. static const int _segHeight = 160;
  17. static const int _segChannels = 32;
  18. static const int CLASSES = 80;
  19. static const int Num_box = 25200;
  20. static const int OUTPUT_SIZE = batchSize * Num_box * (CLASSES + 5);//output0
  21. static const int INPUT_SIZE = batchSize * 3 * INPUT_H * INPUT_W;//images
  22. //置信度阈值
  23. static const float CONF_THRESHOLD = 0.5;
  24. //nms阈值
  25. static const float NMS_THRESHOLD = 0.5;
  26. //输入结点名称
  27. const char* INPUT_BLOB_NAME = "images";
  28. //检测头的输出结点名称
  29. const char* OUTPUT_BLOB_NAME = "output0";//detect
  30. //分割头的输出结点名称
  31. //定义两个静态浮点,用于保存两个输出头的输出结果
  32. static float prob[OUTPUT_SIZE]; //box
  33. static Logger gLogger;
  34. struct OutputSeg {
  35. int id; //结果类别id
  36. float confidence; //结果置信度
  37. cv::Rect box; //矩形框
  38. };
  39. //中间储存
  40. struct OutputObject
  41. {
  42. std::vector<int> classIds;//结果id数组
  43. std::vector<float> confidences;//结果每个id对应置信度数组
  44. std::vector<cv::Rect> boxes;//每个id矩形框
  45. };
  46. const float color_list[80][3] =
  47. {
  48. {0.000, 0.447, 0.741},
  49. {0.850, 0.325, 0.098},
  50. {0.929, 0.694, 0.125},
  51. {0.494, 0.184, 0.556},
  52. {0.466, 0.674, 0.188},
  53. {0.301, 0.745, 0.933},
  54. {0.635, 0.078, 0.184},
  55. {0.300, 0.300, 0.300},
  56. {0.600, 0.600, 0.600},
  57. {1.000, 0.000, 0.000},
  58. {1.000, 0.500, 0.000},
  59. {0.749, 0.749, 0.000},
  60. {0.000, 1.000, 0.000},
  61. {0.000, 0.000, 1.000},
  62. {0.667, 0.000, 1.000},
  63. {0.333, 0.333, 0.000},
  64. {0.333, 0.667, 0.000},
  65. {0.333, 1.000, 0.000},
  66. {0.667, 0.333, 0.000},
  67. {0.667, 0.667, 0.000},
  68. {0.667, 1.000, 0.000},
  69. {1.000, 0.333, 0.000},
  70. {1.000, 0.667, 0.000},
  71. {1.000, 1.000, 0.000},
  72. {0.000, 0.333, 0.500},
  73. {0.000, 0.667, 0.500},
  74. {0.000, 1.000, 0.500},
  75. {0.333, 0.000, 0.500},
  76. {0.333, 0.333, 0.500},
  77. {0.333, 0.667, 0.500},
  78. {0.333, 1.000, 0.500},
  79. {0.667, 0.000, 0.500},
  80. {0.667, 0.333, 0.500},
  81. {0.667, 0.667, 0.500},
  82. {0.667, 1.000, 0.500},
  83. {1.000, 0.000, 0.500},
  84. {1.000, 0.333, 0.500},
  85. {1.000, 0.667, 0.500},
  86. {1.000, 1.000, 0.500},
  87. {0.000, 0.333, 1.000},
  88. {0.000, 0.667, 1.000},
  89. {0.000, 1.000, 1.000},
  90. {0.333, 0.000, 1.000},
  91. {0.333, 0.333, 1.000},
  92. {0.333, 0.667, 1.000},
  93. {0.333, 1.000, 1.000},
  94. {0.667, 0.000, 1.000},
  95. {0.667, 0.333, 1.000},
  96. {0.667, 0.667, 1.000},
  97. {0.667, 1.000, 1.000},
  98. {1.000, 0.000, 1.000},
  99. {1.000, 0.333, 1.000},
  100. {1.000, 0.667, 1.000},
  101. {0.333, 0.000, 0.000},
  102. {0.500, 0.000, 0.000},
  103. {0.667, 0.000, 0.000},
  104. {0.833, 0.000, 0.000},
  105. {1.000, 0.000, 0.000},
  106. {0.000, 0.167, 0.000},
  107. {0.000, 0.333, 0.000},
  108. {0.000, 0.500, 0.000},
  109. {0.000, 0.667, 0.000},
  110. {0.000, 0.833, 0.000},
  111. {0.000, 1.000, 0.000},
  112. {0.000, 0.000, 0.167},
  113. {0.000, 0.000, 0.333},
  114. {0.000, 0.000, 0.500},
  115. {0.000, 0.000, 0.667},
  116. {0.000, 0.000, 0.833},
  117. {0.000, 0.000, 1.000},
  118. {0.000, 0.000, 0.000},
  119. {0.143, 0.143, 0.143},
  120. {0.286, 0.286, 0.286},
  121. {0.429, 0.429, 0.429},
  122. {0.571, 0.571, 0.571},
  123. {0.714, 0.714, 0.714},
  124. {0.857, 0.857, 0.857},
  125. {0.000, 0.447, 0.741},
  126. {0.314, 0.717, 0.741},
  127. {0.50, 0.5, 0}
  128. };
  129. //output中,包含了经过处理的id、conf、box和maskiamg信息
  130. static void DrawPred(Mat& img, std::vector<OutputSeg> result) {
  131. //生成随机颜色
  132. std::vector<Scalar> color;
  133. //这行代码的作用是将当前系统时间作为随机数种子,使得每次程序运行时都会生成不同的随机数序列。
  134. srand(time(0));
  135. //根据类别数,生成不同的颜色
  136. for (int i = 0; i < CLASSES; i++) {
  137. int b = rand() % 256;
  138. int g = rand() % 256;
  139. int r = rand() % 256;
  140. color.push_back(Scalar(b, g, r));
  141. }
  142. //Mat mask = img.clone();
  143. for (int i = 0; i < result.size(); i++) {
  144. int left, top;
  145. left = result[i].box.x;
  146. top = result[i].box.y;
  147. int color_num = i;
  148. //画矩形框,颜色是上面选的
  149. rectangle(img, result[i].box, color[result[i].id], 2, 8);
  150. //将box中的result[i].boxMask区域涂成color[result[i].id]颜色
  151. //mask(result[i].box).setTo(color[result[i].id], result[i].boxMask);
  152. std::string label = std::to_string(result[i].id) + ":" + std::to_string(result[i].confidence);
  153. int baseLine;
  154. //获取标签文本的尺寸
  155. Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
  156. //确定一个最大的高
  157. top = max(top, labelSize.height);
  158. //把文本信息加到图像上
  159. putText(img, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 1, color[result[i].id], 2);
  160. }
  161. //用于对图像的加权融合
  162. //图像1、图像1权重、图像2、图像2权重,添加结果中的标量、输出图像
  163. //addWeighted(img, 0.5, mask, 0.5, 0, img); //将mask加在原图上面
  164. }
  165. //输入引擎文本、图像数据、定义的检测输出和分割输出、batchSize
  166. static void doInference(IExecutionContext& context, float* input, float* output, int batchSize)
  167. {
  168. //从上下文中获取一个CUDA引擎。这个引擎加载了一个深度学习模型
  169. const ICudaEngine& engine = context.getEngine();
  170. //判断该引擎是否有三个绑定,intput, output0, output1
  171. assert(engine.getNbBindings() == 2);
  172. //定义了一个指向void的指针数组,用于存储GPU缓冲区的地址
  173. void* buffers[2];
  174. //获取输入和输出blob的索引,这些索引用于之后的缓冲区操作
  175. const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
  176. const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);
  177. // 使用cudaMalloc分配了GPU内存。这些内存将用于存储模型的输入和输出
  178. CHECK(cudaMalloc(&buffers[inputIndex], INPUT_SIZE * sizeof(float)));//
  179. CHECK(cudaMalloc(&buffers[outputIndex], OUTPUT_SIZE * sizeof(float)));
  180. // cudaMalloc分配内存 cudaFree释放内存 cudaMemcpy或 cudaMemcpyAsync 在主机和设备之间传输数据
  181. // cudaMemcpy cudaMemcpyAsync 显式地阻塞传输 显式地非阻塞传输
  182. //创建一个CUDA流。CUDA流是一种特殊的并发执行环境,可以在其中安排任务以并发执行。流使得任务可以并行执行,从而提高了GPU的利用率。
  183. cudaStream_t stream;
  184. //判断是否创建成功
  185. CHECK(cudaStreamCreate(&stream));
  186. // 使用cudaMemcpyAsync将输入数据异步地复制到GPU缓冲区。这个操作是非阻塞的,意味着它不会立即完成。
  187. CHECK(cudaMemcpyAsync(buffers[inputIndex], input, INPUT_SIZE * sizeof(float), cudaMemcpyHostToDevice, stream));
  188. //将输入和输出缓冲区以及流添加到上下文的执行队列中。这将触发模型的推理。
  189. context.enqueue(batchSize, buffers, stream, nullptr);
  190. //使用cudaMemcpyAsync函数将GPU上的数据复制到主内存中。这是异步的,意味着该函数立即返回,而数据传输可以在后台进行。
  191. CHECK(cudaMemcpyAsync(output, buffers[outputIndex], OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
  192. //等待所有在给定流上的操作都完成。这可以确保在释放流和缓冲区之前,所有的数据都已经被复制完毕。
  193. //这对于保证内存操作的正确性和防止数据竞争非常重要。
  194. cudaStreamSynchronize(stream);
  195. //释放内存
  196. cudaStreamDestroy(stream);
  197. CHECK(cudaFree(buffers[inputIndex]));
  198. CHECK(cudaFree(buffers[outputIndex]));
  199. }
  200. //检测YOLO类
  201. class YOLO
  202. {
  203. public:
  204. void init(std::string engine_path);
  205. void init(char* engine_path);
  206. void destroy();
  207. void blobFromImage(cv::Mat& img, float* data);
  208. void decode_boxs(cv::Mat& src, float* prob, OutputObject& outputObject, std::vector<int> padsize);
  209. void nms_outputs(cv::Mat& src, OutputObject& outputObject, std::vector<OutputSeg>& output);
  210. void detect_img(std::string image_path);
  211. void detect_img(char* image_path, float(*res_array)[6]);
  212. private:
  213. ICudaEngine* engine;
  214. IRuntime* runtime;
  215. IExecutionContext* context;
  216. };
  217. void YOLO::destroy()
  218. {
  219. this->context->destroy();
  220. this->engine->destroy();
  221. this->runtime->destroy();
  222. }
  223. void YOLO::init(std::string engine_path)
  224. {
  225. //无符号整型类型,通常用于表示对象的大小或计数
  226. //{ 0 }: 这是初始化列表,用于初始化 size 变量。在这种情况下,size 被初始化为 0。
  227. size_t size{ 0 };
  228. //定义一个指针变量,通过trtModelStream = new char[size];分配size个字符的空间
  229. //nullptr表示指针针在开始时不指向任何有效的内存地址,空指针
  230. char* trtModelStream{ nullptr };
  231. //打开文件,即engine模型
  232. std::ifstream file(engine_path, std::ios::binary);
  233. if (file.good())
  234. {
  235. //指向文件的最后地址
  236. file.seekg(0, file.end);
  237. //计算文件的长度
  238. size = file.tellg();
  239. //指回文件的起始地址
  240. file.seekg(0, file.beg);
  241. //为trtModelStream指针分配内存,内存大小为size
  242. trtModelStream = new char[size]; //开辟一个char 长度是文件的长度
  243. assert(trtModelStream);
  244. //把file内容传递给trtModelStream,传递大小为size,即engine模型内容传递
  245. file.read(trtModelStream, size);
  246. //关闭文件
  247. file.close();
  248. }
  249. std::cout << "engine init finished" << std::endl;
  250. //创建了一个Inference运行时环境,返回一个指向新创建的运行时环境的指针
  251. runtime = createInferRuntime(gLogger);
  252. assert(runtime != nullptr);
  253. //反序列化一个CUDA引擎。这个引擎将用于执行模型的前向传播
  254. engine = runtime->deserializeCudaEngine(trtModelStream, size);
  255. assert(engine != nullptr);
  256. //使用上一步中创建的引擎创建一个执行上下文。这个上下文将在模型的前向传播期间使用
  257. context = engine->createExecutionContext();
  258. assert(context != nullptr);
  259. //释放了用于存储模型序列化的内存
  260. delete[] trtModelStream;
  261. }
  262. void YOLO::init(char* engine_path)
  263. {
  264. //无符号整型类型,通常用于表示对象的大小或计数
  265. //{ 0 }: 这是初始化列表,用于初始化 size 变量。在这种情况下,size 被初始化为 0。
  266. size_t size{ 0 };
  267. //定义一个指针变量,通过trtModelStream = new char[size];分配size个字符的空间
  268. //nullptr表示指针针在开始时不指向任何有效的内存地址,空指针
  269. char* trtModelStream{ nullptr };
  270. //打开文件,即engine模型
  271. std::ifstream file(engine_path, std::ios::binary);
  272. if (file.good())
  273. {
  274. //指向文件的最后地址
  275. file.seekg(0, file.end);
  276. //计算文件的长度
  277. size = file.tellg();
  278. //指回文件的起始地址
  279. file.seekg(0, file.beg);
  280. //为trtModelStream指针分配内存,内存大小为size
  281. trtModelStream = new char[size]; //开辟一个char 长度是文件的长度
  282. assert(trtModelStream);
  283. //把file内容传递给trtModelStream,传递大小为size,即engine模型内容传递
  284. file.read(trtModelStream, size);
  285. //关闭文件
  286. file.close();
  287. }
  288. std::cout << "engine init finished" << std::endl;
  289. //创建了一个Inference运行时环境,返回一个指向新创建的运行时环境的指针
  290. runtime = createInferRuntime(gLogger);
  291. assert(runtime != nullptr);
  292. //反序列化一个CUDA引擎。这个引擎将用于执行模型的前向传播
  293. engine = runtime->deserializeCudaEngine(trtModelStream, size);
  294. assert(engine != nullptr);
  295. //使用上一步中创建的引擎创建一个执行上下文。这个上下文将在模型的前向传播期间使用
  296. context = engine->createExecutionContext();
  297. assert(context != nullptr);
  298. //释放了用于存储模型序列化的内存
  299. delete[] trtModelStream;
  300. }
  301. void YOLO::blobFromImage(cv::Mat& src, float* data)
  302. {
  303. //定义一个浮点数组
  304. //float* data = new float[3 * INPUT_H * INPUT_W];
  305. int i = 0;// [1,3,INPUT_H,INPUT_W]
  306. for (int row = 0; row < INPUT_H; ++row)
  307. {
  308. //逐行对象素值和图像通道进行处理
  309. //pr_img.step=widthx3 就是每一行有width个3通道的值
  310. //第row行
  311. uchar* uc_pixel = src.data + row * src.step;
  312. for (int col = 0; col < INPUT_W; ++col)
  313. {
  314. //第col列
  315. //提取第第row行第col列数据进行处理
  316. //像素值处理
  317. data[i] = (float)uc_pixel[2] / 255.0;
  318. //通道变换
  319. data[i + INPUT_H * INPUT_W] = (float)uc_pixel[1] / 255.0;
  320. data[i + 2 * INPUT_H * INPUT_W] = (float)uc_pixel[0] / 255.0;
  321. uc_pixel += 3;//表示进行下一列
  322. ++i;//表示在3个通道中的第i个位置,rgb三个通道的值是分开的,如r123456g123456b123456
  323. }
  324. }
  325. //delete[] data;
  326. }
  327. void YOLO::decode_boxs(cv::Mat& src, float* prob, OutputObject& outputObject, std::vector<int> padsize)
  328. {
  329. int newh = padsize[0], neww = padsize[1], padh = padsize[2], padw = padsize[3];
  330. float ratio_h = (float)src.rows / newh;
  331. float ratio_w = (float)src.cols / neww;
  332. // 处理box
  333. int net_width = CLASSES + 5;
  334. float* pdata = prob;
  335. for (int j = 0; j < Num_box; ++j) {
  336. float box_score = pdata[4]; ;//获取每一行的box框中含有某个物体的概率
  337. if (box_score >= CONF_THRESHOLD) {
  338. cv::Mat scores(1, CLASSES, CV_32FC1, pdata + 5);
  339. Point classIdPoint;
  340. double max_class_socre;
  341. //原矩阵、按行查找,0表示全矩阵,最大值的值,按列查找,0表示全矩阵,最大点的位置
  342. minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint);
  343. max_class_socre = (float)max_class_socre;
  344. //如果最大分数大于置信度,则进行下一步处理
  345. //保存符合置信度的目标信息,确定出类别和置信度,即通过80个类别分数,确定目标类别
  346. if (max_class_socre >= CONF_THRESHOLD) {
  347. //尺寸重构,减去填充的尺度,乘以放大因子
  348. float x = (pdata[0] - padw) * ratio_w; //x
  349. float y = (pdata[1] - padh) * ratio_h; //y
  350. float w = pdata[2] * ratio_w; //w
  351. float h = pdata[3] * ratio_h; //h
  352. //坐标变换,变为左上角和宽高
  353. int left = MAX((x - 0.5 * w), 0);
  354. int top = MAX((y - 0.5 * h), 0);
  355. //符合要求,则保存类别id
  356. outputObject.classIds.push_back(classIdPoint.x);
  357. //保存置信度
  358. outputObject.confidences.push_back(max_class_socre * box_score);
  359. //保存框
  360. outputObject.boxes.push_back(Rect(left, top, int(w), int(h)));
  361. }
  362. }
  363. pdata += net_width;//下一行
  364. }
  365. }
  366. void YOLO::nms_outputs(cv::Mat& src, OutputObject& outputObject, std::vector<OutputSeg>& output)
  367. {
  368. //执行非最大抑制以消除具有较低置信度的冗余重叠框(NMS)
  369. std::vector<int> nms_result;
  370. //通过opencv自带的nms函数进行,矩阵box、置信度大小,置信度阈值,nms阈值,结果
  371. cv::dnn::NMSBoxes(outputObject.boxes, outputObject.confidences, CONF_THRESHOLD, NMS_THRESHOLD, nms_result);
  372. 包括类别、置信度、框和mask
  373. //std::vector<std::vector<float>> temp_mask_proposals;
  374. //创建一个名为holeImgRect的Rect对象
  375. Rect holeImgRect(0, 0, src.cols, src.rows);
  376. //提取经过非极大值抑制后的结果
  377. for (int i = 0; i < nms_result.size(); ++i) {
  378. int idx = nms_result[i];
  379. OutputSeg result;
  380. result.id = outputObject.classIds[idx];
  381. result.confidence = outputObject.confidences[idx];
  382. result.box = outputObject.boxes[idx] & holeImgRect;
  383. output.push_back(result);
  384. }
  385. }
  386. //读取图片进行推理,并保存图片
  387. void YOLO::detect_img(std::string image_path)
  388. {
  389. cv::Mat img = cv::imread(image_path);
  390. //图像预处理,输入的是原图像和网络输入的高和宽,填充尺寸容器
  391. //输出的是重构后的图像,以及每条边填充的大小保存在padsize
  392. cv::Mat pr_img;
  393. std::vector<int> padsize;
  394. pr_img = preprocess_img(img, INPUT_H, INPUT_W, padsize); // Resize
  395. float* blob = new float[3 * INPUT_H * INPUT_W];
  396. blobFromImage(pr_img, blob);
  397. //推理
  398. auto start = std::chrono::system_clock::now();
  399. doInference(*context, blob, prob, batchSize);
  400. auto end = std::chrono::system_clock::now();
  401. std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
  402. //解析并绘制output
  403. OutputObject outputObject;
  404. std::vector<OutputSeg> output;
  405. std::vector<std::vector<float>> temp_mask_proposals;
  406. decode_boxs(img, prob, outputObject, padsize);
  407. nms_outputs(img, outputObject, output);
  408. //decode_mask(img, padsize, temp_mask_proposals, output);
  409. //绘制图像
  410. DrawPred(img, output);
  411. cv::imwrite("output.jpg", img);
  412. delete[] blob;
  413. }
  414. /// <summary>
  415. /// 读取图片进行推理,用数组将检测结果传出(包括label序号、置信度分数、矩形参数)
  416. /// </summary>
  417. /// <param name="image_path"></param>
  418. /// <param name="res_array"></param>
  419. void YOLO::detect_img(char* image_path, float(*res_array)[6])
  420. {
  421. cv::Mat img = cv::imread(image_path);
  422. //图像预处理,输入的是原图像和网络输入的高和宽,填充尺寸容器
  423. //输出的是重构后的图像,以及每条边填充的大小保存在padsize
  424. cv::Mat pr_img;
  425. std::vector<int> padsize;
  426. pr_img = preprocess_img(img, INPUT_H, INPUT_W, padsize); // Resize
  427. float* blob = new float[3 * INPUT_H * INPUT_W];
  428. blobFromImage(pr_img, blob);
  429. //推理
  430. auto start = std::chrono::system_clock::now();
  431. doInference(*context, blob, prob, batchSize);
  432. auto end = std::chrono::system_clock::now();
  433. std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
  434. //解析并绘制output
  435. OutputObject outputObject;
  436. std::vector<OutputSeg> output;
  437. std::vector<std::vector<float>> temp_mask_proposals;
  438. decode_boxs(img, prob, outputObject, padsize);
  439. nms_outputs(img, outputObject, output);
  440. //传出矩形框数据
  441. for (size_t j = 0; j < output.size(); j++)
  442. {
  443. res_array[j][0] = output[j].box.x;
  444. res_array[j][1] = output[j].box.y;
  445. res_array[j][2] = output[j].box.width;
  446. res_array[j][3] = output[j].box.height;
  447. res_array[j][4] = output[j].id;
  448. res_array[j][5] = output[j].confidence;
  449. }
  450. delete[] blob;
  451. }

4、测试使用 

        测试使用的话直接new一个,然后使用里面的函数即可,下面给出例子。

  1. #include <iostream>
  2. #include "yolo.hpp"
  3. int main()
  4. {
  5. YOLO yolo;
  6. yolo.init("E:\\OnnxTransEngine\\engine_infer2\\yolov5s.engine");
  7. yolo.detect_img("E:\\OnnxTransEngine\\zidane.jpg");
  8. yolo.destroy();
  9. }

5、C#调用 

        我自己通常使用C#编写上位机程序,所以一般都是将C++封装成DLL使用。

C++封装:

  1. #include <iostream>
  2. #include "yolo.hpp"
  3. //int main()
  4. //{
  5. // YOLO yolo;
  6. // yolo.init("E:\\OnnxTransEngine\\engine_infer2\\yolov5s.engine");
  7. // yolo.detect_img("E:\\OnnxTransEngine\\zidane.jpg");
  8. // yolo.destroy();
  9. //}
  10. YOLO yolo;
  11. extern "C" __declspec(dllexport) void Init(char* engine_path)
  12. {
  13. yolo.init(engine_path);
  14. return;
  15. }
  16. extern "C" __declspec(dllexport) void Detect_Img(char* image_path, float(*res_array)[6])
  17. {
  18. yolo.detect_img(image_path, res_array);
  19. return;
  20. }
  21. extern "C" __declspec(dllexport) void Destroy()
  22. {
  23. yolo.destroy();
  24. return;
  25. }

C#调用: 

  1. [DllImport("engine_infer.dll", CallingConvention = CallingConvention.Cdecl)]
  2. public static extern void Init(string engine_path);
  3. [DllImport("engine_infer.dll", CallingConvention = CallingConvention.Cdecl)]
  4. public static extern void Detect_Img(string path, float[,] resultArray);
  5. [DllImport("engine_infer.dll", CallingConvention = CallingConvention.Cdecl)]
  6. public static extern void Destroy();

6、具体项目修改参数 

        模型会有不同的训练参数,那么推理的参数也需要变化,主要修改的参数如图所示: 

        

         我们使用 Netron打开对应onnx模型,可以看到images和output0参数与模型一一对应,一般我们训练自己的模型都会修改CLASSES,根据自己参数对应修改即可。置信度阈值也根据自己需 要进行修改。

声明:本文内容由网友自发贡献,转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号