当前位置:   article > 正文

yolov8实例分割使用Tensorrt推理(C++Api)_yolov8 tensorrtc++推理

yolov8 tensorrtc++推理

        yolov8的实例分割其实是在目标识别的基础上有增加一个output1,这个输出是mask的数据输出,所以推理也是在目标识别的基础上加了一些东西。这个新增的output1与yolov5的是一致的。

yolov5: 

yolov8: 

1、程序环境配置

        具体路径根据自己安装目录修改。

  

2、pt模型转engine模型

        首先需要获取Tensorrt推理时需要使用的engine模型。这里我使用的是yolov5的export.py先导出onnx,再使用程序转成engine。因为测试过直接使用export.py转engine,后续使用会报错。具体操作参考yolov5的pt模型转onnx模型,onnx模型转engine模型 

3、推理程序 

        推理主要使用Tensorrt的C++的Api。程序里主要有logging.h和utils.h两个基础的头文件,这两个头文件不变,主要是含推理程序的yolo.hpp有一些修改,这部分跟yolov5也是有一些不同。下面是具体程序:

logging.h:

  1. #pragma once
  2. /*
  3. * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  4. *
  5. * Licensed under the Apache License, Version 2.0 (the "License");
  6. * you may not use this file except in compliance with the License.
  7. * You may obtain a copy of the License at
  8. *
  9. * http://www.apache.org/licenses/LICENSE-2.0
  10. *
  11. * Unless required by applicable law or agreed to in writing, software
  12. * distributed under the License is distributed on an "AS IS" BASIS,
  13. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. * See the License for the specific language governing permissions and
  15. * limitations under the License.
  16. */
  17. #ifndef TENSORRT_LOGGING_H
  18. #define TENSORRT_LOGGING_H
  19. #include "NvInferRuntimeCommon.h"
  20. #include <cassert>
  21. #include <ctime>
  22. #include <iomanip>
  23. #include <iostream>
  24. #include <ostream>
  25. #include <sstream>
  26. #include <string>
  27. using Severity = nvinfer1::ILogger::Severity;
  28. class LogStreamConsumerBuffer : public std::stringbuf
  29. {
  30. public:
  31. LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
  32. : mOutput(stream)
  33. , mPrefix(prefix)
  34. , mShouldLog(shouldLog)
  35. {
  36. }
  37. LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other)
  38. : mOutput(other.mOutput)
  39. {
  40. }
  41. ~LogStreamConsumerBuffer()
  42. {
  43. // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
  44. // std::streambuf::pptr() gives a pointer to the current position of the output sequence
  45. // if the pointer to the beginning is not equal to the pointer to the current position,
  46. // call putOutput() to log the output to the stream
  47. if (pbase() != pptr())
  48. {
  49. putOutput();
  50. }
  51. }
  52. // synchronizes the stream buffer and returns 0 on success
  53. // synchronizing the stream buffer consists of inserting the buffer contents into the stream,
  54. // resetting the buffer and flushing the stream
  55. virtual int sync()
  56. {
  57. putOutput();
  58. return 0;
  59. }
  60. void putOutput()
  61. {
  62. if (mShouldLog)
  63. {
  64. // prepend timestamp
  65. std::time_t timestamp = std::time(nullptr);
  66. tm* tm_local = new tm();
  67. localtime_s(tm_local, &timestamp);
  68. std::cout << "[";
  69. std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
  70. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
  71. std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
  72. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
  73. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
  74. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
  75. // std::stringbuf::str() gets the string contents of the buffer
  76. // insert the buffer contents pre-appended by the appropriate prefix into the stream
  77. mOutput << mPrefix << str();
  78. // set the buffer to empty
  79. str("");
  80. // flush the stream
  81. mOutput.flush();
  82. }
  83. }
  84. void setShouldLog(bool shouldLog)
  85. {
  86. mShouldLog = shouldLog;
  87. }
  88. private:
  89. std::ostream& mOutput;
  90. std::string mPrefix;
  91. bool mShouldLog;
  92. };
  93. //!
  94. //! \class LogStreamConsumerBase
  95. //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
  96. //!
  97. class LogStreamConsumerBase
  98. {
  99. public:
  100. LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
  101. : mBuffer(stream, prefix, shouldLog)
  102. {
  103. }
  104. protected:
  105. LogStreamConsumerBuffer mBuffer;
  106. };
  107. //!
  108. //! \class LogStreamConsumer
  109. //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
  110. //! Order of base classes is LogStreamConsumerBase and then std::ostream.
  111. //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
  112. //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
  113. //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
  114. //! Please do not change the order of the parent classes.
  115. //!
  116. class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
  117. {
  118. public:
  119. //! \brief Creates a LogStreamConsumer which logs messages with level severity.
  120. //! Reportable severity determines if the messages are severe enough to be logged.
  121. LogStreamConsumer(Severity reportableSeverity, Severity severity)
  122. : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
  123. , std::ostream(&mBuffer) // links the stream buffer with the stream
  124. , mShouldLog(severity <= reportableSeverity)
  125. , mSeverity(severity)
  126. {
  127. }
  128. LogStreamConsumer(LogStreamConsumer&& other)
  129. : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
  130. , std::ostream(&mBuffer) // links the stream buffer with the stream
  131. , mShouldLog(other.mShouldLog)
  132. , mSeverity(other.mSeverity)
  133. {
  134. }
  135. void setReportableSeverity(Severity reportableSeverity)
  136. {
  137. mShouldLog = mSeverity <= reportableSeverity;
  138. mBuffer.setShouldLog(mShouldLog);
  139. }
  140. private:
  141. static std::ostream& severityOstream(Severity severity)
  142. {
  143. return severity >= Severity::kINFO ? std::cout : std::cerr;
  144. }
  145. static std::string severityPrefix(Severity severity)
  146. {
  147. switch (severity)
  148. {
  149. case Severity::kINTERNAL_ERROR: return "[F] ";
  150. case Severity::kERROR: return "[E] ";
  151. case Severity::kWARNING: return "[W] ";
  152. case Severity::kINFO: return "[I] ";
  153. case Severity::kVERBOSE: return "[V] ";
  154. default: assert(0); return "";
  155. }
  156. }
  157. bool mShouldLog;
  158. Severity mSeverity;
  159. };
  160. //! \class Logger
  161. //!
  162. //! \brief Class which manages logging of TensorRT tools and samples
  163. //!
  164. //! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
  165. //! and supports logging two types of messages:
  166. //!
  167. //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
  168. //! - Test pass/fail messages
  169. //!
  170. //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
  171. //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
  172. //!
  173. //! In the future, this class could be extended to support dumping test results to a file in some standard format
  174. //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
  175. //!
  176. //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
  177. //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
  178. //! library and messages coming from the sample.
  179. //!
  180. //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
  181. //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
  182. //! object.
  183. class Logger : public nvinfer1::ILogger
  184. {
  185. public:
  186. Logger(Severity severity = Severity::kWARNING)
  187. : mReportableSeverity(severity)
  188. {
  189. }
  190. //!
  191. //! \enum TestResult
  192. //! \brief Represents the state of a given test
  193. //!
  194. enum class TestResult
  195. {
  196. kRUNNING, //!< The test is running
  197. kPASSED, //!< The test passed
  198. kFAILED, //!< The test failed
  199. kWAIVED //!< The test was waived
  200. };
  201. //!
  202. //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
  203. //! \return The nvinfer1::ILogger associated with this Logger
  204. //!
  205. //! TODO Once all samples are updated to use this method to register the logger with TensorRT,
  206. //! we can eliminate the inheritance of Logger from ILogger
  207. //!
  208. nvinfer1::ILogger& getTRTLogger()
  209. {
  210. return *this;
  211. }
  212. //!
  213. //! \brief Implementation of the nvinfer1::ILogger::log() virtual method
  214. //!
  215. //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
  216. //! inheritance from nvinfer1::ILogger
  217. //!
  218. void log(Severity severity, const char* msg) noexcept override
  219. {
  220. LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
  221. }
  222. //!
  223. //! \brief Method for controlling the verbosity of logging output
  224. //!
  225. //! \param severity The logger will only emit messages that have severity of this level or higher.
  226. //!
  227. void setReportableSeverity(Severity severity)
  228. {
  229. mReportableSeverity = severity;
  230. }
  231. //!
  232. //! \brief Opaque handle that holds logging information for a particular test
  233. //!
  234. //! This object is an opaque handle to information used by the Logger to print test results.
  235. //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
  236. //! with Logger::reportTest{Start,End}().
  237. //!
  238. class TestAtom
  239. {
  240. public:
  241. TestAtom(TestAtom&&) = default;
  242. private:
  243. friend class Logger;
  244. TestAtom(bool started, const std::string& name, const std::string& cmdline)
  245. : mStarted(started)
  246. , mName(name)
  247. , mCmdline(cmdline)
  248. {
  249. }
  250. bool mStarted;
  251. std::string mName;
  252. std::string mCmdline;
  253. };
  254. //!
  255. //! \brief Define a test for logging
  256. //!
  257. //! \param[in] name The name of the test. This should be a string starting with
  258. //! "TensorRT" and containing dot-separated strings containing
  259. //! the characters [A-Za-z0-9_].
  260. //! For example, "TensorRT.sample_googlenet"
  261. //! \param[in] cmdline The command line used to reproduce the test
  262. //
  263. //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
  264. //!
  265. static TestAtom defineTest(const std::string& name, const std::string& cmdline)
  266. {
  267. return TestAtom(false, name, cmdline);
  268. }
  269. //!
  270. //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
  271. //! as input
  272. //!
  273. //! \param[in] name The name of the test
  274. //! \param[in] argc The number of command-line arguments
  275. //! \param[in] argv The array of command-line arguments (given as C strings)
  276. //!
  277. //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
  278. static TestAtom defineTest(const std::string& name, int argc, char const* const* argv)
  279. {
  280. auto cmdline = genCmdlineString(argc, argv);
  281. return defineTest(name, cmdline);
  282. }
  283. //!
  284. //! \brief Report that a test has started.
  285. //!
  286. //! \pre reportTestStart() has not been called yet for the given testAtom
  287. //!
  288. //! \param[in] testAtom The handle to the test that has started
  289. //!
  290. static void reportTestStart(TestAtom& testAtom)
  291. {
  292. reportTestResult(testAtom, TestResult::kRUNNING);
  293. assert(!testAtom.mStarted);
  294. testAtom.mStarted = true;
  295. }
  296. //!
  297. //! \brief Report that a test has ended.
  298. //!
  299. //! \pre reportTestStart() has been called for the given testAtom
  300. //!
  301. //! \param[in] testAtom The handle to the test that has ended
  302. //! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
  303. //! TestResult::kFAILED, TestResult::kWAIVED
  304. //!
  305. static void reportTestEnd(const TestAtom& testAtom, TestResult result)
  306. {
  307. assert(result != TestResult::kRUNNING);
  308. assert(testAtom.mStarted);
  309. reportTestResult(testAtom, result);
  310. }
  311. static int reportPass(const TestAtom& testAtom)
  312. {
  313. reportTestEnd(testAtom, TestResult::kPASSED);
  314. return EXIT_SUCCESS;
  315. }
  316. static int reportFail(const TestAtom& testAtom)
  317. {
  318. reportTestEnd(testAtom, TestResult::kFAILED);
  319. return EXIT_FAILURE;
  320. }
  321. static int reportWaive(const TestAtom& testAtom)
  322. {
  323. reportTestEnd(testAtom, TestResult::kWAIVED);
  324. return EXIT_SUCCESS;
  325. }
  326. static int reportTest(const TestAtom& testAtom, bool pass)
  327. {
  328. return pass ? reportPass(testAtom) : reportFail(testAtom);
  329. }
  330. Severity getReportableSeverity() const
  331. {
  332. return mReportableSeverity;
  333. }
  334. private:
  335. //!
  336. //! \brief returns an appropriate string for prefixing a log message with the given severity
  337. //!
  338. static const char* severityPrefix(Severity severity)
  339. {
  340. switch (severity)
  341. {
  342. case Severity::kINTERNAL_ERROR: return "[F] ";
  343. case Severity::kERROR: return "[E] ";
  344. case Severity::kWARNING: return "[W] ";
  345. case Severity::kINFO: return "[I] ";
  346. case Severity::kVERBOSE: return "[V] ";
  347. default: assert(0); return "";
  348. }
  349. }
  350. //!
  351. //! \brief returns an appropriate string for prefixing a test result message with the given result
  352. //!
  353. static const char* testResultString(TestResult result)
  354. {
  355. switch (result)
  356. {
  357. case TestResult::kRUNNING: return "RUNNING";
  358. case TestResult::kPASSED: return "PASSED";
  359. case TestResult::kFAILED: return "FAILED";
  360. case TestResult::kWAIVED: return "WAIVED";
  361. default: assert(0); return "";
  362. }
  363. }
  364. //!
  365. //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
  366. //!
  367. static std::ostream& severityOstream(Severity severity)
  368. {
  369. return severity >= Severity::kINFO ? std::cout : std::cerr;
  370. }
  371. //!
  372. //! \brief method that implements logging test results
  373. //!
  374. static void reportTestResult(const TestAtom& testAtom, TestResult result)
  375. {
  376. severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
  377. << testAtom.mCmdline << std::endl;
  378. }
  379. //!
  380. //! \brief generate a command line string from the given (argc, argv) values
  381. //!
  382. static std::string genCmdlineString(int argc, char const* const* argv)
  383. {
  384. std::stringstream ss;
  385. for (int i = 0; i < argc; i++)
  386. {
  387. if (i > 0)
  388. ss << " ";
  389. ss << argv[i];
  390. }
  391. return ss.str();
  392. }
  393. Severity mReportableSeverity;
  394. };
  395. namespace
  396. {
  397. //!
  398. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
  399. //!
  400. //! Example usage:
  401. //!
  402. //! LOG_VERBOSE(logger) << "hello world" << std::endl;
  403. //!
  404. inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
  405. {
  406. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);
  407. }
  408. //!
  409. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
  410. //!
  411. //! Example usage:
  412. //!
  413. //! LOG_INFO(logger) << "hello world" << std::endl;
  414. //!
  415. inline LogStreamConsumer LOG_INFO(const Logger& logger)
  416. {
  417. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);
  418. }
  419. //!
  420. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
  421. //!
  422. //! Example usage:
  423. //!
  424. //! LOG_WARN(logger) << "hello world" << std::endl;
  425. //!
  426. inline LogStreamConsumer LOG_WARN(const Logger& logger)
  427. {
  428. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);
  429. }
  430. //!
  431. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
  432. //!
  433. //! Example usage:
  434. //!
  435. //! LOG_ERROR(logger) << "hello world" << std::endl;
  436. //!
  437. inline LogStreamConsumer LOG_ERROR(const Logger& logger)
  438. {
  439. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);
  440. }
  441. //!
  442. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
  443. // ("fatal" severity)
  444. //!
  445. //! Example usage:
  446. //!
  447. //! LOG_FATAL(logger) << "hello world" << std::endl;
  448. //!
  449. inline LogStreamConsumer LOG_FATAL(const Logger& logger)
  450. {
  451. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);
  452. }
  453. } // anonymous namespace
  454. #endif // TENSORRT_LOGGING_H

