当前位置:   article > 正文

机器视觉-SAHI

机器视觉-SAHI

SAHI 资料

yolov8示例代码: https://github.com/obss/sahi/blob/main/demo/inference_for_yolov8.ipynb
测试图像: https://github.com/obss/sahi/blob/main/tests/data/small-vehicles1.jpeg
原理介绍: https://learnopencv.com/slicing-aided-hyper-inference/
sahi命令行使用说明: https://github.com/obss/sahi/blob/main/docs/cli.md#predict-command-usage

步骤1: 模型初始化

SAHI 默认支持yolov5/yolov8/mmdet等多种预测网络, 我们可以直接使用yolov8的预训练模型文件, 下面是集成yolov8模型的示例代码:

  1. detection_model = AutoDetectionModel.from_pretrained(
  2. model_type='yolov8',
  3. model_path=yolov8_model_path,
  4. confidence_threshold=0.3,
  5. device="cpu", # or 'cuda:0'
  6. )

步骤2: 进行推理:

SAHI 不仅提供了slice 版推理函数 get_sliced_prediction(), 而且也提供了原始Yolo的简单封装推理函数 get_prediction(), 这两个函数返回类型统一为 sahi.prediction.PredictionResult, 这样我们可以方便切换不同predict函数.

步骤3: 使用推理结果对象做进一步处理

预测函数返回类 sahi.prediction.PredictionResult 成员:

  • export_visuals()函数, 可以将推理结果保存为png图片
  • object_prediction_list 成员: 得到 detection object list, 每个detection object 类型都为 ObjectPrediction 类.
  • ObjectPrediction类成员:
    . bbox: BoundingBox: <(321.0, 322.0, 383.0, 363.0), w: 62.0, h: 41.0>,
    . mask: None,
    . score: PredictionScore: <value: 0.9093314409255981>,
    . category: Category: <id: 2, name: car>

代码

  1. import os
  2. from IPython import display
  3. import ultralytics
  4. from ultralytics import YOLO, settings
  5. from os import path
  6. from sahi import AutoDetectionModel
  7. from sahi.utils.cv import read_image
  8. from sahi.predict import get_prediction, get_sliced_prediction
  9. from IPython.display import Image
  10. def yolov8_predict():
  11. image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
  12. yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
  13. model = YOLO(yolov8_model_path)
  14. results_list = model.predict(source=[image_file1], show=False, save=True, save_conf=True,
  15. save_txt=True)
  16. for results in results_list:
  17. boxes = results.boxes
  18. speed = results.speed
  19. names = results.names
  20. json = results.tojson()
  21. image_path = results.path
  22. print("====")
  23. print(image_path)
  24. print(names)
  25. print(json)
  26. def sahi_orginal_predict():
  27. image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
  28. yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
  29. config_path=r"D:\my_workspace\py_code\yolo8\Lib\site-packages\ultralytics\cfg\default.yaml",
  30. # 模型生产函数可调控yolo的参数非常少, 我们只能通过 site-packages\ultralytics\cfg\default.yaml 做进一步设置,
  31. # 比如设置 classes =[2] , 仅仅输出 car 类型
  32. detection_model=AutoDetectionModel.from_pretrained(
  33. model_type='yolov8',
  34. model_path=yolov8_model_path,
  35. confidence_threshold=0.2,
  36. device="cpu", # or 'cuda:0'
  37. )
  38. result = get_prediction(
  39. image= image_file1,
  40. detection_model= detection_model,
  41. )
  42. for obj in result.object_prediction_list:
  43. category = obj.category
  44. #print("====")
  45. #print(category)
  46. result.export_visuals(
  47. export_dir=r"D:\my_workspace\source\opencv\yolov8\WinFormsApp1",
  48. file_name="prediction_visual3",
  49. hide_labels=False,
  50. hide_conf=False)
  51. #Image("demo_data/prediction_visual3.png")
  52. def sahi_sliced_predict():
  53. image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
  54. yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
  55. # 模型生产函数可调控yolo的参数非常少, 我们只能通过 site-packages\ultralytics\cfg\default.yaml 做进一步设置,
  56. # 比如设置 classes =[2] , 仅仅输出 car 类型
  57. detection_model=AutoDetectionModel.from_pretrained(
  58. model_type='yolov8',
  59. model_path=yolov8_model_path,
  60. confidence_threshold=0.2,
  61. device="cpu", # or 'cuda:0'
  62. )
  63. result = get_sliced_prediction(
  64. image= image_file1,
  65. detection_model= detection_model,
  66. slice_height=256,
  67. slice_width=256,
  68. overlap_height_ratio=0.25,
  69. overlap_width_ratio=0.25,
  70. postprocess_type="NMS",
  71. verbose=2,
  72. )
  73. result.export_visuals(
  74. export_dir=r"D:\my_workspace\source\opencv\yolov8\WinFormsApp1",
  75. file_name="prediction_visual4",
  76. hide_labels=False,
  77. hide_conf=False)
  78. for obj in result.object_prediction_list:
  79. category = obj.category
  80. #print("====")
  81. #print(category)
  82. #Image("demo_data/prediction_visual4.png")
  83. if __name__ == '__main__':
  84. display.clear_output()
  85. ultralytics.checks()
  86. #yolov8_predict()
  87. #sahi_orginal_predict()
  88. sahi_sliced_predict()
声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号