当前位置:   article > 正文

torch版ResNet50(带有多输出)转c++ tensorrt_resnet50 tensorrt

resnet50 tensorrt

第一个最简单的lenet示例请参考这篇文章

一.torch阶段 

测试图片:

torch代码:

  1. # coding:utf-8
  2. import torch
  3. from torch import nn
  4. from torch.nn import functional as F
  5. import torchvision
  6. import os
  7. import struct
  8. import time
  9. import cv2
  10. import numpy as np
  11. def main():
  12. print('cuda device count: ', torch.cuda.device_count())
  13. os.environ["CUDA_VISIBLE_DEVICES"] = "1"
  14. model = torchvision.models.resnet50(pretrained=True)
  15. # net.fc = nn.Linear(512, 2)
  16. model = model.to('cuda:0')
  17. model.eval()
  18. # print(model)
  19. st_time = time.time()
  20. nums = 10000
  21. for i in range(nums):
  22. input_ = torch.ones(1, 3, 224, 224).to('cuda:0')
  23. out = model(input_)
  24. # print('====out.shape:===', out.shape)#(1, 1000)
  25. end_time = time.time()
  26. print('==avge cost time{}'.format((end_time - st_time)/nums))
  27. # input_ = torch.ones(1, 3, 224, 224).to('cuda:0')
  28. # save_pth(model, input_)#存储.pth
  29. # save_onnx(input_, model)#存储.onnx方便可视化网络
  30. # get_wts(model)#提取key value权重
  31. def save_pth(model, input_):
  32. conv1 = model.conv1(input_)
  33. print('===conv1.shape:', conv1.shape)
  34. # maxpool_1 = model.maxpool(conv1)
  35. # print('===maxpool_1.shape:', maxpool_1.shape)
  36. # layer1 = model.layer1(maxpool_1)
  37. # print('===layer1.shape:', layer1.shape)
  38. # layer2 = model.layer2(layer1)
  39. # print('===layer2.shape:', layer2.shape)
  40. # layer3 = model.layer3(layer2)
  41. # print('===layer3.shape:', layer3.shape)
  42. # layer4 = model.layer4(layer3)
  43. # print('===layer4.shape:', layer4.shape)
  44. # print('resnet50 out:', out.shape)
  45. torch.save(model, "resnet50.pth")
  46. def get_wts(model):
  47. f = open("resnet50.wts", 'w')
  48. f.write("{}\n".format(len(model.state_dict().keys())))
  49. for k, v in model.state_dict().items():
  50. # print('key: ', k)#weight name
  51. # print('value: ', v.shape)#weight shape
  52. vr = v.reshape(-1).cpu().numpy()
  53. f.write("{} {}".format(k, len(vr)))
  54. for vv in vr:
  55. f.write(" ")
  56. f.write(struct.pack(">f", float(vv)).hex())
  57. f.write("\n")
  58. def save_onnx(input_, model):
  59. # torch.onnx.export(model, input_, "./resnet50.onnx", verbose=True)
  60. torch.onnx.export(model, # model being run
  61. input_, # model input (or a tuple for multiple inputs)
  62. "./resnet50.onnx",
  63. opset_version=10,
  64. verbose=False, # store the trained parameter weights inside the model file
  65. training=False,
  66. do_constant_folding=True,
  67. input_names=['input'],
  68. output_names=['output']
  69. )
  70. def test_real_img():
  71. os.environ["CUDA_VISIBLE_DEVICES"] = "1"
  72. model = torchvision.models.resnet50(pretrained=True)
  73. # net.fc = nn.Linear(512, 2)
  74. model = model.to('cuda:0')
  75. model.eval()
  76. # print(model)
  77. img = cv2.imread('./test2.jpg')
  78. print('===img.shape', img.shape)
  79. img = cv2.resize(img, (224, 224))
  80. mean = np.array([0.406, 0.456, 0.485]).astype(np.float32)
  81. std = np.array([0.225, 0.224, 0.229]).astype(np.float32)
  82. img = (img / 255. - mean) / std
  83. img = np.expand_dims(img, axis=0)
  84. print('===img.shape', img.shape)
  85. img = np.transpose(img, (0, 3, 1, 2)).astype(np.float32)
  86. # img = np.ones((1, 3, 224, 224)).astype(np.float32)
  87. nums = 10000
  88. img = torch.from_numpy(img)
  89. st_time = time.time()
  90. for i in range(nums):
  91. with torch.no_grad():
  92. out = model(img.cuda())
  93. end_time = time.time()
  94. print('==avge cost time{}'.format((end_time - st_time) / nums))
  95. print('====out.shape:===', out.shape) # (1, 1000)
  96. with open('./pytorch_result.txt', 'w', encoding='utf-8') as file:
  97. for i in range(1000):
  98. file.write(str(out.cpu().numpy()[0][i]) + '\n')
  99. torch_value, torch_index = torch.max(out, dim=1)
  100. print('====torch_value:===', torch_value)#13.8998
  101. print('====torch_index:===', torch_index)#285 Egyptian cat
  102. topk = 5
  103. topk_index = torch.argsort(out, dim=1, descending=True)[:, :topk]
  104. print('===topk_index:', topk_index)
  105. out = out.cpu().numpy()
  106. index = np.where(out == np.max(out))
  107. print('===index:===', index)
  108. if __name__ == '__main__':
  109. # main()
  110. test_real_img()

其中:get_wts用于生成16进制权重文件,resnet50.wts,后续tensorrt载入模型权重。

save_onnx用于生成resnet50.onnx文件,可视化网络结构。

结果:

 

查找imageNet 索引285所对应的标签为:

生成.txt截图如下:

二.tensorrt转换阶段

2.1序列化生成.engine阶段

1.文件代码结构图

其中resnet50.wts是torch阶段生成的,resnet50.engine是本阶段要生成的。

2.代码:

