当前位置:   article > 正文

使用onnxruntime加载YOLOv8生成的onnx文件进行目标检测_yolov8测试.onnx文件

yolov8测试.onnx文件

      在网上下载了60多幅包含西瓜和冬瓜的图像组成melon数据集,使用 LabelMe  工具进行标注,然后使用 labelme2yolov8 脚本将json文件转换成YOLOv8支持的.txt文件,并自动生成YOLOv8支持的目录结构,包括melon.yaml文件,其内容如下:

  1. path: ../datasets/melon # dataset root dir
  2. train: images/train # train images (relative to 'path')
  3. val: images/val # val images (relative to 'path')
  4. test: # test images (optional)
  5. # Classes
  6. names:
  7. 0: watermelon
  8. 1: wintermelon

      使用以下python脚本进行训练生成onnx文件

  1. import argparse
  2. import colorama
  3. from ultralytics import YOLO
  4. def parse_args():
  5. parser = argparse.ArgumentParser(description="YOLOv8 train")
  6. parser.add_argument("--yaml", required=True, type=str, help="yaml file")
  7. parser.add_argument("--epochs", required=True, type=int, help="number of training")
  8. parser.add_argument("--task", required=True, type=str, choices=["detect", "segment"], help="specify what kind of task")
  9. args = parser.parse_args()
  10. return args
  11. def train(task, yaml, epochs):
  12. if task == "detect":
  13. model = YOLO("yolov8n.pt") # load a pretrained model
  14. elif task == "segment":
  15. model = YOLO("yolov8n-seg.pt") # load a pretrained model
  16. else:
  17. print(colorama.Fore.RED + "Error: unsupported task:", task)
  18. raise
  19. results = model.train(data=yaml, epochs=epochs, imgsz=640) # train the model
  20. metrics = model.val() # It'll automatically evaluate the data you trained, no arguments needed, dataset and settings remembered
  21. model.export(format="onnx") #, dynamic=True) # export the model, cannot specify dynamic=True, opencv does not support
  22. # model.export(format="onnx", opset=12, simplify=True, dynamic=False, imgsz=640)
  23. model.export(format="torchscript") # libtorch
  24. if __name__ == "__main__":
  25. colorama.init()
  26. args = parse_args()
  27. train(args.task, args.yaml, args.epochs)
  28. print(colorama.Fore.GREEN + "====== execution completed ======")

      以下是使用onnxruntime接口加载onnx文件进行目标检测的实现代码:

  1. namespace {
  2. constexpr bool cuda_enabled{ false };
  3. constexpr int image_size[2]{ 640, 640 }; // {height,width}, input shape (1, 3, 640, 640) BCHW and output shape(s) (1, 6, 8400)
  4. constexpr float model_score_threshold{ 0.45 }; // confidence threshold
  5. constexpr float model_nms_threshold{ 0.50 }; // iou threshold
  6. #ifdef _MSC_VER
  7. constexpr char* onnx_file{ "../../../data/best.onnx" };
  8. constexpr char* torchscript_file{ "../../../data/best.torchscript" };
  9. constexpr char* images_dir{ "../../../data/images/predict" };
  10. constexpr char* result_dir{ "../../../data/result" };
  11. constexpr char* classes_file{ "../../../data/images/labels.txt" };
  12. #else
  13. constexpr char* onnx_file{ "data/best.onnx" };
  14. constexpr char* torchscript_file{ "data/best.torchscript" };
  15. constexpr char* images_dir{ "data/images/predict" };
  16. constexpr char* result_dir{ "data/result" };
  17. constexpr char* classes_file{ "data/images/labels.txt" };
  18. #endif
  19. std::vector<std::string> parse_classes_file(const char* name)
  20. {
  21. std::vector<std::string> classes;
  22. std::ifstream file(name);
  23. if (!file.is_open()) {
  24. std::cerr << "Error: fail to open classes file: " << name << std::endl;
  25. return classes;
  26. }
  27. std::string line;
  28. while (std::getline(file, line)) {
  29. auto pos = line.find_first_of(" ");
  30. classes.emplace_back(line.substr(0, pos));
  31. }
  32. file.close();
  33. return classes;
  34. }
  35. auto get_dir_images(const char* name)
  36. {
  37. std::map<std::string, std::string> images; // image name, image path + image name
  38. for (auto const& dir_entry : std::filesystem::directory_iterator(name)) {
  39. if (dir_entry.is_regular_file())
  40. images[dir_entry.path().filename().string()] = dir_entry.path().string();
  41. }
  42. return images;
  43. }
  44. void draw_boxes(const std::vector<std::string>& classes, const std::vector<int>& ids, const std::vector<float>& confidences,
  45. const std::vector<cv::Rect>& boxes, const std::string& name, cv::Mat& frame)
  46. {
  47. if (ids.size() != confidences.size() || ids.size() != boxes.size() || confidences.size() != boxes.size()) {
  48. std::cerr << "Error: their lengths are inconsistent: " << ids.size() << ", " << confidences.size() << ", " << boxes.size() << std::endl;
  49. return;
  50. }
  51. std::cout << "image name: " << name << ", number of detections: " << ids.size() << std::endl;
  52. std::random_device rd;
  53. std::mt19937 gen(rd());
  54. std::uniform_int_distribution<int> dis(100, 255);
  55. for (auto i = 0; i < ids.size(); ++i) {
  56. auto color = cv::Scalar(dis(gen), dis(gen), dis(gen));
  57. cv::rectangle(frame, boxes[i], color, 2);
  58. std::string class_string = classes[ids[i]] + ' ' + std::to_string(confidences[i]).substr(0, 4);
  59. cv::Size text_size = cv::getTextSize(class_string, cv::FONT_HERSHEY_DUPLEX, 1, 2, 0);
  60. cv::Rect text_box(boxes[i].x, boxes[i].y - 40, text_size.width + 10, text_size.height + 20);
  61. cv::rectangle(frame, text_box, color, cv::FILLED);
  62. cv::putText(frame, class_string, cv::Point(boxes[i].x + 5, boxes[i].y - 10), cv::FONT_HERSHEY_DUPLEX, 1, cv::Scalar(0, 0, 0), 2, 0);
  63. }
  64. //cv::imshow("Inference", frame);
  65. //cv::waitKey(-1);
  66. std::string path(result_dir);
  67. path += "/" + name;
  68. cv::imwrite(path, frame);
  69. }
  70. std::wstring ctow(const char* str)
  71. {
  72. constexpr size_t len{ 128 };
  73. wchar_t wch[len];
  74. swprintf(wch, len, L"%hs", str);
  75. return std::wstring(wch);
  76. }
  77. float image_preprocess(const cv::Mat& src, cv::Mat& dst)
  78. {
  79. cv::cvtColor(src, dst, cv::COLOR_BGR2RGB);
  80. float resize_scales{ 1. };
  81. if (src.cols >= src.rows) {
  82. resize_scales = src.cols * 1.f / image_size[1];
  83. cv::resize(dst, dst, cv::Size(image_size[1], static_cast<int>(src.rows / resize_scales)));
  84. } else {
  85. resize_scales = src.rows * 1.f / image_size[0];
  86. cv::resize(dst, dst, cv::Size(static_cast<int>(src.cols / resize_scales), image_size[0]));
  87. }
  88. cv::Mat tmp = cv::Mat::zeros(image_size[0], image_size[1], CV_8UC3);
  89. dst.copyTo(tmp(cv::Rect(0, 0, dst.cols, dst.rows)));
  90. dst = tmp;
  91. return resize_scales;
  92. }
  93. template<typename T>
  94. void image_to_blob(const cv::Mat& src, T* blob)
  95. {
  96. for (auto c = 0; c < 3; ++c) {
  97. for (auto h = 0; h < src.rows; ++h) {
  98. for (auto w = 0; w < src.cols; ++w) {
  99. blob[c * src.rows * src.cols + h * src.cols + w] = (src.at<cv::Vec3b>(h, w)[c]) / 255.f;
  100. }
  101. }
  102. }
  103. }
  104. void post_process(const float* data, int rows, int stride, float xfactor, float yfactor, const std::vector<std::string>& classes,
  105. cv::Mat& frame, const std::string& name)
  106. {
  107. std::vector<int> class_ids;
  108. std::vector<float> confidences;
  109. std::vector<cv::Rect> boxes;
  110. for (auto i = 0; i < rows; ++i) {
  111. const float* classes_scores = data + 4;
  112. cv::Mat scores(1, classes.size(), CV_32FC1, (float*)classes_scores);
  113. cv::Point class_id;
  114. double max_class_score;
  115. cv::minMaxLoc(scores, 0, &max_class_score, 0, &class_id);
  116. if (max_class_score > model_score_threshold) {
  117. confidences.push_back(max_class_score);
  118. class_ids.push_back(class_id.x);
  119. float x = data[0];
  120. float y = data[1];
  121. float w = data[2];
  122. float h = data[3];
  123. int left = int((x - 0.5 * w) * xfactor);
  124. int top = int((y - 0.5 * h) * yfactor);
  125. int width = int(w * xfactor);
  126. int height = int(h * yfactor);
  127. boxes.push_back(cv::Rect(left, top, width, height));
  128. }
  129. data += stride;
  130. }
  131. std::vector<int> nms_result;
  132. cv::dnn::NMSBoxes(boxes, confidences, model_score_threshold, model_nms_threshold, nms_result);
  133. std::vector<int> ids;
  134. std::vector<float> confs;
  135. std::vector<cv::Rect> rects;
  136. for (size_t i = 0; i < nms_result.size(); ++i) {
  137. ids.emplace_back(class_ids[nms_result[i]]);
  138. confs.emplace_back(confidences[nms_result[i]]);
  139. rects.emplace_back(boxes[nms_result[i]]);
  140. }
  141. draw_boxes(classes, ids, confs, rects, name, frame);
  142. }
  143. } // namespace
  144. int test_yolov8_detect_onnxruntime()
  145. {
  146. // reference: ultralytics/examples/YOLOv8-ONNXRuntime-CPP
  147. try {
  148. Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "Yolo");
  149. Ort::SessionOptions session_option;
  150. if (cuda_enabled) {
  151. OrtCUDAProviderOptions cuda_option;
  152. cuda_option.device_id = 0;
  153. session_option.AppendExecutionProvider_CUDA(cuda_option);
  154. }
  155. session_option.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_ALL);
  156. session_option.SetIntraOpNumThreads(1);
  157. session_option.SetLogSeverityLevel(3);
  158. Ort::Session session(env, ctow(onnx_file).c_str(), session_option);
  159. Ort::AllocatorWithDefaultOptions allocator;
  160. std::vector<const char*> input_node_names, output_node_names;
  161. std::vector<std::string> input_node_names_, output_node_names_;
  162. for (auto i = 0; i < session.GetInputCount(); ++i) {
  163. Ort::AllocatedStringPtr input_node_name = session.GetInputNameAllocated(i, allocator);
  164. input_node_names_.emplace_back(input_node_name.get());
  165. }
  166. for (auto i = 0; i < session.GetOutputCount(); ++i) {
  167. Ort::AllocatedStringPtr output_node_name = session.GetOutputNameAllocated(i, allocator);
  168. output_node_names_.emplace_back(output_node_name.get());
  169. }
  170. for (auto i = 0; i < input_node_names_.size(); ++i)
  171. input_node_names.emplace_back(input_node_names_[i].c_str());
  172. for (auto i = 0; i < output_node_names_.size(); ++i)
  173. output_node_names.emplace_back(output_node_names_[i].c_str());
  174. Ort::RunOptions options(nullptr);
  175. std::unique_ptr<float[]> blob(new float[image_size[0] * image_size[1] * 3]);
  176. std::vector<int64_t> input_node_dims{ 1, 3, image_size[1], image_size[0] };
  177. auto classes = parse_classes_file(classes_file);
  178. if (classes.size() == 0) {
  179. std::cerr << "Error: fail to parse classes file: " << classes_file << std::endl;
  180. return -1;
  181. }
  182. for (const auto& [key, val] : get_dir_images(images_dir)) {
  183. cv::Mat frame = cv::imread(val, cv::IMREAD_COLOR);
  184. if (frame.empty()) {
  185. std::cerr << "Warning: unable to load image: " << val << std::endl;
  186. continue;
  187. }
  188. auto tstart = std::chrono::high_resolution_clock::now();
  189. cv::Mat rgb;
  190. auto resize_scales = image_preprocess(frame, rgb);
  191. image_to_blob(rgb, blob.get());
  192. Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
  193. Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), blob.get(), 3 * image_size[1] * image_size[0], input_node_dims.data(), input_node_dims.size());
  194. auto output_tensors = session.Run(options, input_node_names.data(), &input_tensor, 1, output_node_names.data(), output_node_names.size());
  195. Ort::TypeInfo type_info = output_tensors.front().GetTypeInfo();
  196. auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
  197. std::vector<int64_t> output_node_dims = tensor_info.GetShape();
  198. auto output = output_tensors.front().GetTensorMutableData<float>();
  199. int stride_num = output_node_dims[1];
  200. int signal_result_num = output_node_dims[2];
  201. cv::Mat raw_data = cv::Mat(stride_num, signal_result_num, CV_32F, output);
  202. raw_data = raw_data.t();
  203. float* data = (float*)raw_data.data;
  204. auto tend = std::chrono::high_resolution_clock::now();
  205. std::cout << "elapsed millisenconds: " << std::chrono::duration_cast<std::chrono::milliseconds>(tend - tstart).count() << " ms" << std::endl;
  206. post_process(data, signal_result_num, stride_num, resize_scales, resize_scales, classes, frame, key);
  207. }
  208. }
  209. catch (const std::exception& e) {
  210. std::cerr << "Error: " << e.what() << std::endl;
  211. return -1;
  212. }
  213. return 0;
  214. }

      labels.txt文件内容如下:仅2类

  1. watermelon 0
  2. wintermelon 1

      说明

      1.这里使用的onnxruntime版本为1.18.0;

      2.windows下,onnxruntime库在debug和release为同一套库,在debug和release下均可执行;

      3.通过指定变量cuda_enabled判断走cpu还是gpu流程 ;

      4.windows下,onnxruntime中有些接口参数为wchar_t*,而linux下为char*,因此在windows下需要单独做转换,这里通过ctow函数实现从char*到wchar_t的转换;

      5.yolov8中提供的sample有问题,需要作调整。

      执行结果如下图所示:同样的预测图像集,与opencv dnn结果相似,它们具有相同的后处理流程;下面显示的耗时是在cpu下,gpu下仅20毫秒左右

      其中一幅图像的检测结果如下图所示:

      GitHubhttps://github.com/fengbingchun/NN_Test

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小桥流水78/article/detail/979869
推荐阅读
相关标签
  

闽ICP备14008679号