当前位置:   article > 正文

YOLOV8-ONNX 模型推理_yolov8 onnx推理

yolov8 onnx推理

YOLOv8的ONNX模型推理是指使用ONNX(Open Neural Network Exchange)格式的YOLOv8模型来进行对象检测的推断过程。ONNX是一种跨平台的深度学习模型格式,支持多种框架之间的模型转换和运行,使得模型能够在不同的硬件和软件平台上高效执行

使用Ultralytics的YOLO库来加载一个YOLOv8的PyTorch模型 导出为ONNX格式

  1. from ultralytics import YOLO
  2. model = YOLO(yolov8n-pose.pt")
  3. model.export(format="onnx") # export the model to onnx format

准备coco128.yaml文件来存放类别 

代码如下:

onnx_inference

  1. import os
  2. import time
  3. import random
  4. from tool import *
  5. def main():
  6. model_path = "yolov8n.onnx"
  7. session, model_inputs, input_width, input_height = init_detect_model(model_path)
  8. modes = {
  9. 1: process_images,
  10. 2: webcam_detection,
  11. 3: video_processing
  12. }
  13. mode = 1
  14. if mode in modes:
  15. modes[mode](session, model_inputs, input_width, input_height)
  16. else:
  17. print("Invalid mode. Please choose from 1, 2, or 3.")
  18. def process_images(session, model_inputs, input_width, input_height):
  19. image_dir = './images'
  20. image_list = os.listdir(image_dir)
  21. random.shuffle(image_list)
  22. for image_item in image_list:
  23. path = os.path.join(image_dir, image_item)
  24. im0 = cv2.imread(path)
  25. result_image = detect_object(im0, session, model_inputs, input_width, input_height)
  26. cv2.imwrite("output_image.jpg", result_image)
  27. cv2.imshow('Output', result_image)
  28. cv2.waitKey(0)
  29. def webcam_detection(session, model_inputs, input_width, input_height):
  30. cap = cv2.VideoCapture(0)
  31. if not cap.isOpened():
  32. print("Error: Could not open camera.")
  33. return
  34. frame_count, start_time = 0, time.time()
  35. while True:
  36. ret, frame = cap.read()
  37. if not ret:
  38. print("Error: Could not read frame.")
  39. break
  40. output_image = detect_object(frame, session, model_inputs, input_width, input_height)
  41. frame_count += 1
  42. elapsed_time = time.time() - start_time
  43. fps = frame_count / elapsed_time
  44. cv2.putText(output_image, f"FPS: {fps:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2,
  45. cv2.LINE_AA)
  46. cv2.imshow("Video", output_image)
  47. if cv2.waitKey(1) & 0xFF == ord('q'):
  48. break
  49. cap.release()
  50. cv2.destroyAllWindows()
  51. def video_processing(session, model_inputs, input_width, input_height):
  52. input_video_path = 'kun1.mp4'
  53. output_video_path = 'kun_det1.mp4'
  54. cap = cv2.VideoCapture(input_video_path)
  55. if not cap.isOpened():
  56. print("Error: Could not open video.")
  57. return
  58. frame_width = int(cap.get(3))
  59. frame_height = int(cap.get(4))
  60. fps = cap.get(cv2.CAP_PROP_FPS)
  61. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  62. out = cv2.VideoWriter(output_video_path, fourcc, fps, (frame_width, frame_height))
  63. frame_count, start_time = 0, time.time()
  64. while True:
  65. ret, frame = cap.read()
  66. if not ret:
  67. print("Info: End of video file.")
  68. break
  69. output_image = detect_object(frame, session, model_inputs, input_width, input_height)
  70. frame_count += 1
  71. elapsed_time = time.time() - start_time
  72. fps = frame_count / elapsed_time if elapsed_time > 0 else 0
  73. print(f"FPS: {fps:.2f}")
  74. out.write(output_image)
  75. cv2.imshow("Output Video", output_image)
  76. if cv2.waitKey(1) & 0xFF == ord('q'):
  77. break
  78. cap.release()
  79. out.release()
  80. cv2.destroyAllWindows()
  81. if __name__ == "__main__":
  82. main()

tool 

  1. import cv2
  2. import yaml
  3. import torch.cuda
  4. import numpy as np
  5. from PIL import Image
  6. import onnxruntime as ort
  7. # iou阈值
  8. iou_thresh = 0.6
  9. # 置信度
  10. confidence_thresh = 0.55
  11. # 类别
  12. label_path='coco128.yaml'
  13. #读取yaml文件
  14. def yaml_load(file=label_path):
  15. with open(file,errors='ignore') as f:
  16. return yaml.safe_load(f)
  17. classes = yaml_load(label_path)['names']
  18. color_palette = np.random.uniform(100, 255, size=(len(classes), 3))
  19. cuda = torch.cuda.is_available()
  20. providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
  21. def calculate_iou(box, other_boxes):
  22. # # top left x y
  23. x1 = np.maximum(box[0], np.array(other_boxes)[:, 0])
  24. y1 = np.maximum(box[1], np.array(other_boxes)[:, 1])
  25. # bottom right x y
  26. x2 = np.minimum(box[0] + box[2], np.array(other_boxes)[:, 0] + np.array(other_boxes)[:, 2])
  27. y2 = np.minimum(box[1] + box[3], np.array(other_boxes)[:, 1] + np.array(other_boxes)[:, 3])
  28. # 计算交集区域的面积
  29. intersection_area = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1)
  30. # 计算给定边界框的面积
  31. box_area = box[2] * box[3]
  32. # 计算其他边界框的面积
  33. other_boxes_area = np.array(other_boxes)[:, 2] * np.array(other_boxes)[:, 3]
  34. # 计算IoU值
  35. iou = intersection_area / (box_area + other_boxes_area - intersection_area)
  36. return iou
  37. def custom_NMSBoxes(boxes, scores, confidence_threshold, iou_threshold):
  38. # 如果没有边界框,则直接返回空列表
  39. if len(boxes) == 0:
  40. return []
  41. # 将得分和边界框转换为NumPy数组
  42. scores = np.array(scores)
  43. boxes = np.array(boxes)
  44. # 根据置信度阈值过滤边界框
  45. mask = scores > confidence_threshold
  46. filtered_boxes = boxes[mask]
  47. filtered_scores = scores[mask]
  48. # 如果过滤后没有边界框,则返回空列表
  49. if len(filtered_boxes) == 0:
  50. return []
  51. # 根据置信度得分对边界框进行排序
  52. sorted_indices = np.argsort(filtered_scores)[::-1]
  53. # 初始化一个空列表来存储选择的边界框索引
  54. indices = []
  55. # 当还有未处理的边界框时,循环继续
  56. while len(sorted_indices) > 0:
  57. # 选择得分最高的边界框索引
  58. current_index = sorted_indices[0]
  59. indices.append(current_index)
  60. # 如果只剩一个边界框,则结束循环
  61. if len(sorted_indices) == 1:
  62. break
  63. # 获取当前边界框和其他边界框
  64. current_box = filtered_boxes[current_index]
  65. other_boxes = filtered_boxes[sorted_indices[1:]]
  66. # 计算当前边界框与其他边界框的IoU
  67. iou = calculate_iou(current_box, other_boxes)
  68. # 找到IoU低于阈值的边界框,即与当前边界框不重叠的边界框
  69. non_overlapping_indices = np.where(iou <= iou_threshold)[0]
  70. # 更新sorted_indices以仅包含不重叠的边界框
  71. sorted_indices = sorted_indices[non_overlapping_indices + 1]
  72. # 返回选择的边界框索引
  73. return indices
  74. def draw_detections(img, box, score, class_id):
  75. # 提取边界框的坐标
  76. x1, y1, w, h = box
  77. # 根据类别ID检索颜色
  78. color = color_palette[class_id]
  79. # 在图像上绘制边界框
  80. cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)
  81. # 创建标签文本,包括类名和得分
  82. label = f'{classes[class_id]}: {score:.2f}'
  83. # 计算标签文本的尺寸
  84. (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
  85. # 计算标签文本的位置
  86. label_x = x1
  87. label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10
  88. # 绘制填充的矩形作为标签文本的背景
  89. cv2.rectangle(img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED)
  90. # 在图像上绘制标签文本
  91. cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)
  92. def preprocess(img, input_width, input_height):
  93. # 获取输入图像的高度和宽度
  94. img_height, img_width = img.shape[:2]
  95. # 将图像颜色空间从BGR转换为RGB
  96. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  97. # 将图像大小调整为匹配输入形状
  98. img = cv2.resize(img, (input_width, input_height))
  99. # 通过除以255.0来归一化图像数据
  100. image_data = np.array(img) / 255.0
  101. # 转置图像,使通道维度为第一维
  102. image_data = np.transpose(image_data, (2, 0, 1)) # 通道首
  103. # 扩展图像数据的维度以匹配预期的输入形状
  104. image_data = np.expand_dims(image_data, axis=0).astype(np.float32)
  105. # 返回预处理后的图像数据
  106. return image_data, img_height, img_width
  107. def postprocess(input_image, output, input_width, input_height, img_width, img_height):
  108. # 转置和压缩输出以匹配预期的形状
  109. outputs = np.transpose(np.squeeze(output[0]))
  110. # 获取输出数组的行数
  111. rows = outputs.shape[0]
  112. # 用于存储检测的边界框、得分和类别ID的列表
  113. boxes = []
  114. scores = []
  115. class_ids = []
  116. # 计算边界框坐标的缩放因子
  117. x_factor = img_width / input_width
  118. y_factor = img_height / input_height
  119. # 遍历输出数组的每一行
  120. for i in range(rows):
  121. # 从当前行提取类别得分
  122. classes_scores = outputs[i][4:]
  123. # 找到类别得分中的最大得分
  124. max_score = np.amax(classes_scores)
  125. # 如果最大得分高于置信度阈值
  126. if max_score >= confidence_thresh:
  127. # 获取得分最高的类别ID
  128. class_id = np.argmax(classes_scores)
  129. # 从当前行提取边界框坐标
  130. x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]
  131. # 计算边界框的缩放坐标
  132. left = int((x - w / 2) * x_factor)
  133. top = int((y - h / 2) * y_factor)
  134. width = int(w * x_factor)
  135. height = int(h * y_factor)
  136. # 将类别ID、得分和框坐标添加到各自的列表中
  137. class_ids.append(class_id)
  138. scores.append(max_score)
  139. boxes.append([left, top, width, height])
  140. # 应用非最大抑制过滤重叠的边界框
  141. indices = custom_NMSBoxes(boxes, scores, confidence_thresh, iou_thresh)
  142. # 遍历非最大抑制后的选定索引
  143. for i in indices:
  144. # 根据索引获取框、得分和类别ID
  145. box = boxes[i]
  146. score = scores[i]
  147. class_id = class_ids[i]
  148. # 在输入图像上绘制检测结果
  149. draw_detections(input_image, box, score, class_id)
  150. # 返回修改后的输入图像
  151. return input_image
  152. def init_detect_model(model_path):
  153. # 使用ONNX模型文件创建一个推理会话,并指定执行提供者
  154. session = ort.InferenceSession(model_path, providers=providers)
  155. # 获取模型的输入信息
  156. model_inputs = session.get_inputs()
  157. # 获取输入的形状,用于后续使用
  158. input_shape = model_inputs[0].shape
  159. # 从输入形状中提取输入宽度
  160. input_width = input_shape[2]
  161. # 从输入形状中提取输入高度
  162. input_height = input_shape[3]
  163. # 返回会话、模型输入信息、输入宽度和输入高度
  164. return session, model_inputs, input_width, input_height
  165. def detect_object(image, session, model_inputs, input_width, input_height):
  166. # 如果输入的图像是PIL图像对象,将其转换为NumPy数组
  167. if isinstance(image, Image.Image):
  168. result_image = np.array(image)
  169. else:
  170. # 否则,直接使用输入的图像(假定已经是NumPy数组)
  171. result_image = image
  172. # 预处理图像数据,调整图像大小并可能进行归一化等操作
  173. img_data, img_height, img_width = preprocess(result_image, input_width, input_height)
  174. # 使用预处理后的图像数据进行推理
  175. outputs = session.run(None, {model_inputs[0].name: img_data})
  176. # 对推理结果进行后处理,例如解码检测框,过滤低置信度的检测等
  177. output_image = postprocess(result_image, outputs, input_width, input_height, img_width, img_height)
  178. # 返回处理后的图像
  179. return output_image

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小惠珠哦/article/detail/1015260
推荐阅读
相关标签
  

闽ICP备14008679号