logging.h

  1. /*
  2. * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef TENSORRT_LOGGING_H
  17. #define TENSORRT_LOGGING_H
  18. #include "NvInferRuntimeCommon.h"
  19. #include <cassert>
  20. #include <ctime>
  21. #include <iomanip>
  22. #include <iostream>
  23. #include <ostream>
  24. #include <sstream>
  25. #include <string>
  26. using Severity = nvinfer1::ILogger::Severity;
  27. class LogStreamConsumerBuffer : public std::stringbuf
  28. {
  29. public:
  30. LogStreamConsumerBuffer(std::ostream& stream, const std::string& prefix, bool shouldLog)
  31. : mOutput(stream)
  32. , mPrefix(prefix)
  33. , mShouldLog(shouldLog)
  34. {
  35. }
  36. LogStreamConsumerBuffer(LogStreamConsumerBuffer&& other)
  37. : mOutput(other.mOutput)
  38. {
  39. }
  40. ~LogStreamConsumerBuffer()
  41. {
  42. // std::streambuf::pbase() gives a pointer to the beginning of the buffered part of the output sequence
  43. // std::streambuf::pptr() gives a pointer to the current position of the output sequence
  44. // if the pointer to the beginning is not equal to the pointer to the current position,
  45. // call putOutput() to log the output to the stream
  46. if (pbase() != pptr())
  47. {
  48. putOutput();
  49. }
  50. }
  51. // synchronizes the stream buffer and returns 0 on success
  52. // synchronizing the stream buffer consists of inserting the buffer contents into the stream,
  53. // resetting the buffer and flushing the stream
  54. virtual int sync()
  55. {
  56. putOutput();
  57. return 0;
  58. }
  59. void putOutput()
  60. {
  61. if (mShouldLog)
  62. {
  63. // prepend timestamp
  64. std::time_t timestamp = std::time(nullptr);
  65. tm* tm_local = std::localtime(&timestamp);
  66. std::cout << "[";
  67. std::cout << std::setw(2) << std::setfill('0') << 1 + tm_local->tm_mon << "/";
  68. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_mday << "/";
  69. std::cout << std::setw(4) << std::setfill('0') << 1900 + tm_local->tm_year << "-";
  70. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_hour << ":";
  71. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_min << ":";
  72. std::cout << std::setw(2) << std::setfill('0') << tm_local->tm_sec << "] ";
  73. // std::stringbuf::str() gets the string contents of the buffer
  74. // insert the buffer contents pre-appended by the appropriate prefix into the stream
  75. mOutput << mPrefix << str();
  76. // set the buffer to empty
  77. str("");
  78. // flush the stream
  79. mOutput.flush();
  80. }
  81. }
  82. void setShouldLog(bool shouldLog)
  83. {
  84. mShouldLog = shouldLog;
  85. }
  86. private:
  87. std::ostream& mOutput;
  88. std::string mPrefix;
  89. bool mShouldLog;
  90. };
  91. //!
  92. //! \class LogStreamConsumerBase
  93. //! \brief Convenience object used to initialize LogStreamConsumerBuffer before std::ostream in LogStreamConsumer
  94. //!
  95. class LogStreamConsumerBase
  96. {
  97. public:
  98. LogStreamConsumerBase(std::ostream& stream, const std::string& prefix, bool shouldLog)
  99. : mBuffer(stream, prefix, shouldLog)
  100. {
  101. }
  102. protected:
  103. LogStreamConsumerBuffer mBuffer;
  104. };
  105. //!
  106. //! \class LogStreamConsumer
  107. //! \brief Convenience object used to facilitate use of C++ stream syntax when logging messages.
  108. //! Order of base classes is LogStreamConsumerBase and then std::ostream.
  109. //! This is because the LogStreamConsumerBase class is used to initialize the LogStreamConsumerBuffer member field
  110. //! in LogStreamConsumer and then the address of the buffer is passed to std::ostream.
  111. //! This is necessary to prevent the address of an uninitialized buffer from being passed to std::ostream.
  112. //! Please do not change the order of the parent classes.
  113. //!
  114. class LogStreamConsumer : protected LogStreamConsumerBase, public std::ostream
  115. {
  116. public:
  117. //! \brief Creates a LogStreamConsumer which logs messages with level severity.
  118. //! Reportable severity determines if the messages are severe enough to be logged.
  119. LogStreamConsumer(Severity reportableSeverity, Severity severity)
  120. : LogStreamConsumerBase(severityOstream(severity), severityPrefix(severity), severity <= reportableSeverity)
  121. , std::ostream(&mBuffer) // links the stream buffer with the stream
  122. , mShouldLog(severity <= reportableSeverity)
  123. , mSeverity(severity)
  124. {
  125. }
  126. LogStreamConsumer(LogStreamConsumer&& other)
  127. : LogStreamConsumerBase(severityOstream(other.mSeverity), severityPrefix(other.mSeverity), other.mShouldLog)
  128. , std::ostream(&mBuffer) // links the stream buffer with the stream
  129. , mShouldLog(other.mShouldLog)
  130. , mSeverity(other.mSeverity)
  131. {
  132. }
  133. void setReportableSeverity(Severity reportableSeverity)
  134. {
  135. mShouldLog = mSeverity <= reportableSeverity;
  136. mBuffer.setShouldLog(mShouldLog);
  137. }
  138. private:
  139. static std::ostream& severityOstream(Severity severity)
  140. {
  141. return severity >= Severity::kINFO ? std::cout : std::cerr;
  142. }
  143. static std::string severityPrefix(Severity severity)
  144. {
  145. switch (severity)
  146. {
  147. case Severity::kINTERNAL_ERROR: return "[F] ";
  148. case Severity::kERROR: return "[E] ";
  149. case Severity::kWARNING: return "[W] ";
  150. case Severity::kINFO: return "[I] ";
  151. case Severity::kVERBOSE: return "[V] ";
  152. default: assert(0); return "";
  153. }
  154. }
  155. bool mShouldLog;
  156. Severity mSeverity;
  157. };
  158. //! \class Logger
  159. //!
  160. //! \brief Class which manages logging of TensorRT tools and samples
  161. //!
  162. //! \details This class provides a common interface for TensorRT tools and samples to log information to the console,
  163. //! and supports logging two types of messages:
  164. //!
  165. //! - Debugging messages with an associated severity (info, warning, error, or internal error/fatal)
  166. //! - Test pass/fail messages
  167. //!
  168. //! The advantage of having all samples use this class for logging as opposed to emitting directly to stdout/stderr is
  169. //! that the logic for controlling the verbosity and formatting of sample output is centralized in one location.
  170. //!
  171. //! In the future, this class could be extended to support dumping test results to a file in some standard format
  172. //! (for example, JUnit XML), and providing additional metadata (e.g. timing the duration of a test run).
  173. //!
  174. //! TODO: For backwards compatibility with existing samples, this class inherits directly from the nvinfer1::ILogger
  175. //! interface, which is problematic since there isn't a clean separation between messages coming from the TensorRT
  176. //! library and messages coming from the sample.
  177. //!
  178. //! In the future (once all samples are updated to use Logger::getTRTLogger() to access the ILogger) we can refactor the
  179. //! class to eliminate the inheritance and instead make the nvinfer1::ILogger implementation a member of the Logger
  180. //! object.
  181. class Logger : public nvinfer1::ILogger
  182. {
  183. public:
  184. Logger(Severity severity = Severity::kWARNING)
  185. : mReportableSeverity(severity)
  186. {
  187. }
  188. //!
  189. //! \enum TestResult
  190. //! \brief Represents the state of a given test
  191. //!
  192. enum class TestResult
  193. {
  194. kRUNNING, //!< The test is running
  195. kPASSED, //!< The test passed
  196. kFAILED, //!< The test failed
  197. kWAIVED //!< The test was waived
  198. };
  199. //!
  200. //! \brief Forward-compatible method for retrieving the nvinfer::ILogger associated with this Logger
  201. //! \return The nvinfer1::ILogger associated with this Logger
  202. //!
  203. //! TODO Once all samples are updated to use this method to register the logger with TensorRT,
  204. //! we can eliminate the inheritance of Logger from ILogger
  205. //!
  206. nvinfer1::ILogger& getTRTLogger()
  207. {
  208. return *this;
  209. }
  210. //!
  211. //! \brief Implementation of the nvinfer1::ILogger::log() virtual method
  212. //!
  213. //! Note samples should not be calling this function directly; it will eventually go away once we eliminate the
  214. //! inheritance from nvinfer1::ILogger
  215. //!
  216. void log(Severity severity, const char* msg) override
  217. {
  218. LogStreamConsumer(mReportableSeverity, severity) << "[TRT] " << std::string(msg) << std::endl;
  219. }
  220. //!
  221. //! \brief Method for controlling the verbosity of logging output
  222. //!
  223. //! \param severity The logger will only emit messages that have severity of this level or higher.
  224. //!
  225. void setReportableSeverity(Severity severity)
  226. {
  227. mReportableSeverity = severity;
  228. }
  229. //!
  230. //! \brief Opaque handle that holds logging information for a particular test
  231. //!
  232. //! This object is an opaque handle to information used by the Logger to print test results.
  233. //! The sample must call Logger::defineTest() in order to obtain a TestAtom that can be used
  234. //! with Logger::reportTest{Start,End}().
  235. //!
  236. class TestAtom
  237. {
  238. public:
  239. TestAtom(TestAtom&&) = default;
  240. private:
  241. friend class Logger;
  242. TestAtom(bool started, const std::string& name, const std::string& cmdline)
  243. : mStarted(started)
  244. , mName(name)
  245. , mCmdline(cmdline)
  246. {
  247. }
  248. bool mStarted;
  249. std::string mName;
  250. std::string mCmdline;
  251. };
  252. //!
  253. //! \brief Define a test for logging
  254. //!
  255. //! \param[in] name The name of the test. This should be a string starting with
  256. //! "TensorRT" and containing dot-separated strings containing
  257. //! the characters [A-Za-z0-9_].
  258. //! For example, "TensorRT.sample_googlenet"
  259. //! \param[in] cmdline The command line used to reproduce the test
  260. //
  261. //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
  262. //!
  263. static TestAtom defineTest(const std::string& name, const std::string& cmdline)
  264. {
  265. return TestAtom(false, name, cmdline);
  266. }
  267. //!
  268. //! \brief A convenience overloaded version of defineTest() that accepts an array of command-line arguments
  269. //! as input
  270. //!
  271. //! \param[in] name The name of the test
  272. //! \param[in] argc The number of command-line arguments
  273. //! \param[in] argv The array of command-line arguments (given as C strings)
  274. //!
  275. //! \return a TestAtom that can be used in Logger::reportTest{Start,End}().
  276. static TestAtom defineTest(const std::string& name, int argc, char const* const* argv)
  277. {
  278. auto cmdline = genCmdlineString(argc, argv);
  279. return defineTest(name, cmdline);
  280. }
  281. //!
  282. //! \brief Report that a test has started.
  283. //!
  284. //! \pre reportTestStart() has not been called yet for the given testAtom
  285. //!
  286. //! \param[in] testAtom The handle to the test that has started
  287. //!
  288. static void reportTestStart(TestAtom& testAtom)
  289. {
  290. reportTestResult(testAtom, TestResult::kRUNNING);
  291. assert(!testAtom.mStarted);
  292. testAtom.mStarted = true;
  293. }
  294. //!
  295. //! \brief Report that a test has ended.
  296. //!
  297. //! \pre reportTestStart() has been called for the given testAtom
  298. //!
  299. //! \param[in] testAtom The handle to the test that has ended
  300. //! \param[in] result The result of the test. Should be one of TestResult::kPASSED,
  301. //! TestResult::kFAILED, TestResult::kWAIVED
  302. //!
  303. static void reportTestEnd(const TestAtom& testAtom, TestResult result)
  304. {
  305. assert(result != TestResult::kRUNNING);
  306. assert(testAtom.mStarted);
  307. reportTestResult(testAtom, result);
  308. }
  309. static int reportPass(const TestAtom& testAtom)
  310. {
  311. reportTestEnd(testAtom, TestResult::kPASSED);
  312. return EXIT_SUCCESS;
  313. }
  314. static int reportFail(const TestAtom& testAtom)
  315. {
  316. reportTestEnd(testAtom, TestResult::kFAILED);
  317. return EXIT_FAILURE;
  318. }
  319. static int reportWaive(const TestAtom& testAtom)
  320. {
  321. reportTestEnd(testAtom, TestResult::kWAIVED);
  322. return EXIT_SUCCESS;
  323. }
  324. static int reportTest(const TestAtom& testAtom, bool pass)
  325. {
  326. return pass ? reportPass(testAtom) : reportFail(testAtom);
  327. }
  328. Severity getReportableSeverity() const
  329. {
  330. return mReportableSeverity;
  331. }
  332. private:
  333. //!
  334. //! \brief returns an appropriate string for prefixing a log message with the given severity
  335. //!
  336. static const char* severityPrefix(Severity severity)
  337. {
  338. switch (severity)
  339. {
  340. case Severity::kINTERNAL_ERROR: return "[F] ";
  341. case Severity::kERROR: return "[E] ";
  342. case Severity::kWARNING: return "[W] ";
  343. case Severity::kINFO: return "[I] ";
  344. case Severity::kVERBOSE: return "[V] ";
  345. default: assert(0); return "";
  346. }
  347. }
  348. //!
  349. //! \brief returns an appropriate string for prefixing a test result message with the given result
  350. //!
  351. static const char* testResultString(TestResult result)
  352. {
  353. switch (result)
  354. {
  355. case TestResult::kRUNNING: return "RUNNING";
  356. case TestResult::kPASSED: return "PASSED";
  357. case TestResult::kFAILED: return "FAILED";
  358. case TestResult::kWAIVED: return "WAIVED";
  359. default: assert(0); return "";
  360. }
  361. }
  362. //!
  363. //! \brief returns an appropriate output stream (cout or cerr) to use with the given severity
  364. //!
  365. static std::ostream& severityOstream(Severity severity)
  366. {
  367. return severity >= Severity::kINFO ? std::cout : std::cerr;
  368. }
  369. //!
  370. //! \brief method that implements logging test results
  371. //!
  372. static void reportTestResult(const TestAtom& testAtom, TestResult result)
  373. {
  374. severityOstream(Severity::kINFO) << "&&&& " << testResultString(result) << " " << testAtom.mName << " # "
  375. << testAtom.mCmdline << std::endl;
  376. }
  377. //!
  378. //! \brief generate a command line string from the given (argc, argv) values
  379. //!
  380. static std::string genCmdlineString(int argc, char const* const* argv)
  381. {
  382. std::stringstream ss;
  383. for (int i = 0; i < argc; i++)
  384. {
  385. if (i > 0)
  386. ss << " ";
  387. ss << argv[i];
  388. }
  389. return ss.str();
  390. }
  391. Severity mReportableSeverity;
  392. };
  393. namespace
  394. {
  395. //!
  396. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kVERBOSE
  397. //!
  398. //! Example usage:
  399. //!
  400. //! LOG_VERBOSE(logger) << "hello world" << std::endl;
  401. //!
  402. inline LogStreamConsumer LOG_VERBOSE(const Logger& logger)
  403. {
  404. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kVERBOSE);
  405. }
  406. //!
  407. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINFO
  408. //!
  409. //! Example usage:
  410. //!
  411. //! LOG_INFO(logger) << "hello world" << std::endl;
  412. //!
  413. inline LogStreamConsumer LOG_INFO(const Logger& logger)
  414. {
  415. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINFO);
  416. }
  417. //!
  418. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kWARNING
  419. //!
  420. //! Example usage:
  421. //!
  422. //! LOG_WARN(logger) << "hello world" << std::endl;
  423. //!
  424. inline LogStreamConsumer LOG_WARN(const Logger& logger)
  425. {
  426. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kWARNING);
  427. }
  428. //!
  429. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kERROR
  430. //!
  431. //! Example usage:
  432. //!
  433. //! LOG_ERROR(logger) << "hello world" << std::endl;
  434. //!
  435. inline LogStreamConsumer LOG_ERROR(const Logger& logger)
  436. {
  437. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kERROR);
  438. }
  439. //!
  440. //! \brief produces a LogStreamConsumer object that can be used to log messages of severity kINTERNAL_ERROR
  441. // ("fatal" severity)
  442. //!
  443. //! Example usage:
  444. //!
  445. //! LOG_FATAL(logger) << "hello world" << std::endl;
  446. //!
  447. inline LogStreamConsumer LOG_FATAL(const Logger& logger)
  448. {
  449. return LogStreamConsumer(logger.getReportableSeverity(), Severity::kINTERNAL_ERROR);
  450. }
  451. } // anonymous namespace
  452. #endif // TENSORRT_LOGGING_H

Resnet50Serial.cpp

  1. #include <map>
  2. #include <chrono>
  3. #include <fstream>
  4. #include <string>
  5. #include "NvInfer.h"
  6. #include "logging.h"
  7. #include "cuda_runtime_api.h"
  8. #include <NvInferRuntimeCommon.h>
  9. #include "common.hpp"
  10. #include <opencv2/opencv.hpp>
  11. #include <limits.h>
  12. static Logger gLogger;
  13. #define DEVICE 0//gpu id
  14. #define BATCH_SIZE 1
  15. static const int INPUT_H = 224;
  16. static const int INPUT_W = 224;
  17. // static const int BATCH_SIZE=32;
  18. static const int OUTPUT_SIZE=1000;
  19. static const int INFER_NUMS=10000;
  20. const char* INPUT_BLOB_NAME = "image";
  21. const char* OUTPUT_BLOB_NAME1 = "output1";
  22. const char* OUTPUT_BLOB_NAME2 = "output2";
  23. using namespace nvinfer1;
  24. using namespace std;
  25. #define CHECK(status) \
  26. do\
  27. {\
  28. auto ret = (status);\
  29. if (ret != 0)\
  30. {\
  31. std::cerr << "Cuda failure: " << ret << endl;\
  32. abort();\
  33. }\
  34. } while (0)
  35. map<string, Weights> loadWeights(const string file)
  36. {
  37. cout << "Loading weights: " << file << endl;
  38. map<string, Weights> weightMap;
  39. // Open weights file
  40. ifstream input(file);
  41. assert(input.is_open() && "Unable to load weight file.");
  42. // Read number of weight blobs
  43. int32_t count;
  44. input >> count;
  45. assert(count > 0 && "Invalid weight map file.");
  46. while (count--)
  47. {
  48. Weights wt{DataType::kFLOAT, nullptr, 0};
  49. uint32_t size;
  50. // Read name and type of blob
  51. string name;
  52. input >> name >> std::dec >> size;
  53. wt.type = DataType::kFLOAT;
  54. // Load blob
  55. uint32_t* val = reinterpret_cast<uint32_t*>(malloc(sizeof(val) * size));
  56. for (uint32_t x = 0, y = size; x < y; ++x)
  57. {
  58. input >> std::hex >> val[x];
  59. }
  60. wt.values = val;
  61. wt.count = size;
  62. weightMap[name] = wt;
  63. }
  64. return weightMap;
  65. }
  66. //输出每一个维度
  67. void debug_print(ITensor* input_tensor, string head)
  68. {
  69. cout<<"==head:"<<head<<":";
  70. for(int i = 0; i<input_tensor->getDimensions().nbDims; i++)
  71. {
  72. cout<<input_tensor->getDimensions().d[i]<<" ";
  73. }
  74. cout<<endl;
  75. }
  76. ICudaEngine* createEngine(const char* weightPath, unsigned int maxBatchSize, IBuilder* builder, IBuilderConfig* config, DataType dt)
  77. {
  78. //开始定义网络 0U无符号整型0
  79. INetworkDefinition* network = builder->createNetworkV2(0U);
  80. ITensor* input = network->addInput(INPUT_BLOB_NAME, dt, Dims3{3, INPUT_H, INPUT_W});
  81. assert(input);
  82. map<string, Weights> weightMap = loadWeights(weightPath);//载入权重放入weightMap
  83. auto id_323 = convBnRelu(network, weightMap, *input, 64, 7, 2, 3,"conv1", "bn1", false);
  84. // debug_print(id_323->getOutput(0), "id_323");//debug
  85. IPoolingLayer* pool1 = network->addPoolingNd(*id_323->getOutput(0), PoolingType::kMAX, DimsHW{3,3});
  86. assert(pool1);
  87. pool1->setStrideNd(DimsHW{2, 2});
  88. pool1->setPaddingNd(DimsHW{1, 1});
  89. // debug_print(pool1->getOutput(0), " pool1");//debug
  90. auto id_336 = bottleneck(network, weightMap, *pool1->getOutput(0), 64, 1, "layer1.0", false);
  91. // debug_print(id_336->getOutput(0), "id_336");//debug
  92. auto id_346 = bottleneck(network, weightMap, *id_336->getOutput(0), 64, 1, "layer1.1", true);
  93. // debug_print(id_346->getOutput(0), "id_346");//debug
  94. auto id_356 = bottleneck(network, weightMap, *id_346->getOutput(0), 64, 1, "layer1.2", true);
  95. // debug_print(id_356->getOutput(0), "id_356");//debug
  96. auto id_368 = bottleneck(network, weightMap, *id_356->getOutput(0), 128, 2, "layer2.0", false);
  97. // debug_print(id_368->getOutput(0), "id_368");//debug
  98. auto id_378 = bottleneck(network, weightMap, *id_368->getOutput(0), 128, 1, "layer2.1", true);
  99. // debug_print(id_378->getOutput(0), "id_378");//debug
  100. auto id_388 = bottleneck(network, weightMap, *id_378->getOutput(0), 128, 1, "layer2.2", true);
  101. // debug_print(id_388->getOutput(0), "id_388");//debug
  102. auto id_398 = bottleneck(network, weightMap, *id_388->getOutput(0), 128, 1, "layer2.3", true);
  103. // debug_print(id_398->getOutput(0), "id_398");//debug
  104. auto id_410 = bottleneck(network, weightMap, *id_398->getOutput(0), 256, 2, "layer3.0", false);
  105. // debug_print(id_410->getOutput(0), "id_410");//debug
  106. auto id_420 = bottleneck(network, weightMap, *id_410->getOutput(0), 256, 1, "layer3.1", true);
  107. // debug_print(id_420->getOutput(0), "id_420");//debug
  108. auto id_430 = bottleneck(network, weightMap, *id_420->getOutput(0), 256, 1, "layer3.2", true);
  109. // debug_print(id_430->getOutput(0), "id_430");//debug
  110. auto id_440 = bottleneck(network, weightMap, *id_430->getOutput(0), 256, 1, "layer3.3", true);
  111. // debug_print(id_440->getOutput(0), "id_440");//debug
  112. auto id_450 = bottleneck(network, weightMap, *id_440->getOutput(0), 256, 1, "layer3.4", true);
  113. // debug_print(id_450->getOutput(0), "id_450");//debug
  114. auto id_460 = bottleneck(network, weightMap, *id_450->getOutput(0), 256, 1, "layer3.5", true);
  115. // debug_print(id_460->getOutput(0), "id_460");//debug
  116. auto id_472 = bottleneck(network, weightMap, *id_460->getOutput(0), 512, 2, "layer4.0", false);
  117. // debug_print(id_472->getOutput(0), "id_472");//debug
  118. auto id_482 = bottleneck(network, weightMap, *id_472->getOutput(0), 512, 1, "layer4.1", true);
  119. // debug_print(id_482->getOutput(0), "id_482");//debug
  120. auto id_492 = bottleneck(network, weightMap, *id_482->getOutput(0), 512, 1, "layer4.2", true);
  121. IPoolingLayer* pool2 = network->addPoolingNd(*id_492->getOutput(0), PoolingType::kAVERAGE, DimsHW{7,7});
  122. assert(pool2);
  123. // debug_print(pool2->getOutput(0), "pool2");//debug
  124. IFullyConnectedLayer* fc1 = network->addFullyConnected(*pool2->getOutput(0), 1000, weightMap["fc.weight"], weightMap["fc.bias"]);
  125. assert(fc1);
  126. // debug_print(fc1->getOutput(0), "fc1");//debug
  127. IActivationLayer* fc1_relu = network->addActivation(*fc1->getOutput(0), ActivationType::kRELU);
  128. assert(fc1_relu);
  129. // //分类层
  130. // ISoftMaxLayer *prob = network->addSoftMax(*fc1->getOutput(0));
  131. // assert(prob);
  132. fc1->getOutput(0)->setName(OUTPUT_BLOB_NAME1);
  133. fc1_relu->getOutput(0)->setName(OUTPUT_BLOB_NAME2);
  134. network->markOutput(*fc1->getOutput(0));
  135. network->markOutput(*fc1_relu->getOutput(0));
  136. //构造engine
  137. builder->setMaxBatchSize(maxBatchSize);
  138. config->setMaxWorkspaceSize(1<<20);
  139. ICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);
  140. //放入engine 所以network可以销毁了
  141. network->destroy();
  142. // 释放资源
  143. for (auto& mem : weightMap)
  144. {
  145. free((void*) (mem.second.values));
  146. }
  147. return engine;
  148. }
  149. void APIToModel(const char* weightPath, unsigned int maxBatchSize, IHostMemory** modelStream)
  150. {
  151. //创建builder
  152. IBuilder* builder = createInferBuilder(gLogger);//网络入口 类似pytorch的model
  153. IBuilderConfig* config = builder->createBuilderConfig();
  154. //创建模型 搭建网络层
  155. ICudaEngine* engine = createEngine(weightPath, maxBatchSize, builder, config, DataType::kFLOAT);
  156. assert(engine!=nullptr);
  157. //序列化engine
  158. (*modelStream)= engine->serialize();
  159. //销毁对象
  160. engine->destroy();
  161. config->destroy();
  162. builder->destroy();
  163. }
  164. int main(int args, char **argv)
  165. {
  166. //序列化模型为.engine文件
  167. string engine_name = "./resnet50.engine";
  168. const char* weightPath = "./resnet50.wts";
  169. IHostMemory* modelStream{nullptr};//modelStream是一块内存区域,用来保存序列化文件
  170. APIToModel(weightPath, BATCH_SIZE, &modelStream);
  171. assert(modelStream!=nullptr);
  172. //变换为.engine文件
  173. ofstream p(engine_name);
  174. if (!p)
  175. {
  176. std::cerr<<"can not open plan file"<<endl;
  177. return -1;
  178. }
  179. p.write(reinterpret_cast<const char *>(modelStream->data()), modelStream->size());
  180. p.close();
  181. //销毁对象
  182. modelStream->destroy();
  183. return 0;
  184. }

common.hpp

  1. #ifndef COMMON_HPP
  2. #define COMMON_HPP
  3. #include <map>
  4. #include <chrono>
  5. #include <fstream>
  6. #include <vector>
  7. #include <dirent.h>
  8. #include <math.h>
  9. #include <assert.h>
  10. #include "NvInfer.h"
  11. #include "logging.h"
  12. #include "cuda_runtime_api.h"
  13. using namespace nvinfer1;
  14. IScaleLayer* addBatchNorm2d(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor& input,std::string bnname,float eps)
  15. {
  16. float* gamma= (float*)weightMap[bnname+".weight"].values;
  17. float* beta=(float*)weightMap[bnname+".bias"].values;
  18. float* mean=(float*)weightMap[bnname+".running_mean"].values;
  19. float* var=(float*)weightMap[bnname+".running_var"].values;
  20. int length = weightMap[bnname+".running_var"].count;
  21. float* scval = reinterpret_cast<float *>(malloc(sizeof(float)*length));
  22. for (int i=0;i<length;i++)
  23. {
  24. scval[i] = gamma[i]/sqrt(var[i]+eps);
  25. }
  26. Weights scale{ DataType::kFLOAT, scval, length};//实例化一个weights scale 存放scval指针
  27. float* shavl = reinterpret_cast<float *>(malloc(sizeof(float)*length));
  28. for (int i=0;i<length;i++)
  29. {
  30. shavl[i] = beta[i]-mean[i]*gamma[i]/sqrt(var[i]+eps);
  31. }
  32. Weights shift{ DataType::kFLOAT, shavl, length};//实例化一个weights shift 存放shavl指针
  33. float* pval = reinterpret_cast<float *>(malloc(sizeof(float)*length));
  34. for (int i=0;i<length;i++)
  35. {
  36. pval[i] = 1.0;
  37. }
  38. Weights power{ DataType::kFLOAT, pval, length};//实例化一个weights power 存放pval指针
  39. weightMap[bnname+".scale"] = scale;
  40. weightMap[bnname+".shift"] = shift;
  41. weightMap[bnname+".power"] = power;
  42. IScaleLayer* scale_1 = network->addScale(input,ScaleMode::kCHANNEL, shift, scale, power);
  43. assert(scale_1);
  44. return scale_1;
  45. }
  46. IActivationLayer* convBnRelu(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor& input,int outch, int ksize, int s,int p,std::string convname,std::string bnname,bool bias = false)
  47. {
  48. Weights emptywts{ DataType::kFLOAT, nullptr, 0};//实例化一个空weights emptywts 空指针 长度为0
  49. //卷积层
  50. IConvolutionLayer* conv1;//先定义指针
  51. if (!bias)
  52. {
  53. conv1 = network->addConvolutionNd(input, outch, DimsHW{ksize,ksize}, weightMap[convname+".weight"],emptywts);
  54. }
  55. else
  56. {
  57. conv1 = network->addConvolutionNd(input, outch, DimsHW{ksize,ksize}, weightMap[convname+".weight"],weightMap[convname+".bias"]);
  58. }
  59. //设置步长
  60. assert(conv1);
  61. conv1->setStrideNd(DimsHW{s, s});
  62. conv1->setPaddingNd(DimsHW{p, p});
  63. IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), bnname, 1e-5);
  64. assert(bn1);
  65. //激活层
  66. IActivationLayer* relu = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU);
  67. assert(relu);
  68. return relu;
  69. }
  70. IActivationLayer* bottleneck(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor& input, int outch, int stride, std::string lname, bool shortcut_clean)
  71. {
  72. Weights emptywts{ DataType::kFLOAT, nullptr, 0};//实例化一个空weights emptywts 空指针 长度为0
  73. IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{1,1}, weightMap[lname+".conv1.weight"], emptywts);
  74. assert(conv1);
  75. IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname+".bn1", 1e-5);
  76. assert(bn1);
  77. IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU);
  78. assert(relu1);
  79. IConvolutionLayer* conv2 = network->addConvolutionNd(*relu1->getOutput(0), outch, DimsHW{3,3}, weightMap[lname+".conv2.weight"], emptywts);
  80. assert(conv2);
  81. conv2->setStrideNd(DimsHW{stride, stride});
  82. conv2->setPaddingNd(DimsHW{1, 1});
  83. IScaleLayer* bn2 = addBatchNorm2d(network, weightMap, *conv2->getOutput(0), lname+".bn2", 1e-5);
  84. assert(bn2);
  85. IActivationLayer* relu2 = network->addActivation(*bn2->getOutput(0), ActivationType::kRELU);
  86. assert(relu2);
  87. IConvolutionLayer* conv3 = network->addConvolutionNd(*relu2->getOutput(0), outch*4, DimsHW{1,1}, weightMap[lname+".conv3.weight"], emptywts);
  88. assert(conv3);
  89. IScaleLayer* bn3 = addBatchNorm2d(network, weightMap, *conv3->getOutput(0), lname+".bn3", 1e-5);
  90. assert(bn3);
  91. IElementWiseLayer *ew1;
  92. if (!shortcut_clean)
  93. {
  94. IConvolutionLayer* conv4 = network->addConvolutionNd(input, outch*4, DimsHW{1,1}, weightMap[lname+".downsample.0.weight"], emptywts);
  95. assert(conv4);
  96. conv4->setStrideNd(DimsHW{stride, stride});
  97. IScaleLayer* bn4 = addBatchNorm2d(network, weightMap, *conv4->getOutput(0), lname+".downsample.1", 1e-5);
  98. assert(bn4);
  99. ew1 = network->addElementWise(*bn4->getOutput(0), *bn3->getOutput(0), ElementWiseOperation::kSUM);
  100. }
  101. else
  102. {
  103. ew1 = network->addElementWise(input, *bn3->getOutput(0), ElementWiseOperation::kSUM);
  104. }
  105. assert(ew1);
  106. IActivationLayer* relu3 = network->addActivation(*ew1->getOutput(0), ActivationType::kRELU);
  107. assert(relu3);
  108. return relu3;
  109. }
  110. ILayer* ResBlock(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor& input, int inch, int outch, int stride, std::string lname)
  111. {
  112. Weights emptywts{ DataType::kFLOAT, nullptr, 0};//实例化一个空weights emptywts 空指针 长度为0
  113. IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{1,1}, weightMap[lname+".conv1.weight"], emptywts);
  114. assert(conv1);
  115. conv1->setStrideNd(DimsHW{stride, stride});
  116. IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname+".bn1", 1e-5);
  117. assert(bn1);
  118. IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU);
  119. assert(relu1);
  120. IConvolutionLayer* conv2 = network->addConvolutionNd(*relu1->getOutput(0), outch, DimsHW{3,3}, weightMap[lname+".conv2.weight"], emptywts);
  121. assert(conv2);
  122. conv2->setStrideNd(DimsHW{stride, stride});
  123. conv2->setPaddingNd(DimsHW{1, 1});
  124. IScaleLayer* bn2 = addBatchNorm2d(network, weightMap, *conv2->getOutput(0), lname+".bn2", 1e-5);
  125. assert(bn2);
  126. IActivationLayer* relu2 = network->addActivation(*bn2->getOutput(0), ActivationType::kRELU);
  127. assert(relu2);
  128. IConvolutionLayer* conv3 = network->addConvolutionNd(*relu2->getOutput(0), inch, DimsHW{1,1}, weightMap[lname+".conv3.weight"], emptywts);
  129. assert(conv3);
  130. conv3->setStrideNd(DimsHW{stride, stride});
  131. IScaleLayer* bn3 = addBatchNorm2d(network, weightMap, *conv3->getOutput(0), lname+".bn3", 1e-5);
  132. assert(bn3);
  133. IElementWiseLayer* ew1 = network->addElementWise(input, *bn3->getOutput(0), ElementWiseOperation::kSUM);
  134. assert(ew1);
  135. IActivationLayer* relu3 = network->addActivation(*ew1->getOutput(0), ActivationType::kRELU);
  136. assert(relu3);
  137. return relu3;
  138. }
  139. ILayer* liteResBlock(INetworkDefinition* network, std::map<std::string, Weights>& weightMap, ITensor& input, int outch, int stride, std::string lname)
  140. {
  141. Weights emptywts{ DataType::kFLOAT, nullptr, 0};//实例化一个空weights emptywts 空指针 长度为0
  142. IConvolutionLayer* conv1 = network->addConvolutionNd(input, outch, DimsHW{3,3}, weightMap[lname+".conv1.weight"], emptywts);
  143. assert(conv1);
  144. conv1->setStrideNd(DimsHW{stride, stride});
  145. conv1->setPaddingNd(DimsHW{1, 1});
  146. IScaleLayer* bn1 = addBatchNorm2d(network, weightMap, *conv1->getOutput(0), lname+".bn1", 1e-5);
  147. assert(bn1);
  148. IActivationLayer* relu1 = network->addActivation(*bn1->getOutput(0), ActivationType::kRELU);
  149. assert(relu1);
  150. IConvolutionLayer* conv2 = network->addConvolutionNd(*relu1->getOutput(0), outch, DimsHW{3,3}, weightMap[lname+".conv2.weight"], emptywts);
  151. assert(conv2);
  152. conv2->setStrideNd(DimsHW{stride, stride});
  153. conv2->setPaddingNd(DimsHW{1, 1});
  154. IScaleLayer* bn2 = addBatchNorm2d(network, weightMap, *conv2->getOutput(0), lname+".bn2", 1e-5);
  155. assert(bn2);
  156. IActivationLayer* relu2 = network->addActivation(*bn2->getOutput(0), ActivationType::kRELU);
  157. assert(relu2);
  158. IElementWiseLayer* ew1 = network->addElementWise(input, *bn2->getOutput(0), ElementWiseOperation::kSUM);
  159. assert(ew1);
  160. IActivationLayer* relu3 = network->addActivation(*ew1->getOutput(0), ActivationType::kRELU);
  161. assert(relu3);
  162. return relu3;
  163. }
  164. #endif

CMakeLists.txt

  1. cmake_minimum_required(VERSION 2.6)
  2. project(resnet)
  3. add_definitions(-std=c++11)
  4. option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
  5. set(CMAKE_CXX_STANDARD 11)
  6. set(CMAKE_BUILD_TYPE Debug)
  7. find_package(OpenCV REQUIRED)
  8. include_directories(OpenCV_INCLUDE_DIRS)
  9. include_directories(${PROJECT_SOURCE_DIR}/include)
  10. # include and link dirs of cuda and tensorrt, you need adapt them if yours are different
  11. # cuda
  12. include_directories(/usr/local/cuda/include)
  13. link_directories(/usr/local/cuda/lib64)
  14. # tensorrt
  15. include_directories(/usr/include/x86_64-linux-gnu/)
  16. link_directories(/usr/lib/x86_64-linux-gnu/)
  17. add_executable(Resnet50Serial ${PROJECT_SOURCE_DIR}/Resnet50Serial.cpp)
  18. target_link_libraries(Resnet50Serial nvinfer)
  19. target_link_libraries(Resnet50Serial cudart)
  20. target_link_libraries(Resnet50Serial ${OpenCV_LIBS})
  21. #add_executable(resnext50 ${PROJECT_SOURCE_DIR}/resnext50_32x4d.cpp)
  22. #target_link_libraries(resnext50 nvinfer)
  23. #target_link_libraries(resnext50 cudart)
  24. add_definitions(-O2 -pthread)

即可生成.engine文件,而如果要量化为fp16,只需要增加:

builder->setHalf2Mode(true);

就可以.

下面这句话用来判断是否支持fp16.

bool useFp16 = builder->platformHasFastFp16();

2.2反序列化推理阶段

1.文件代码结构图

其中resnet50.engine是上一阶段生成的,logging.h和上一阶段一样。

2.代码:

main.cpp

  1. #include <complex>
  2. #include <fstream>
  3. #include <iostream>
  4. #include "Resnet50Classify.h"
  5. #include <vector>
  6. #include <algorithm>
  7. using namespace std;
  8. bool cmp(int x,int y)
  9. {
  10. return x>y;
  11. }
  12. template<typename T>
  13. vector<int> sort_indexes(const vector<T> & v, bool reverse=false) {
  14. // initialize original index locations
  15. vector<int> idx(v.size());
  16. for (int i = 0; i != idx.size(); ++i) idx[i] = i;
  17. // sort indexes based on comparing values in v
  18. if(reverse)
  19. {
  20. sort(idx.begin(), idx.end(),
  21. [& v](int i1, int i2) {return v[i1] > v[i2];});
  22. }else{
  23. sort(idx.begin(), idx.end(),
  24. [& v](int i1, int i2) {return v[i1] < v[i2];});
  25. }
  26. return idx;
  27. }
  28. void get_index_value(int OUTPUT_SIZE, float *prob, vector<float>& res){
  29. // res[0] = 1;
  30. // res[1] = 0.9898978;
  31. float maxp = INT_MIN;
  32. int index = 0;
  33. for (int i = 0; i < OUTPUT_SIZE; i++)
  34. {
  35. if(prob[i]>maxp){
  36. maxp = prob[i];
  37. index = i;
  38. }
  39. }
  40. res[0] = index;
  41. res[1] = maxp;
  42. }
  43. vector<int> topk_index(int OUTPUT_SIZE, float* prob, vector<float>& ProbIndex){
  44. vector<int> sorted_indx;
  45. sorted_indx = sort_indexes(ProbIndex, true);
  46. return sorted_indx;
  47. }
  48. int main(int argc, char** argv){
  49. if( argc != 2)
  50. {
  51. cout<<"图片路径没有输入"<<endl;
  52. return -1;
  53. }
  54. ResNet50* model = new ResNet50();
  55. //开始推理, 模拟推理10000次,存储推理结果
  56. const char* enginePath = "./resnet50.engine";
  57. model->InferenceInit(enginePath);//将引擎文件载入显卡,反序列化好环境并启动cuda核
  58. const char* imgPath = argv[1];
  59. cout<<"=====main cv::CV_VERSION:===="<<CV_VERSION<<endl;
  60. auto start = chrono::system_clock::now();//开始时间
  61. model->preProcess(imgPath);//图像预处理
  62. for (int i = 0; i < model->INFER_NUMS; i++)
  63. {
  64. // std::cout<<"data[i]:"<<data[i]<<std::endl;
  65. model->doInference(model->data, model->prob1, model->prob2, model->batchSize); //开始推理
  66. }
  67. auto end = chrono::system_clock::now();//结束时间
  68. std::cout << chrono::duration_cast<chrono::milliseconds>(end - start).count() << "ms" << std::endl;
  69. cout<<"====model->prob1:"<<model->prob1<<endl;//打印地址
  70. cout<<"====model->prob2:"<<model->prob2<<endl;//打印地址
  71. cout<<"========================================"<<endl;
  72. vector<float>res1(2, 0);
  73. get_index_value(model->OUTPUT_SIZE, model->prob1, res1);
  74. vector<float>res2(2, 0);
  75. get_index_value(model->OUTPUT_SIZE, model->prob2, res2);
  76. for(int i=0; i<2; i++){
  77. cout<<"===res1[i]:==="<<res1[i]<<endl;//打印最大值的索引
  78. cout<<"===res2[i]:==="<<res2[i]<<endl;//打印最大值
  79. }
  80. cout<<"========================================"<<endl;
  81. ofstream trt_result("./fc_and_relu.txt");
  82. int topk = 100;
  83. for (int i = 0; i < topk; i++)
  84. {
  85. trt_result<<model->prob1[i];
  86. trt_result<<",";
  87. trt_result<<model->prob2[i]<<endl;
  88. cout<<"===model->prob1[i]==="<<model->prob1[i]<<endl;
  89. cout<<"===model->prob2[i]==="<<model->prob2[i]<<endl;
  90. }
  91. trt_result.close();
  92. // vector<float> ProbIndex(model->prob1, model->prob1 + model->OUTPUT_SIZE);
  93. // vector<int> sorted_indx;
  94. // vector<int> res;
  95. // sorted_indx = sort_indexes(ProbIndex, true);
  96. // vector<float> ProbIndex1(model->prob1, model->prob1 + model->OUTPUT_SIZE);
  97. // vector<float> ProbIndex2(model->prob2, model->prob2 + model->OUTPUT_SIZE);
  98. // vector<int> sorted_indx1;
  99. // vector<int> sorted_indx2;
  100. // sorted_indx1 = topk_index(model->OUTPUT_SIZE, model->prob1, ProbIndex1);
  101. // sorted_indx2 = topk_index(model->OUTPUT_SIZE, model->prob2, ProbIndex2);
  102. // for (int i = 0; i < topk; i++)
  103. // {
  104. // cout<<"===sorted_indx1[i]==="<<sorted_indx1[i]<<endl;
  105. // cout<<"===sorted_indx2[i]==="<<sorted_indx2[i]<<endl;
  106. // }
  107. delete model;
  108. model = nullptr;
  109. return 0;
  110. }

Resnet50Classify.h

  1. #ifndef TENSORRT_H
  2. #define TENSORRT_H
  3. #include <map>
  4. #include <chrono>
  5. #include <fstream>
  6. #include <string>
  7. #include "NvInfer.h"
  8. #include "logging.h"
  9. #include "cuda_runtime_api.h"
  10. #include <NvInferRuntimeCommon.h>
  11. #include <opencv2/opencv.hpp>
  12. #include <limits.h>
  13. using namespace std;
  14. using namespace nvinfer1;
  15. class ResNet50
  16. {
  17. public:
  18. void InferenceInit(const char* enginePath);
  19. void doInference(float* input, float* output1, float* output2, int batchSize);
  20. void preProcess(const char* imgPath);
  21. ResNet50(){};
  22. ~ResNet50();
  23. public:
  24. Logger gLogger;
  25. static const int INPUT_H = 224;
  26. static const int INPUT_W = 224;
  27. static const int OUTPUT_SIZE = 1000;
  28. static const int INFER_NUMS = 10000;
  29. const int batchSize = 1;
  30. const char* imaPath;
  31. const char* INPUT_BLOB_NAME = "image";
  32. const char* OUTPUT_BLOB_NAME1 = "output1";
  33. const char* OUTPUT_BLOB_NAME2 = "output2";
  34. float prob1[OUTPUT_SIZE];
  35. float prob2[OUTPUT_SIZE];
  36. char *trtModelStream;
  37. vector<float> mean_value{ 0.406, 0.456, 0.485 }; // BGR
  38. vector<float> std_value{ 0.225, 0.224, 0.229 };
  39. float* data = new float[3 * INPUT_H * INPUT_W];
  40. IRuntime* m_runtime;
  41. ICudaEngine* m_engine;
  42. IExecutionContext* m_context;
  43. };
  44. #endif

Resnet50Classify.cpp

  1. #include <opencv2/core/core.hpp>
  2. #include <opencv2/core/types_c.h>
  3. #include <opencv2/highgui/highgui.hpp>
  4. #include <opencv2/opencv.hpp>
  5. #include "cuda_runtime_api.h"
  6. #include <fstream>
  7. #include <string>
  8. #include <NvInferRuntimeCommon.h>
  9. #include <c++/5/bits/c++config.h>
  10. #include <cassert>
  11. #include <limits.h>
  12. #include "Resnet50Classify.h"
  13. using namespace std;
  14. using namespace nvinfer1;
  15. #define CHECK(status) \
  16. do\
  17. {\
  18. auto ret = (status);\
  19. if (ret != 0)\
  20. {\
  21. std::cerr << "Cuda failure: " << ret << std::endl;\
  22. abort();\
  23. }\
  24. } while (0)
  25. void ResNet50::doInference(float* input, float* output1, float* output2, int batchSize){
  26. //输入输出总共有两个,做一下验证
  27. assert(m_engine->getNbBindings()==3);
  28. //void型指针
  29. void* buffers[3];
  30. //获取与这个engine相关的输入输出tensor的索引s
  31. const int inputIndex = m_engine->getBindingIndex(INPUT_BLOB_NAME);
  32. const int outputIndex1 = m_engine->getBindingIndex(OUTPUT_BLOB_NAME1);
  33. const int outputIndex2 = m_engine->getBindingIndex(OUTPUT_BLOB_NAME2);
  34. //为输入输出tensor开辟显存。
  35. CHECK(cudaMalloc(&buffers[inputIndex], batchSize * 3 * INPUT_H * INPUT_W * sizeof(float)));
  36. CHECK(cudaMalloc(&buffers[outputIndex1], batchSize * OUTPUT_SIZE * sizeof(float)));
  37. CHECK(cudaMalloc(&buffers[outputIndex2], batchSize * OUTPUT_SIZE * sizeof(float)));
  38. //创建cuda流,用于管理数据复制,存取,和计算的并发操作
  39. cudaStream_t stream;
  40. CHECK(cudaStreamCreate(&stream));
  41. //从内存到显存,input是读入内存中的数据;buffers[inputIndex]是显存上的存储区域,用于存放输入数据
  42. CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize *3* INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream));
  43. // //启动cuda核,异步执行推理计算
  44. m_context->enqueue(batchSize, buffers, stream, nullptr);
  45. //从显存到内存,buffers[outputIndex]是显存中的存储区,存放模型输出;output是内存中的数据
  46. CHECK(cudaMemcpyAsync(output1, buffers[outputIndex1], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
  47. CHECK(cudaMemcpyAsync(output2, buffers[outputIndex2], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
  48. //如果使用了多个cuda流,需要同步
  49. cudaStreamSynchronize(stream);
  50. // Release stream and buffers
  51. cudaStreamDestroy(stream);
  52. CHECK(cudaFree(buffers[inputIndex]));
  53. CHECK(cudaFree(buffers[outputIndex1]));
  54. CHECK(cudaFree(buffers[outputIndex2]));
  55. }
  56. void ResNet50::preProcess(const char* imgPath){
  57. cv::Mat img = cv::imread(imgPath);
  58. cv::Mat src_img;
  59. cv::resize(img, src_img, cv::Size(INPUT_W, INPUT_H));
  60. int count = 0;
  61. for(int i = 0; i<INPUT_H; i++){
  62. uchar* uc_pixel = src_img.data + i * src_img.step;
  63. for(int j = 0; j<INPUT_W; j++){//bgr存放
  64. data[count] = (uc_pixel[0] / 255. - mean_value[0]) / std_value[0];
  65. data[count + src_img.rows * src_img.cols] = (uc_pixel[1] / 255. - mean_value[1]) / std_value[1];
  66. data[count + 2 * src_img.rows * src_img.cols] = (uc_pixel[2] / 255. - mean_value[2]) / std_value[2];
  67. uc_pixel += 3;
  68. count++;
  69. }
  70. }
  71. }
  72. void ResNet50::InferenceInit(const char* enginePath){
  73. size_t size;
  74. ifstream file(enginePath, std::ios::binary);
  75. if(file.good()){
  76. //get length of file
  77. file.seekg(0, file.end);
  78. size = file.tellg();
  79. file.seekg(0, file.beg);
  80. //allocate memory
  81. trtModelStream = new char[size];
  82. assert(trtModelStream);
  83. //read data as block
  84. file.read(trtModelStream, size);
  85. file.close();
  86. }
  87. //创建运行时环境IRuntime对象
  88. IRuntime* runtime = createInferRuntime(gLogger);
  89. assert(runtime !=nullptr);
  90. m_runtime = runtime;
  91. //引擎反序列化
  92. ICudaEngine* engine = m_runtime->deserializeCudaEngine(trtModelStream, size, nullptr);
  93. assert(engine !=nullptr);
  94. m_engine = engine;
  95. //创建上下文环境,主要用与inference函数中启动cuda核
  96. IExecutionContext* context = m_engine->createExecutionContext();
  97. assert(context !=nullptr);
  98. m_context = context;
  99. }
  100. ResNet50::~ResNet50(){
  101. if(m_context){
  102. m_context->destroy();
  103. m_context = nullptr;
  104. }
  105. if(m_engine){
  106. m_engine->destroy();
  107. m_engine = nullptr;
  108. }
  109. if(m_runtime){
  110. m_runtime->destroy();
  111. m_runtime = nullptr;
  112. }
  113. if(data){
  114. delete[] data;
  115. data = nullptr;
  116. }
  117. if(trtModelStream){
  118. delete trtModelStream;
  119. trtModelStream = nullptr;
  120. }
  121. }

CMakeLists.txt

  1. cmake_minimum_required(VERSION 2.6)
  2. project(resnet)
  3. add_definitions(-std=c++11)
  4. option(CUDA_USE_STATIC_CUDA_RUNTIME OFF)
  5. set(CMAKE_CXX_STANDARD 11)
  6. set(CMAKE_BUILD_TYPE Debug)
  7. find_package(OpenCV REQUIRED)
  8. include_directories(OpenCV_INCLUDE_DIRS)
  9. include_directories(${PROJECT_SOURCE_DIR}/include)
  10. # include and link dirs of cuda and tensorrt, you need adapt them if yours are different
  11. # cuda
  12. include_directories(/usr/local/cuda/include)
  13. link_directories(/usr/local/cuda/lib64)
  14. # tensorrt
  15. include_directories(/usr/include/x86_64-linux-gnu/)
  16. link_directories(/usr/lib/x86_64-linux-gnu/)
  17. add_executable(Resnet50Classify ${PROJECT_SOURCE_DIR}/main.cpp Resnet50Classify.cpp)
  18. target_link_libraries(Resnet50Classify nvinfer)
  19. target_link_libraries(Resnet50Classify cudart)
  20. target_link_libraries(Resnet50Classify ${OpenCV_LIBS})
  21. add_definitions(-O2 -pthread)

./Resnet50Classify test.jpg

结果:

生成的fc_and_relu.txt的结果.

2.3 比较结果

  1. import numpy as np
  2. pytorch_res_path = './pytorch_result.txt'
  3. pytorch_res = []
  4. trt_res_path = './fc_and_relu.txt'
  5. trt_res = []
  6. with open(pytorch_res_path, 'r', encoding='utf-8') as file:
  7. for i, read_info in enumerate(file.readlines()):
  8. pytorch_res.append(float(read_info))
  9. with open(trt_res_path, 'r', encoding='utf-8') as file:
  10. for i, read_info in enumerate(file.readlines()):
  11. trt_res.append(float(read_info.split(',')[0]))
  12. print('==trt_res:', trt_res)
  13. pytorch_res = np.array(pytorch_res)
  14. trt_res = np.array(trt_res)
  15. abs_error = np.sum(np.abs((pytorch_res - trt_res)/pytorch_res)) / len(pytorch_res)
  16. print('===abs_error===', abs_error)

可看出和torch的结果误差很小,同时时间由原先的12ms变为28656/10000 = 2.86ms,同时显存占用量减少100M。速度还是得到了4倍左右的提升,同时看出另一个Relu的输出是直接将fc层置为>=0的。

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/351080
推荐阅读
相关标签
  

闽ICP备14008679号