当前位置:   article > 正文

YOLOv8结合SAHI推理图像和视频_sahi的视频检测

sahi的视频检测


前言

在上一篇文章中,我们深入探讨了如何通过结合YOLOv8和SAHI来增强小目标检测效果
,并计算了相关评估指标,虽然我们也展示了可视化功能,但是这些功能往往需要结合实际的ground truth(GT)数据进行对比,这在实际操作中可能会稍显不便。

为了进一步简化操作,这篇文章将直接分享可以用来推理图像和视频的代码,通过这段代码,我们能够更加方便地使用SAHI进行小目标检测,而不需要反复处理和对比GT数据。

不多说啦,以下是完整的代码示例,供大家参考使用
在这里插入图片描述


视频效果

b站链接:使用SAHI增强YOLOv8推理,提升小目标检测效果(附教程)


必要环境

  1. 参考上期博客
    地址:使用YOLOv8+SAHI增强小目标检测效果并计算评估指标

一、完整代码

import os
import cv2
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import argparse
from tqdm import tqdm
import time

parser = argparse.ArgumentParser(description="Object Detection Evaluation Script")
parser.add_argument('--filepath', type=str, default='test/images', help='Path to the images folder or video file')
parser.add_argument('--output_dir', type=str, default='output', help='Directory to save the output images or video')

parser.add_argument('--model_type', type=str, default='yolov8', help='Type of the detection model')
parser.add_argument('--model_path', type=str, default='yolov8n.pt', help='Path to the model weights')
parser.add_argument('--confidence_threshold', type=float, default=0.5, help='Confidence threshold for the model')
parser.add_argument('--device', type=str, default="cuda:0", help='Device to run the model on')
parser.add_argument('--slice_height', type=int, default=256, help='Height of the image slices')
parser.add_argument('--slice_width', type=int, default=256, help='Width of the image slices')
parser.add_argument('--overlap_height_ratio', type=float, default=0.2, help='Overlap height ratio for slicing')
parser.add_argument('--overlap_width_ratio', type=float, default=0.2, help='Overlap width ratio for slicing')
parser.add_argument('--visualize_predictions', action='store_true', default=False, help='Visualize prediction results')
parser.add_argument('--images_format', type=str, nargs='+', default=['.png', '.jpg', '.jpeg'],
                    help='List of acceptable image formats')
parser.add_argument('--videos_format', type=str, nargs='+', default=['.mp4', '.avi'],
                    help='List of acceptable video formats')
args = parser.parse_args()


def get_color(idx):
    idx = int(idx) + 5
    return ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)


def load_detection_model():
    return AutoDetectionModel.from_pretrained(
        model_type=args.model_type,
        model_path=args.model_path,
        confidence_threshold=args.confidence_threshold,
        device=args.device
    )


def process_image(image_name, model):
    img_path = os.path.join(args.filepath, image_name)
    img_vis = cv2.imread(img_path)

    result = get_sliced_prediction(
        img_path,
        model,
        slice_height=args.slice_height,
        slice_width=args.slice_width,
        overlap_height_ratio=args.overlap_height_ratio,
        overlap_width_ratio=args.overlap_width_ratio,
    )

    for pred in result.object_prediction_list:
        bbox = pred.bbox
        cls = pred.category.id
        score = pred.score.value
        name = pred.category.name
        xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy

        cv2.rectangle(img_vis, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                      get_color(cls), 2, cv2.LINE_AA)
        cv2.putText(img_vis, f"{name} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                    cv2.FONT_HERSHEY_COMPLEX, 0.5, get_color(cls), thickness=2)

    if args.visualize_predictions:
        cv2.imshow(image_name, img_vis)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    # 保存结果图像
    output_path = os.path.join(args.output_dir, image_name)
    os.makedirs(args.output_dir, exist_ok=True)
    cv2.imwrite(output_path, img_vis)


def process_video(video_path, model):
    cap = cv2.VideoCapture(video_path)
    output_path = os.path.join(args.output_dir, os.path.basename(video_path))
    os.makedirs(args.output_dir, exist_ok=True)

    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = None
    fps_list = []

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        start_time = time.time()

        result = get_sliced_prediction(
            frame,
            model,
            slice_height=args.slice_height,
            slice_width=args.slice_width,
            overlap_height_ratio=args.overlap_height_ratio,
            overlap_width_ratio=args.overlap_width_ratio,
        )

        for pred in result.object_prediction_list:
            bbox = pred.bbox
            cls = pred.category.id
            score = pred.score.value
            name = pred.category.name
            xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy

            cv2.rectangle(frame, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                          get_color(cls), 2, cv2.LINE_AA)
            cv2.putText(frame, f"{name} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(cls), thickness=2)

        end_time = time.time()
        fps = 1 / (end_time - start_time)
        fps_list.append(fps)
        cv2.putText(frame, f"FPS: {fps:.2f}", (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 4)

        if out is None:
            frame_height, frame_width = frame.shape[:2]
            out = cv2.VideoWriter(output_path, fourcc, cap.get(cv2.CAP_PROP_FPS), (frame_width, frame_height))

        out.write(frame)

        if args.visualize_predictions:
            cv2.imshow('Video', frame)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    cap.release()
    out.release()
    cv2.destroyAllWindows()

    avg_fps = sum(fps_list) / len(fps_list)
    print(f"Average FPS: {avg_fps:.2f}")


def main():
    detection_model = load_detection_model()

    if os.path.isfile(args.filepath) and os.path.splitext(args.filepath)[1].lower() in args.videos_format:
        process_video(args.filepath, detection_model)
    else:
        image_names = [name for name in os.listdir(args.filepath) if
                       os.path.splitext(name)[1].lower() in args.images_format]

        for i, image_name in enumerate(tqdm(image_names, desc="Processing images")):
            process_image(image_name, detection_model)


if __name__ == "__main__":
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154

二、运行方法

1、 推理图像

  • 将文件命名为 yolov8_sahi_inference.py
  • 运行如下命令,输出结果将保存到output文件夹下
    python yolov8_sahi_inference.py --filepath test/images  --output_dir output --model_type yolov8 --model_path yolov8n.pt
    
    • 1

关键参数详解:

  • –filepath: 指定要推理的图像文件夹的路径
  • –output_dir: 指定保存推理结果的路径
  • –model_type: 指定检测模型的类型 (默认为yolov8)
  • –model_path: 指定模型权重文件的路径
  • –visualize_predictions: 如果设置True,将在运行代码过程中可视化推理结果

效果:
在这里插入图片描述

2、 推理视频

  • 将文件命名为 yolov8_sahi_inference.py
  • 运行如下命令,输出结果将保存到output文件夹下
    python yolov8_sahi_inference.py --filepath inputvideo.mp4  --output_dir output --model_type yolov8 --model_path yolov8n.pt --visualize_predictions
    
    • 1

关键参数详解:

  • –filepath: 指定要推理的视频路径
  • –output_dir: 指定保存推理结果的路径
  • –model_type: 指定检测模型的类型 (默认为yolov8)
  • –model_path: 指定模型权重文件的路径
  • –visualize_predictions: 如果设置True,将在运行代码过程中可视化推理结果

效果:
在这里插入图片描述


总结

本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!

最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG

学习交流群:995760755

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/859643
推荐阅读
相关标签
  

闽ICP备14008679号