当前位置:   article > 正文

Yolov5 Onnxruntime Python

yolov5 onnxruntime
  1. import cv2
  2. import numpy as np
  3. import onnxruntime as ort
  4. def readClassesNames(file_path):
  5. with open(file_path, encoding='utf-8') as f:
  6. class_names = f.readlines()
  7. class_names = [c.strip() for c in class_names]
  8. return class_names
  9. classes_names = 'coco.names'
  10. classes = readClassesNames(classes_names)
  11. image = cv2.imread('bus.jpg')
  12. image_height, image_width = image.shape[:2]
  13. model_path = 'yolov5n.onnx'
  14. start_time = cv2.getTickCount()
  15. session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
  16. conf_thresold = 0.25
  17. iou_threshold = 0.45
  18. score_thresold = 0.25
  19. model_inputs = session.get_inputs()
  20. input_names = [model_inputs[i].name for i in range(len(model_inputs))]
  21. input_shape = model_inputs[0].shape
  22. model_output = session.get_outputs()
  23. output_names = [model_output[i].name for i in range(len(model_output))]
  24. input_height, input_width = input_shape[2:]
  25. image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  26. resized = cv2.resize(image_rgb, (input_width, input_height))
  27. input_image = resized / 255.0
  28. input_image = input_image.transpose(2,0,1)
  29. input_tensor = input_image[np.newaxis, :, :, :].astype(np.float32)
  30. outputs = session.run(output_names, {input_names[0]: input_tensor})[0]
  31. predictions = np.squeeze(outputs)
  32. scores = np.max(predictions[:, 4:5], axis=1)
  33. predictions = predictions[scores > score_thresold, :]
  34. scores = scores[scores > score_thresold]
  35. class_ids = np.argmax(predictions[:, 5:], axis=1)
  36. boxes = predictions[:, :4]
  37. input_shape = np.array([input_width, input_height, input_width, input_height])
  38. boxes = np.divide(boxes, input_shape, dtype=np.float32)
  39. boxes *= np.array([image_width, image_height, image_width, image_height])
  40. boxes = boxes.astype(np.int32)
  41. indices = cv2.dnn.NMSBoxes(boxes, scores, score_threshold=conf_thresold, nms_threshold=iou_threshold)
  42. detections = []
  43. def xywh2xyxy(x):
  44. y = np.copy(x)
  45. y[..., 0] = x[..., 0] - x[..., 2] / 2
  46. y[..., 1] = x[..., 1] - x[..., 3] / 2
  47. y[..., 2] = x[..., 0] + x[..., 2] / 2
  48. y[..., 3] = x[..., 1] + x[..., 3] / 2
  49. return y
  50. for (bbox, score, label) in zip(xywh2xyxy(boxes[indices]), scores[indices], class_ids[indices]):
  51. bbox = bbox.round().astype(np.int32).tolist()
  52. cls_id = int(label)
  53. cls = classes[cls_id]
  54. cv2.rectangle(image, tuple(bbox[:2]), tuple(bbox[2:]), (0,0,255), 2, 8)
  55. cv2.rectangle(image, (bbox[0], (bbox[1]-20)), (bbox[2], bbox[1]), (0,255,255), -1)
  56. cv2.putText(image, f'{cls}', (bbox[0], bbox[1] - 5),
  57. cv2.FONT_HERSHEY_PLAIN,2, [225, 0, 0], thickness=2)
  58. end_time = cv2.getTickCount()
  59. t = (end_time - start_time) / cv2.getTickFrequency()
  60. fps = 1 / t
  61. print(f"EStimated FPS: {fps:.2f}")
  62. cv2.putText(image, 'FPS: {:.2f}'.format(fps), (20, 40), cv2.FONT_HERSHEY_PLAIN, 2, [225, 0, 0], 2, 8);
  63. cv2.imshow("ONNXRUNTIME", image)
  64. cv2.waitKey(0)

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/从前慢现在也慢/article/detail/821679
推荐阅读
相关标签
  

闽ICP备14008679号