utils.h: 

  1. #pragma once
  2. #include <algorithm>
  3. #include <fstream>
  4. #include <iostream>
  5. #include <opencv2/opencv.hpp>
  6. #include <vector>
  7. #include <chrono>
  8. #include <cmath>
  9. #include <numeric> // std::iota
  10. using namespace cv;
  11. #define CHECK(status) \
  12. do\
  13. {\
  14. auto ret = (status);\
  15. if (ret != 0)\
  16. {\
  17. std::cerr << "Cuda failure: " << ret << std::endl;\
  18. abort();\
  19. }\
  20. } while (0)
  21. struct alignas(float) Detection {
  22. //center_x center_y w h
  23. float bbox[4];
  24. float conf; // bbox_conf * cls_conf
  25. int class_id;
  26. };
  27. static inline cv::Mat preprocess_img(cv::Mat& img, int input_w, int input_h, std::vector<int>& padsize) {
  28. int w, h, x, y;
  29. float r_w = input_w / (img.cols*1.0);
  30. float r_h = input_h / (img.rows*1.0);
  31. if (r_h > r_w) {//宽大于高
  32. w = input_w;
  33. h = r_w * img.rows;
  34. x = 0;
  35. y = (input_h - h) / 2;
  36. }
  37. else {
  38. w = r_h * img.cols;
  39. h = input_h;
  40. x = (input_w - w) / 2;
  41. y = 0;
  42. }
  43. cv::Mat re(h, w, CV_8UC3);
  44. cv::resize(img, re, re.size(), 0, 0, cv::INTER_LINEAR);
  45. cv::Mat out(input_h, input_w, CV_8UC3, cv::Scalar(128, 128, 128));
  46. re.copyTo(out(cv::Rect(x, y, re.cols, re.rows)));
  47. padsize.push_back(h);
  48. padsize.push_back(w);
  49. padsize.push_back(y);
  50. padsize.push_back(x);// int newh = padsize[0], neww = padsize[1], padh = padsize[2], padw = padsize[3];
  51. return out;
  52. }
  53. cv::Rect get_rect(cv::Mat& img, float bbox[4], int INPUT_W, int INPUT_H) {
  54. int l, r, t, b;
  55. float r_w = INPUT_W / (img.cols * 1.0);
  56. float r_h = INPUT_H / (img.rows * 1.0);
  57. if (r_h > r_w) {
  58. l = bbox[0];
  59. r = bbox[2];
  60. t = bbox[1]- (INPUT_H - r_w * img.rows) / 2;
  61. b = bbox[3] - (INPUT_H - r_w * img.rows) / 2;
  62. l = l / r_w;
  63. r = r / r_w;
  64. t = t / r_w;
  65. b = b / r_w;
  66. }
  67. else {
  68. l = bbox[0] - bbox[2] / 2.f - (INPUT_W - r_h * img.cols) / 2;
  69. r = bbox[0] + bbox[2] / 2.f - (INPUT_W - r_h * img.cols) / 2;
  70. t = bbox[1] - bbox[3] / 2.f;
  71. b = bbox[1] + bbox[3] / 2.f;
  72. l = l / r_h;
  73. r = r / r_h;
  74. t = t / r_h;
  75. b = b / r_h;
  76. }
  77. return cv::Rect(l, t, r - l, b - t);
  78. }

