当前位置:   article > 正文

yolov8目标检测onnx推理及后处理实现_yolov8 onnx

yolov8 onnx

         使用onnx进行yolov8模型推理测试。首先从YOLOv8开源地址下载预训练模型,由于测试在CPU上进行,就只下载最小的YOLOv8n模型。

 YOLOv8n预训练模型为pytorch的pt格式,大小为6.2M,下载完成后,通过pytorch转换为onnx。转换脚本:

  1. import torch
  2. net = torch.load('yolov8n.pt', map_location='cpu')
  3. net.eval()
  4. dummpy_input = torch.randn(1, 3, 640, 640)
  5. torch.onnx.export(net, dummpy_input, 'yolov8n.onnx', export_params=True,
  6. input_names=['input'],
  7. output_names=['output'])

        完成模型转换后,接下来进行onnx推理测试。编写推理脚本前可以通过netron工具查看模型输入输出,可以看到yolov8输入为[1,3,640,640],输入为[1,84,8400]。

         YOLOv8输出shape跟yolo之前系列模型(YOLOv5输出为[25200,85]),有较大差异,查找一番后,发现yolov8在两个方面做了调整,一是取消了anchor(因为每个anchor对应3个bbox),因此总的bbox数降低三倍;二是取消了bbox的置信度,将bbox置信度与分类融合。

         为了复用之前YOLO系列的后处理代码(非极大值抑制),需要将YOLOv8输出结果进行处理,将分类预测中的最大值提取出来作为bbox置信度。将推理结果转换为[1,8400,85]形式。

 

  1. pred_class = pred[..., 4:]
  2. pred_conf = np.max(pred_class, axis=-1)
  3. pred = np.insert(pred, 4, pred_conf, axis=-1)

测试图片:

 测试结果:

完整的推理脚本:

  1. import onnxruntime as rt
  2. import numpy as np
  3. import cv2
  4. import matplotlib.pyplot as plt
  5. def nms(pred, conf_thres, iou_thres):
  6. conf = pred[..., 4] > conf_thres
  7. box = pred[conf == True]
  8. cls_conf = box[..., 5:]
  9. cls = []
  10. for i in range(len(cls_conf)):
  11. cls.append(int(np.argmax(cls_conf[i])))
  12. total_cls = list(set(cls))
  13. output_box = []
  14. for i in range(len(total_cls)):
  15. clss = total_cls[i]
  16. cls_box = []
  17. for j in range(len(cls)):
  18. if cls[j] == clss:
  19. box[j][5] = clss
  20. cls_box.append(box[j][:6])
  21. cls_box = np.array(cls_box)
  22. box_conf = cls_box[..., 4]
  23. box_conf_sort = np.argsort(box_conf)
  24. max_conf_box = cls_box[box_conf_sort[len(box_conf) - 1]]
  25. output_box.append(max_conf_box)
  26. cls_box = np.delete(cls_box, 0, 0)
  27. while len(cls_box) > 0:
  28. max_conf_box = output_box[len(output_box) - 1]
  29. del_index = []
  30. for j in range(len(cls_box)):
  31. current_box = cls_box[j]
  32. interArea = getInter(max_conf_box, current_box)
  33. iou = getIou(max_conf_box, current_box, interArea)
  34. if iou > iou_thres:
  35. del_index.append(j)
  36. cls_box = np.delete(cls_box, del_index, 0)
  37. if len(cls_box) > 0:
  38. output_box.append(cls_box[0])
  39. cls_box = np.delete(cls_box, 0, 0)
  40. return output_box
  41. def getIou(box1, box2, inter_area):
  42. box1_area = box1[2] * box1[3]
  43. box2_area = box2[2] * box2[3]
  44. union = box1_area + box2_area - inter_area
  45. iou = inter_area / union
  46. return iou
  47. def getInter(box1, box2):
  48. box1_x1, box1_y1, box1_x2, box1_y2 = box1[0] - box1[2] / 2, box1[1] - box1[3] / 2, \
  49. box1[0] + box1[2] / 2, box1[1] + box1[3] / 2
  50. box2_x1, box2_y1, box2_x2, box2_y2 = box2[0] - box2[2] / 2, box2[1] - box1[3] / 2, \
  51. box2[0] + box2[2] / 2, box2[1] + box2[3] / 2
  52. if box1_x1 > box2_x2 or box1_x2 < box2_x1:
  53. return 0
  54. if box1_y1 > box2_y2 or box1_y2 < box2_y1:
  55. return 0
  56. x_list = [box1_x1, box1_x2, box2_x1, box2_x2]
  57. x_list = np.sort(x_list)
  58. x_inter = x_list[2] - x_list[1]
  59. y_list = [box1_y1, box1_y2, box2_y1, box2_y2]
  60. y_list = np.sort(y_list)
  61. y_inter = y_list[2] - y_list[1]
  62. inter = x_inter * y_inter
  63. return inter
  64. def draw(img, xscale, yscale, pred):
  65. img_ = img.copy()
  66. if len(pred):
  67. for detect in pred:
  68. detect = [int((detect[0] - detect[2] / 2) * xscale), int((detect[1] - detect[3] / 2) * yscale),
  69. int((detect[0]+detect[2] / 2) * xscale), int((detect[1]+detect[3] / 2) * yscale)]
  70. img_ = cv2.rectangle(img, (detect[0], detect[1]), (detect[2], detect[3]), (0, 255, 0), 1)
  71. return img_
  72. if __name__ == '__main__':
  73. height, width = 640, 640
  74. img0 = cv2.imread('1.jpg')
  75. x_scale = img0.shape[1] / width
  76. y_scale = img0.shape[0] / height
  77. img = img0 / 255.
  78. img = cv2.resize(img, (width, height))
  79. img = np.transpose(img, (2, 0, 1))
  80. data = np.expand_dims(img, axis=0)
  81. sess = rt.InferenceSession('yolov8n.onnx')
  82. input_name = sess.get_inputs()[0].name
  83. label_name = sess.get_outputs()[0].name
  84. pred = sess.run([label_name], {input_name: data.astype(np.float32)})[0]
  85. pred = np.squeeze(pred)
  86. pred = np.transpose(pred, (1, 0))
  87. pred_class = pred[..., 4:]
  88. pred_conf = np.max(pred_class, axis=-1)
  89. pred = np.insert(pred, 4, pred_conf, axis=-1)
  90. result = nms(pred, 0.3, 0.45)
  91. ret_img = draw(img0, x_scale, y_scale, result)
  92. ret_img = ret_img[:, :, ::-1]
  93. plt.imshow(ret_img)
  94. plt.show()

 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/Gausst松鼠会/article/detail/162296
推荐阅读
相关标签
  

闽ICP备14008679号