SAHI 资料
yolov8示例代码: https://github.com/obss/sahi/blob/main/demo/inference_for_yolov8.ipynb
测试图像: https://github.com/obss/sahi/blob/main/tests/data/small-vehicles1.jpeg
原理介绍: https://learnopencv.com/slicing-aided-hyper-inference/
sahi命令行使用说明: https://github.com/obss/sahi/blob/main/docs/cli.md#predict-command-usage
步骤1: 模型初始化
SAHI 默认支持yolov5/yolov8/mmdet等多种预测网络, 我们可以直接使用yolov8的预训练模型文件, 下面是集成yolov8模型的示例代码:
- detection_model = AutoDetectionModel.from_pretrained(
- model_type='yolov8',
- model_path=yolov8_model_path,
- confidence_threshold=0.3,
- device="cpu", # or 'cuda:0'
- )
步骤2: 进行推理:
SAHI 不仅提供了slice 版推理函数 get_sliced_prediction()
, 而且也提供了原始Yolo的简单封装推理函数 get_prediction()
, 这两个函数返回类型统一为 sahi.prediction.PredictionResult
, 这样我们可以方便切换不同predict函数.
步骤3: 使用推理结果对象做进一步处理
预测函数返回类 sahi.prediction.PredictionResult 成员:
- export_visuals()函数, 可以将推理结果保存为png图片
object_prediction_list
成员: 得到 detection object list, 每个detection object 类型都为 ObjectPrediction 类.- ObjectPrediction类成员:
. bbox: BoundingBox: <(321.0, 322.0, 383.0, 363.0), w: 62.0, h: 41.0>,
. mask: None,
. score: PredictionScore: <value: 0.9093314409255981>,
. category: Category: <id: 2, name: car>
代码
- import os
- from IPython import display
- import ultralytics
- from ultralytics import YOLO, settings
- from os import path
- from sahi import AutoDetectionModel
- from sahi.utils.cv import read_image
- from sahi.predict import get_prediction, get_sliced_prediction
- from IPython.display import Image
-
- def yolov8_predict():
- image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
- yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
- model = YOLO(yolov8_model_path)
- results_list = model.predict(source=[image_file1], show=False, save=True, save_conf=True,
- save_txt=True)
- for results in results_list:
- boxes = results.boxes
- speed = results.speed
- names = results.names
- json = results.tojson()
- image_path = results.path
- print("====")
- print(image_path)
- print(names)
- print(json)
-
- def sahi_orginal_predict():
- image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
- yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
- config_path=r"D:\my_workspace\py_code\yolo8\Lib\site-packages\ultralytics\cfg\default.yaml",
-
- # 模型生产函数可调控yolo的参数非常少, 我们只能通过 site-packages\ultralytics\cfg\default.yaml 做进一步设置,
- # 比如设置 classes =[2] , 仅仅输出 car 类型
- detection_model=AutoDetectionModel.from_pretrained(
- model_type='yolov8',
- model_path=yolov8_model_path,
- confidence_threshold=0.2,
- device="cpu", # or 'cuda:0'
- )
-
- result = get_prediction(
- image= image_file1,
- detection_model= detection_model,
- )
- for obj in result.object_prediction_list:
- category = obj.category
- #print("====")
- #print(category)
-
- result.export_visuals(
- export_dir=r"D:\my_workspace\source\opencv\yolov8\WinFormsApp1",
- file_name="prediction_visual3",
- hide_labels=False,
- hide_conf=False)
- #Image("demo_data/prediction_visual3.png")
-
-
- def sahi_sliced_predict():
- image_file1 = r"C:\Users\dorothy\Downloads\small-vehicles1.jpeg"
- yolov8_model_path=r"D:\my_workspace\py_code\yolo8\Scripts\yolov8m.pt"
-
- # 模型生产函数可调控yolo的参数非常少, 我们只能通过 site-packages\ultralytics\cfg\default.yaml 做进一步设置,
- # 比如设置 classes =[2] , 仅仅输出 car 类型
- detection_model=AutoDetectionModel.from_pretrained(
- model_type='yolov8',
- model_path=yolov8_model_path,
- confidence_threshold=0.2,
- device="cpu", # or 'cuda:0'
- )
-
- result = get_sliced_prediction(
- image= image_file1,
- detection_model= detection_model,
- slice_height=256,
- slice_width=256,
- overlap_height_ratio=0.25,
- overlap_width_ratio=0.25,
- postprocess_type="NMS",
- verbose=2,
- )
- result.export_visuals(
- export_dir=r"D:\my_workspace\source\opencv\yolov8\WinFormsApp1",
- file_name="prediction_visual4",
- hide_labels=False,
- hide_conf=False)
- for obj in result.object_prediction_list:
- category = obj.category
- #print("====")
- #print(category)
- #Image("demo_data/prediction_visual4.png")
-
- if __name__ == '__main__':
- display.clear_output()
- ultralytics.checks()
- #yolov8_predict()
- #sahi_orginal_predict()
- sahi_sliced_predict()