yolo.hpp: 

  1. #pragma once
  2. #include "NvInfer.h"
  3. #include "cuda_runtime_api.h"
  4. #include "NvInferPlugin.h"
  5. #include "logging.h"
  6. #include <opencv2/opencv.hpp>
  7. #include "utils.h"
  8. #include <string>
  9. using namespace nvinfer1;
  10. using namespace cv;
  11. // stuff we know about the network and the input/output blobs
  12. static const int batchSize = 1;
  13. static const int INPUT_H = 640;
  14. static const int INPUT_W = 640;
  15. static const int _segWidth = 160;
  16. static const int _segHeight = 160;
  17. static const int _segChannels = 32;
  18. static const int CLASSES = 80;
  19. static const int Num_box = 8400;
  20. static const int OUTPUT_SIZE = batchSize * Num_box * (CLASSES + 4 + _segChannels);//output0
  21. //分割的输出头尺寸大小,输出是32*160*160
  22. static const int OUTPUT_SIZE1 = batchSize * _segChannels * _segWidth * _segHeight;//output1
  23. static const int INPUT_SIZE = batchSize * 3 * INPUT_H * INPUT_W;//images
  24. //置信度阈值
  25. static const float CONF_THRESHOLD = 0.5;
  26. //nms阈值
  27. static const float NMS_THRESHOLD = 0.5;
  28. //mask阈值
  29. static const float MASK_THRESHOLD = 0.5;
  30. //输入结点名称
  31. const char* INPUT_BLOB_NAME = "images";
  32. //检测头的输出结点名称
  33. const char* OUTPUT_BLOB_NAME = "output0";//detect
  34. //分割头的输出结点名称
  35. const char* OUTPUT_BLOB_NAME1 = "output1";//mask
  36. //定义两个静态浮点,用于保存两个输出头的输出结果
  37. static float prob[OUTPUT_SIZE]; //box
  38. static float prob1[OUTPUT_SIZE1]; //mask
  39. static Logger gLogger;
  40. struct OutputSeg {
  41. int id; //结果类别id
  42. float confidence; //结果置信度
  43. cv::Rect box; //矩形框
  44. cv::Mat boxMask; //矩形框内mask,节省内存空间和加快速度
  45. };
  46. //中间储存
  47. struct OutputObject
  48. {
  49. std::vector<int> classIds;//结果id数组
  50. std::vector<float> confidences;//结果每个id对应置信度数组
  51. std::vector<cv::Rect> boxes;//每个id矩形框
  52. std::vector<std::vector<float>> picked_proposals; //存储output0[:,:, 5 + _className.size():net_width]用以后续计算mask
  53. };
  54. const float color_list[80][3] =
  55. {
  56. {0.000, 0.447, 0.741},
  57. {0.850, 0.325, 0.098},
  58. {0.929, 0.694, 0.125},
  59. {0.494, 0.184, 0.556},
  60. {0.466, 0.674, 0.188},
  61. {0.301, 0.745, 0.933},
  62. {0.635, 0.078, 0.184},
  63. {0.300, 0.300, 0.300},
  64. {0.600, 0.600, 0.600},
  65. {1.000, 0.000, 0.000},
  66. {1.000, 0.500, 0.000},
  67. {0.749, 0.749, 0.000},
  68. {0.000, 1.000, 0.000},
  69. {0.000, 0.000, 1.000},
  70. {0.667, 0.000, 1.000},
  71. {0.333, 0.333, 0.000},
  72. {0.333, 0.667, 0.000},
  73. {0.333, 1.000, 0.000},
  74. {0.667, 0.333, 0.000},
  75. {0.667, 0.667, 0.000},
  76. {0.667, 1.000, 0.000},
  77. {1.000, 0.333, 0.000},
  78. {1.000, 0.667, 0.000},
  79. {1.000, 1.000, 0.000},
  80. {0.000, 0.333, 0.500},
  81. {0.000, 0.667, 0.500},
  82. {0.000, 1.000, 0.500},
  83. {0.333, 0.000, 0.500},
  84. {0.333, 0.333, 0.500},
  85. {0.333, 0.667, 0.500},
  86. {0.333, 1.000, 0.500},
  87. {0.667, 0.000, 0.500},
  88. {0.667, 0.333, 0.500},
  89. {0.667, 0.667, 0.500},
  90. {0.667, 1.000, 0.500},
  91. {1.000, 0.000, 0.500},
  92. {1.000, 0.333, 0.500},
  93. {1.000, 0.667, 0.500},
  94. {1.000, 1.000, 0.500},
  95. {0.000, 0.333, 1.000},
  96. {0.000, 0.667, 1.000},
  97. {0.000, 1.000, 1.000},
  98. {0.333, 0.000, 1.000},
  99. {0.333, 0.333, 1.000},
  100. {0.333, 0.667, 1.000},
  101. {0.333, 1.000, 1.000},
  102. {0.667, 0.000, 1.000},
  103. {0.667, 0.333, 1.000},
  104. {0.667, 0.667, 1.000},
  105. {0.667, 1.000, 1.000},
  106. {1.000, 0.000, 1.000},
  107. {1.000, 0.333, 1.000},
  108. {1.000, 0.667, 1.000},
  109. {0.333, 0.000, 0.000},
  110. {0.500, 0.000, 0.000},
  111. {0.667, 0.000, 0.000},
  112. {0.833, 0.000, 0.000},
  113. {1.000, 0.000, 0.000},
  114. {0.000, 0.167, 0.000},
  115. {0.000, 0.333, 0.000},
  116. {0.000, 0.500, 0.000},
  117. {0.000, 0.667, 0.000},
  118. {0.000, 0.833, 0.000},
  119. {0.000, 1.000, 0.000},
  120. {0.000, 0.000, 0.167},
  121. {0.000, 0.000, 0.333},
  122. {0.000, 0.000, 0.500},
  123. {0.000, 0.000, 0.667},
  124. {0.000, 0.000, 0.833},
  125. {0.000, 0.000, 1.000},
  126. {0.000, 0.000, 0.000},
  127. {0.143, 0.143, 0.143},
  128. {0.286, 0.286, 0.286},
  129. {0.429, 0.429, 0.429},
  130. {0.571, 0.571, 0.571},
  131. {0.714, 0.714, 0.714},
  132. {0.857, 0.857, 0.857},
  133. {0.000, 0.447, 0.741},
  134. {0.314, 0.717, 0.741},
  135. {0.50, 0.5, 0}
  136. };
  137. //output中,包含了经过处理的id、conf、box和maskiamg信息
  138. static void DrawPred(Mat& img, std::vector<OutputSeg> result) {
  139. //生成随机颜色
  140. std::vector<Scalar> color;
  141. //这行代码的作用是将当前系统时间作为随机数种子,使得每次程序运行时都会生成不同的随机数序列。
  142. srand(time(0));
  143. //根据类别数,生成不同的颜色
  144. for (int i = 0; i < CLASSES; i++) {
  145. int b = rand() % 256;
  146. int g = rand() % 256;
  147. int r = rand() % 256;
  148. color.push_back(Scalar(b, g, r));
  149. }
  150. Mat mask = img.clone();
  151. for (int i = 0; i < result.size(); i++) {
  152. int left, top;
  153. left = result[i].box.x;
  154. top = result[i].box.y;
  155. int color_num = i;
  156. //画矩形框,颜色是上面选的
  157. rectangle(img, result[i].box, color[result[i].id], 2, 8);
  158. //将box中的result[i].boxMask区域涂成color[result[i].id]颜色
  159. mask(result[i].box).setTo(color[result[i].id], result[i].boxMask);
  160. std::string label = std::to_string(result[i].id) + ":" + std::to_string(result[i].confidence);
  161. int baseLine;
  162. //获取标签文本的尺寸
  163. Size labelSize = getTextSize(label, FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine);
  164. //确定一个最大的高
  165. top = max(top, labelSize.height);
  166. //把文本信息加到图像上
  167. putText(img, label, Point(left, top), FONT_HERSHEY_SIMPLEX, 1, color[result[i].id], 2);
  168. }
  169. //用于对图像的加权融合
  170. //图像1、图像1权重、图像2、图像2权重,添加结果中的标量、输出图像
  171. addWeighted(img, 0.5, mask, 0.5, 0, img); //将mask加在原图上面
  172. }
  173. //输入引擎文本、图像数据、定义的检测输出和分割输出、batchSize
  174. static void doInference(IExecutionContext& context, float* input, float* output, float* output1, int batchSize)
  175. {
  176. //从上下文中获取一个CUDA引擎。这个引擎加载了一个深度学习模型
  177. const ICudaEngine& engine = context.getEngine();
  178. //判断该引擎是否有三个绑定,intput, output0, output1
  179. assert(engine.getNbBindings() == 3);
  180. //定义了一个指向void的指针数组,用于存储GPU缓冲区的地址
  181. void* buffers[3];
  182. //获取输入和输出blob的索引,这些索引用于之后的缓冲区操作
  183. const int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
  184. const int outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);
  185. const int outputIndex1 = engine.getBindingIndex(OUTPUT_BLOB_NAME1);
  186. // 使用cudaMalloc分配了GPU内存。这些内存将用于存储模型的输入和输出
  187. CHECK(cudaMalloc(&buffers[inputIndex], INPUT_SIZE * sizeof(float)));//
  188. CHECK(cudaMalloc(&buffers[outputIndex], OUTPUT_SIZE * sizeof(float)));
  189. CHECK(cudaMalloc(&buffers[outputIndex1], OUTPUT_SIZE1 * sizeof(float)));
  190. // cudaMalloc分配内存 cudaFree释放内存 cudaMemcpy或 cudaMemcpyAsync 在主机和设备之间传输数据
  191. // cudaMemcpy cudaMemcpyAsync 显式地阻塞传输 显式地非阻塞传输
  192. //创建一个CUDA流。CUDA流是一种特殊的并发执行环境,可以在其中安排任务以并发执行。流使得任务可以并行执行,从而提高了GPU的利用率。
  193. cudaStream_t stream;
  194. //判断是否创建成功
  195. CHECK(cudaStreamCreate(&stream));
  196. // 使用cudaMemcpyAsync将输入数据异步地复制到GPU缓冲区。这个操作是非阻塞的,意味着它不会立即完成。
  197. CHECK(cudaMemcpyAsync(buffers[inputIndex], input, INPUT_SIZE * sizeof(float), cudaMemcpyHostToDevice, stream));
  198. //将输入和输出缓冲区以及流添加到上下文的执行队列中。这将触发模型的推理。
  199. context.enqueue(batchSize, buffers, stream, nullptr);
  200. //使用cudaMemcpyAsync函数将GPU上的数据复制到主内存中。这是异步的,意味着该函数立即返回,而数据传输可以在后台进行。
  201. CHECK(cudaMemcpyAsync(output, buffers[outputIndex], OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
  202. CHECK(cudaMemcpyAsync(output1, buffers[outputIndex1], OUTPUT_SIZE1 * sizeof(float), cudaMemcpyDeviceToHost, stream));
  203. //等待所有在给定流上的操作都完成。这可以确保在释放流和缓冲区之前,所有的数据都已经被复制完毕。
  204. //这对于保证内存操作的正确性和防止数据竞争非常重要。
  205. cudaStreamSynchronize(stream);
  206. //释放内存
  207. cudaStreamDestroy(stream);
  208. CHECK(cudaFree(buffers[inputIndex]));
  209. CHECK(cudaFree(buffers[outputIndex]));
  210. CHECK(cudaFree(buffers[outputIndex1]));
  211. }
  212. //
  213. //检测YOLO类
  214. class YOLO
  215. {
  216. public:
  217. void init(std::string engine_path);
  218. void init(char* engine_path);
  219. void destroy();
  220. void blobFromImage(cv::Mat& img, float* data);
  221. void decode_boxs(cv::Mat& src, float* prob, OutputObject& outputObject, std::vector<int> padsize);
  222. void nms_outputs(cv::Mat& src, OutputObject& outputObject, std::vector<std::vector<float>>& temp_mask_proposals, std::vector<OutputSeg>& output);
  223. void decode_mask(cv::Mat& src, float* prob1, std::vector<int> padsize, std::vector<std::vector<float>>& temp_mask_proposals, std::vector<OutputSeg>& output);
  224. void drawMask(Mat& img, std::vector<OutputSeg> result);
  225. void detect_img(std::string image_path);
  226. void detect_img(std::string image_path, float(*res_array)[6], uchar(*mask_array));
  227. private:
  228. ICudaEngine* engine;
  229. IRuntime* runtime;
  230. IExecutionContext* context;
  231. };
  232. void YOLO::destroy()
  233. {
  234. this->context->destroy();
  235. this->engine->destroy();
  236. this->runtime->destroy();
  237. }
  238. void YOLO::init(std::string engine_path)
  239. {
  240. //无符号整型类型,通常用于表示对象的大小或计数
  241. //{ 0 }: 这是初始化列表,用于初始化 size 变量。在这种情况下,size 被初始化为 0。
  242. size_t size{ 0 };
  243. //定义一个指针变量,通过trtModelStream = new char[size];分配size个字符的空间
  244. //nullptr表示指针针在开始时不指向任何有效的内存地址,空指针
  245. char* trtModelStream{ nullptr };
  246. //打开文件,即engine模型
  247. std::ifstream file(engine_path, std::ios::binary);
  248. if (file.good())
  249. {
  250. //指向文件的最后地址
  251. file.seekg(0, file.end);
  252. //计算文件的长度
  253. size = file.tellg();
  254. //指回文件的起始地址
  255. file.seekg(0, file.beg);
  256. //为trtModelStream指针分配内存,内存大小为size
  257. trtModelStream = new char[size]; //开辟一个char 长度是文件的长度
  258. assert(trtModelStream);
  259. //把file内容传递给trtModelStream,传递大小为size,即engine模型内容传递
  260. file.read(trtModelStream, size);
  261. //关闭文件
  262. file.close();
  263. }
  264. std::cout << "engine init finished" << std::endl;
  265. //创建了一个Inference运行时环境,返回一个指向新创建的运行时环境的指针
  266. runtime = createInferRuntime(gLogger);
  267. assert(runtime != nullptr);
  268. //反序列化一个CUDA引擎。这个引擎将用于执行模型的前向传播
  269. engine = runtime->deserializeCudaEngine(trtModelStream, size);
  270. assert(engine != nullptr);
  271. //使用上一步中创建的引擎创建一个执行上下文。这个上下文将在模型的前向传播期间使用
  272. context = engine->createExecutionContext();
  273. assert(context != nullptr);
  274. //释放了用于存储模型序列化的内存
  275. delete[] trtModelStream;
  276. }
  277. void YOLO::init(char* engine_path)
  278. {
  279. //无符号整型类型,通常用于表示对象的大小或计数
  280. //{ 0 }: 这是初始化列表,用于初始化 size 变量。在这种情况下,size 被初始化为 0。
  281. size_t size{ 0 };
  282. //定义一个指针变量,通过trtModelStream = new char[size];分配size个字符的空间
  283. //nullptr表示指针针在开始时不指向任何有效的内存地址,空指针
  284. char* trtModelStream{ nullptr };
  285. //打开文件,即engine模型
  286. std::ifstream file(engine_path, std::ios::binary);
  287. if (file.good())
  288. {
  289. //指向文件的最后地址
  290. file.seekg(0, file.end);
  291. //计算文件的长度
  292. size = file.tellg();
  293. //指回文件的起始地址
  294. file.seekg(0, file.beg);
  295. //为trtModelStream指针分配内存,内存大小为size
  296. trtModelStream = new char[size]; //开辟一个char 长度是文件的长度
  297. assert(trtModelStream);
  298. //把file内容传递给trtModelStream,传递大小为size,即engine模型内容传递
  299. file.read(trtModelStream, size);
  300. //关闭文件
  301. file.close();
  302. }
  303. std::cout << "engine init finished" << std::endl;
  304. //创建了一个Inference运行时环境,返回一个指向新创建的运行时环境的指针
  305. runtime = createInferRuntime(gLogger);
  306. assert(runtime != nullptr);
  307. //反序列化一个CUDA引擎。这个引擎将用于执行模型的前向传播
  308. engine = runtime->deserializeCudaEngine(trtModelStream, size);
  309. assert(engine != nullptr);
  310. //使用上一步中创建的引擎创建一个执行上下文。这个上下文将在模型的前向传播期间使用
  311. context = engine->createExecutionContext();
  312. assert(context != nullptr);
  313. //释放了用于存储模型序列化的内存
  314. delete[] trtModelStream;
  315. }
  316. void YOLO::blobFromImage(cv::Mat& src, float* data)
  317. {
  318. //定义一个浮点数组
  319. //float* data = new float[3 * INPUT_H * INPUT_W];
  320. int i = 0;// [1,3,INPUT_H,INPUT_W]
  321. for (int row = 0; row < INPUT_H; ++row)
  322. {
  323. //逐行对象素值和图像通道进行处理
  324. //pr_img.step=widthx3 就是每一行有width个3通道的值
  325. //第row行
  326. uchar* uc_pixel = src.data + row * src.step;
  327. for (int col = 0; col < INPUT_W; ++col)
  328. {
  329. //第col列
  330. //提取第第row行第col列数据进行处理
  331. //像素值处理
  332. data[i] = (float)uc_pixel[2] / 255.0;
  333. //通道变换
  334. data[i + INPUT_H * INPUT_W] = (float)uc_pixel[1] / 255.0;
  335. data[i + 2 * INPUT_H * INPUT_W] = (float)uc_pixel[0] / 255.0;
  336. uc_pixel += 3;//表示进行下一列
  337. ++i;//表示在3个通道中的第i个位置,rgb三个通道的值是分开的,如r123456g123456b123456
  338. }
  339. }
  340. //return data;
  341. }
  342. void YOLO::decode_boxs(cv::Mat& src, float* prob, OutputObject& outputObject, std::vector<int> padsize)
  343. {
  344. int newh = padsize[0], neww = padsize[1], padh = padsize[2], padw = padsize[3];
  345. float ratio_h = (float)src.rows / newh;
  346. float ratio_w = (float)src.cols / neww;
  347. // 处理box
  348. int net_length = CLASSES + 4 + _segChannels;
  349. float* pdata = prob;
  350. cv::Mat out1 = cv::Mat(net_length, Num_box, CV_32F, prob);
  351. for (int j = 0; j < Num_box; ++j)
  352. {
  353. //输出是1*net_length*Num_box;所以每个box的属性是每隔Num_box取一个值,共net_length个值
  354. cv::Mat scores = out1(Rect(j, 4, 1, CLASSES)).clone();
  355. Point classIdPoint;
  356. double max_class_socre;
  357. minMaxLoc(scores, 0, &max_class_socre, 0, &classIdPoint);
  358. max_class_socre = (float)max_class_socre;
  359. if (max_class_socre >= CONF_THRESHOLD) {
  360. cv::Mat temp_proto = out1(Rect(j, 4 + CLASSES, 1, _segChannels)).clone();
  361. outputObject.picked_proposals.push_back(temp_proto);
  362. float x = (out1.at<float>(0, j) - padw) * ratio_w; //cx
  363. float y = (out1.at<float>(1, j) - padh) * ratio_h; //cy
  364. float w = out1.at<float>(2, j) * ratio_w; //w
  365. float h = out1.at<float>(3, j) * ratio_h; //h
  366. int left = MAX((x - 0.5 * w), 0);
  367. int top = MAX((y - 0.5 * h), 0);
  368. int width = (int)w;
  369. int height = (int)h;
  370. if (width <= 0 || height <= 0) { continue; }
  371. outputObject.classIds.push_back(classIdPoint.y);
  372. outputObject.confidences.push_back(max_class_socre);
  373. outputObject.boxes.push_back(Rect(left, top, width, height));
  374. }
  375. }
  376. }
  377. void YOLO::nms_outputs(cv::Mat& src, OutputObject& outputObject, std::vector<std::vector<float>>& temp_mask_proposals, std::vector<OutputSeg>& output)
  378. {
  379. //执行非最大抑制以消除具有较低置信度的冗余重叠框(NMS)
  380. std::vector<int> nms_result;
  381. //通过opencv自带的nms函数进行,矩阵box、置信度大小,置信度阈值,nms阈值,结果
  382. cv::dnn::NMSBoxes(outputObject.boxes, outputObject.confidences, CONF_THRESHOLD, NMS_THRESHOLD, nms_result);
  383. 包括类别、置信度、框和mask
  384. //std::vector<std::vector<float>> temp_mask_proposals;
  385. //创建一个名为holeImgRect的Rect对象
  386. Rect holeImgRect(0, 0, src.cols, src.rows);
  387. //提取经过非极大值抑制后的结果
  388. for (int i = 0; i < nms_result.size(); ++i) {
  389. int idx = nms_result[i];
  390. OutputSeg result;
  391. result.id = outputObject.classIds[idx];
  392. result.confidence = outputObject.confidences[idx];
  393. result.box = outputObject.boxes[idx] & holeImgRect;
  394. output.push_back(result);
  395. temp_mask_proposals.push_back(outputObject.picked_proposals[idx]);
  396. }
  397. }
  398. void YOLO::decode_mask(cv::Mat& src, float* prob1, std::vector<int> padsize, std::vector<std::vector<float>>& temp_mask_proposals, std::vector<OutputSeg>& output)
  399. {
  400. int newh = padsize[0], neww = padsize[1], padh = padsize[2], padw = padsize[3];
  401. // 处理mask
  402. Mat maskProposals;
  403. for (int i = 0; i < temp_mask_proposals.size(); ++i)
  404. //std::cout<< Mat(temp_mask_proposals[i]).t().size();
  405. maskProposals.push_back(Mat(temp_mask_proposals[i]).t());
  406. //开始处理分割头的输出32*160*160
  407. //把分割结果重构为32,160*160
  408. float* pdata = prob1;
  409. std::vector<float> mask(pdata, pdata + _segChannels * _segWidth * _segHeight);
  410. Mat mask_protos = Mat(mask);
  411. Mat protos = mask_protos.reshape(0, { _segChannels,_segWidth * _segHeight });//将prob1的值 赋给mask_protos
  412. Mat matmulRes = (maskProposals * protos).t();//n*32 32*25600 A*B是以数学运算中矩阵相乘的方式实现的,要求A的列数等于B的行数时
  413. Mat masks = matmulRes.reshape(output.size(), { _segWidth,_segHeight });
  414. //std::cout << protos.size();
  415. std::vector<Mat> maskChannels;
  416. //将masks分割成多个通道,保存到maskChannels
  417. split(masks, maskChannels);
  418. //处理和获得原始图像中改变像素点颜色的区域
  419. for (int i = 0; i < output.size(); ++i) {
  420. Mat dest, mask;
  421. //进行sigmoid
  422. cv::exp(-maskChannels[i], dest);
  423. dest = 1.0 / (1.0 + dest);//160*160
  424. Rect roi(int((float)padw / INPUT_W * _segWidth), int((float)padh / INPUT_H * _segHeight), int(_segWidth - padw / 2), int(_segHeight - padh / 2));
  425. //截取相应区域,避免填充影响
  426. dest = dest(roi);
  427. //把mask的大小重构到原始图像大小
  428. resize(dest, mask, src.size(), INTER_NEAREST);
  429. //crop----截取box中的mask作为该box对应的mask
  430. Rect temp_rect = output[i].box;
  431. //判断mask中box区域的值是否大于mask阈值,大于为true,小于为false
  432. //提取出mask中与temp_rect相交的部分,然后判断这部分的值是否大于预设的阈值MASK_THRESHOLD。结果保存在mask中
  433. mask = mask(temp_rect) > MASK_THRESHOLD;
  434. //把掩码图像进行保存,大小和原图像大小一样,目标区域已经为true
  435. output[i].boxMask = mask;
  436. }
  437. }
  438. void YOLO::drawMask(Mat& img, std::vector<OutputSeg> result)
  439. {
  440. //生成随机颜色
  441. std::vector<Scalar> color;
  442. //这行代码的作用是将当前系统时间作为随机数种子,使得每次程序运行时都会生成不同的随机数序列。
  443. srand(time(0));
  444. //根据类别数,生成不同的颜色
  445. for (int i = 0; i < CLASSES; i++) {
  446. int b = color_list[i][0] * 255;
  447. int g = color_list[i][1] * 255;
  448. int r = color_list[i][2] * 255;
  449. color.push_back(Scalar(b, g, r));
  450. }
  451. Mat mask = img.clone();
  452. for (int i = 0; i < result.size(); i++) {
  453. //将box中的result[i].boxMask区域涂成color[result[i].id]颜色
  454. mask(result[i].box).setTo(color[result[i].id], result[i].boxMask);
  455. }
  456. //用于对图像的加权融合
  457. //图像1、图像1权重、图像2、图像2权重,添加结果中的标量、输出图像
  458. addWeighted(img, 0.5, mask, 0.5, 0, img); //将mask加在原图上面
  459. }
  460. void YOLO::detect_img(std::string image_path)
  461. {
  462. cv::Mat img = cv::imread(image_path);
  463. //图像预处理,输入的是原图像和网络输入的高和宽,填充尺寸容器
  464. //输出的是重构后的图像,以及每条边填充的大小保存在padsize
  465. cv::Mat pr_img;
  466. std::vector<int> padsize;
  467. pr_img = preprocess_img(img, INPUT_H, INPUT_W, padsize); // Resize
  468. float* blob = new float[3 * INPUT_H * INPUT_W];
  469. blobFromImage(pr_img, blob);
  470. //推理
  471. auto start = std::chrono::system_clock::now();
  472. doInference(*context, blob, prob, prob1, batchSize);
  473. auto end = std::chrono::system_clock::now();
  474. std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
  475. //解析并绘制output
  476. OutputObject outputObject;
  477. std::vector<OutputSeg> output;
  478. std::vector<std::vector<float>> temp_mask_proposals;
  479. decode_boxs(img, prob, outputObject, padsize);
  480. nms_outputs(img, outputObject, temp_mask_proposals, output);
  481. //判断如果没有识别出东西直接输出原图
  482. if (output.size() != 0)
  483. {
  484. decode_mask(img, prob1, padsize, temp_mask_proposals, output);
  485. }
  486. //绘制mask图像
  487. Mat mask = Mat(img.rows, img.cols, img.type(), Scalar(255, 255, 255));
  488. drawMask(mask, output);
  489. //传出数据
  490. //for (size_t j = 0; j < output.size(); j++)
  491. //{
  492. // res_array[j][0] = output[j].box.x;
  493. // res_array[j][1] = output[j].box.y;
  494. // res_array[j][2] = output[j].box.width;
  495. // res_array[j][3] = output[j].box.height;
  496. // res_array[j][4] = output[j].id;
  497. // res_array[j][5] = output[j].confidence;
  498. // //mask_array = output[j].boxMask.data;
  499. //}
  500. 传出mask数据
  501. //for (int i = 0; i < mask.rows; i++) {
  502. // for (int j = 0; j < mask.cols; j++) {
  503. // for (int k = 0; k < 3; k++) {
  504. // mask_array[i * mask.cols * 3 + j * 3 + k] = mask.at<cv::Vec3b>(i, j)[k];
  505. // }
  506. // }
  507. //}
  508. cv::imshow("output.jpg", mask);
  509. cv::imwrite("output.jpg", mask);
  510. cv::waitKey(0);
  511. delete[] blob;
  512. }
  513. void YOLO::detect_img(std::string image_path, float(*res_array)[6], uchar(*mask_array))
  514. {
  515. cv::Mat img = cv::imread(image_path);
  516. //图像预处理,输入的是原图像和网络输入的高和宽,填充尺寸容器
  517. //输出的是重构后的图像,以及每条边填充的大小保存在padsize
  518. cv::Mat pr_img;
  519. std::vector<int> padsize;
  520. pr_img = preprocess_img(img, INPUT_H, INPUT_W, padsize); // Resize
  521. float* blob = new float[3 * INPUT_H * INPUT_W];
  522. blobFromImage(pr_img, blob);
  523. //推理
  524. auto start = std::chrono::system_clock::now();
  525. doInference(*context, blob, prob, prob1, batchSize);
  526. auto end = std::chrono::system_clock::now();
  527. std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(end - start).count() << "ms" << std::endl;
  528. //解析并绘制output
  529. OutputObject outputObject;
  530. std::vector<OutputSeg> output;
  531. std::vector<std::vector<float>> temp_mask_proposals;
  532. decode_boxs(img, prob, outputObject, padsize);
  533. nms_outputs(img, outputObject, temp_mask_proposals, output);
  534. //判断如果没有识别出东西直接输出原图
  535. if (output.size() != 0)
  536. {
  537. decode_mask(img, prob1, padsize, temp_mask_proposals, output);
  538. }
  539. //绘制mask图像
  540. Mat mask = Mat(img.rows, img.cols, img.type(), Scalar(255, 255, 255));
  541. drawMask(mask, output);
  542. //传出数据
  543. for (size_t j = 0; j < output.size(); j++)
  544. {
  545. res_array[j][0] = output[j].box.x;
  546. res_array[j][1] = output[j].box.y;
  547. res_array[j][2] = output[j].box.width;
  548. res_array[j][3] = output[j].box.height;
  549. res_array[j][4] = output[j].id;
  550. res_array[j][5] = output[j].confidence;
  551. //mask_array = output[j].boxMask.data;
  552. }
  553. //传出mask数据
  554. for (int i = 0; i < mask.rows; i++) {
  555. for (int j = 0; j < mask.cols; j++) {
  556. for (int k = 0; k < 3; k++) {
  557. mask_array[i * mask.cols * 3 + j * 3 + k] = mask.at<cv::Vec3b>(i, j)[k];
  558. }
  559. }
  560. }
  561. delete[] blob;
  562. }

4、测试使用

        测试使用的话直接new一个yolo,使用里面函数就可以。下面给出例子:

  1. #include <iostream>
  2. #include "yolo.hpp"
  3. int main()
  4. {
  5. YOLO yolo;
  6. yolo.init("E:\\yolov8s_seg.engine");
  7. yolo.detect_img2("E:\\bus.jpg");
  8. yolo.destroy();
  9. }

5、C#调用

        我自己通常使用C#编写上位机程序,所以一般都是将C++封装成DLL使用。

C++封装:

  1. #include <iostream>
  2. #include "yolo.hpp"
  3. //int main()
  4. //{
  5. // YOLO yolo;
  6. // yolo.init("E:\\yolov8s_seg.engine");
  7. // yolo.detect_img2("E:\\bus.jpg");
  8. // yolo.destroy();
  9. //}
  10. YOLO yolo;
  11. extern "C" __declspec(dllexport) void Init(char* engine_path)
  12. {
  13. yolo.init(engine_path);
  14. return;
  15. }
  16. extern "C" __declspec(dllexport) void Detect_Img(std::string image_path, float(*res_array)[6], uchar(*mask_array))
  17. {
  18. yolo.detect_img(image_path, res_array, mask_array);
  19. return;
  20. }
  21. extern "C" __declspec(dllexport) void Destroy()
  22. {
  23. yolo.destroy();
  24. return;
  25. }

C#调用:

  1. [DllImport("engine_infer_mask.dll", CallingConvention = CallingConvention.Cdecl)]
  2. public static extern void Init(string engine_path);
  3. [DllImport("engine_infer_mask.dll", CallingConvention = CallingConvention.Cdecl)]
  4. public static extern void Detect_Img(string, path, float[,] resultArray, ref byte classArray);
  5. [DllImport("engine_infer_mask.dll", CallingConvention = CallingConvention.Cdecl)]
  6. public static extern void Destroy();

6、具体项目修改参数

         模型会有不同的训练参数,那么推理的参数也需要变化,主要修改的参数如图所示:

         我们使用 Netron打开对应onnx模型,可以看到images和output0、output1参数与模型一一对应,一般我们训练自己的模型都会修改CLASSES,根据自己参数对应修改即可。置信度阈值也根据自己需要进行修改。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/一键难忘520/article/detail/1013332
推荐阅读
相关标签
  

闽ICP备14008679号