当前位置:   article > 正文

使用YOLOv8+SAHI增强小目标检测效果并计算评估指标_yolov8 sahi使用

yolov8 sahi使用


前言

最近看到不少同学苦恼于想要评估 YOLO + SAHI 的指标,但不知道具体操作步骤,我自己在网上搜到的内容也比较复杂,大部分还要收费,所以就决定自己写一个代码,实现从模型加载、图像处理、检测结果可视化到评估指标计算的全过程。这个代码基本已经做到即插即用啦,支持 YOLOv5、YOLOv8 等多种模型,话不多说,下面就开始讲解吧!(赶时间的小伙伴可以直接跳转到最后复制完整代码)

在这里插入图片描述


必要环境

我们需要安装如下几个库
1、OpenCV (cv2)
2、SAHI
3、tabulate
4、podm
5、tqdm
6、argparse

安装命令如下:

pip install opencv-python sahi tabulate podm tqdm argparse -i  https://pypi.tuna.tsinghua.edu.cn/simple
  • 1

一、代码结构

1、 训练参数解析

首先,我们利用 argparse 模块来设置命令行参数,以便灵活配置各种变量

parser = argparse.ArgumentParser(description="Object Detection Evaluation Script")
parser.add_argument('--filepath', type=str, default='val/images', help='Path to the images folder')
parser.add_argument('--annotation_folder', type=str, default='val/labels', help='Path to the annotation folder')

parser.add_argument('--model_type', type=str, default='yolov8', help='Type of the detection model')
parser.add_argument('--model_path', type=str, default='kitti_baseline/weights/best.pt',
                    help='Path to the model weights')
parser.add_argument('--confidence_threshold', type=float, default=0.4, help='Confidence threshold for the model')
parser.add_argument('--device', type=str, default="cuda:0", help='Device to run the model on')

parser.add_argument('--slice_height', type=int, default=256, help='Height of the image slices')
parser.add_argument('--slice_width', type=int, default=256, help='Width of the image slices')
parser.add_argument('--overlap_height_ratio', type=float, default=0.2, help='Overlap height ratio for slicing')
parser.add_argument('--overlap_width_ratio', type=float, default=0.2, help='Overlap width ratio for slicing')

parser.add_argument('--visualize_predictions', action='store_true', default=False, help='Visualize prediction results')
parser.add_argument('--visualize_annotations', action='store_true', default=False, help='Visualize annotation results')

parser.add_argument('--class_list', type=str, nargs='+',
                    default=['Pedestrian', 'Car', 'Van', 'Truck', 'Person_sitting', 'Cyclist', 'Tram'],
                    help='List of class names')
parser.add_argument('--images_format', type=str, nargs='+', default=['.png', '.jpg', '.jpeg'],
                    help='List of acceptable image formats')

args = parser.parse_args()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25

关键参数详解:

  • –filepath: 指定图像文件夹的路径(images)

  • –annotation_folder: 指定标注文件夹的路径 (labels)

  • –model_type: 指定检测模型的类型 (默认为yolov8)

  • –model_path: 指定模型权重文件的路径

  • –confidence_threshold: 指定模型的置信度阈值 (置信度高于这个阈值的检测框才会被保留)

  • –device: 指定运行模型的设备(如 cuda:0 或 cpu)

  • –slice_height: 指定图像切片的高度

  • –slice_width: 指定图像切片的宽度

  • –overlap_height_ratio: 指定切片的高度重叠比例

  • –overlap_width_ratio: 指定切片的宽度重叠比例

  • –visualize_predictions: 如果设置True,将可视化推理结果

  • –visualize_annotations: 如果设置True,将可视化标注结果

  • –class_list: 指定类名列表(可以直接复制数据集.yaml文件中 变量names后面的列表)

  • –images_format: 指定可接受的图像格式列表

2、 核心代码解析

1.加载检测模型

调用AutoDetectionModel.from_pretrained函数来加载YOLOv8模型

def load_detection_model():
    return AutoDetectionModel.from_pretrained(
        model_type=args.model_type,
        model_path=args.model_path,
        confidence_threshold=args.confidence_threshold,
        device=args.device
    )
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

2. 处理图像

定义 process_image 函数来处理每张图像,该函数包含:

  • 从指定路径读取图像
  • 使用SAHI进行切片预测
  • 读取标注文件的真实框
  • 可视化预测和标注结果
def process_image(image_name, model, labels, detections):
    img_path = os.path.join(args.filepath, image_name)
    img_vis = cv2.imread(img_path)
    img_h, img_w, _ = img_vis.shape

    result = get_sliced_prediction(
        img_path,
        model,
        slice_height=args.slice_height,
        slice_width=args.slice_width,
        overlap_height_ratio=args.overlap_height_ratio,
        overlap_width_ratio=args.overlap_width_ratio,
        verbose = 0
    )

    anno_file = os.path.join(args.annotation_folder, image_name[:-4] + '.txt')
    annotations = read_boxes(anno_file, img_w, img_h)

    for anno in annotations:
        label, xmin_gt, ymin_gt, xmax_gt, ymax_gt = anno
        labels.append(BoundingBox.of_bbox(image_name, label, xmin_gt, ymin_gt, xmax_gt, ymax_gt))
        if args.visualize_annotations:
            cv2.rectangle(img_vis, (int(xmin_gt), int(ymin_gt)), (int(xmax_gt), int(ymax_gt)), get_color(label), 2,
                          cv2.LINE_AA)
            cv2.putText(img_vis, f"{args.class_list[label]}", (int(xmin_gt), int(ymin_gt - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(label), thickness=2)

    for pred in result.object_prediction_list:
        bbox = pred.bbox
        cls = pred.category.id
        score = pred.score.value
        xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy
        detections.append(BoundingBox.of_bbox(image_name, cls, xmin_pd, ymin_pd, xmax_pd, ymax_pd, score))

        if args.visualize_predictions:
            cv2.rectangle(img_vis, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                          get_color(cls + len(args.class_list)), 2, cv2.LINE_AA)
            cv2.putText(img_vis, f"{args.class_list[cls]} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(cls + len(args.class_list)), thickness=2)

    if args.visualize_predictions or args.visualize_annotations:
        cv2.imshow(image_name, img_vis)
        cv2.waitKey(0)
        cv2.destroyAllWindows()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
2.1 读取图像和标注文件
anno_file = os.path.join(args.annotation_folder, image_name[:-4] + '.txt')
annotations = read_boxes(anno_file, img_w, img_h)
  • 1
  • 2
  • anno_file:构建标注文件路径
  • annotations:使用 read_boxes 函数读取标注框
2.2 处理标注框
for anno in annotations:
    label, xmin_gt, ymin_gt, xmax_gt, ymax_gt = anno
    labels.append(BoundingBox.of_bbox(image_name, label, xmin_gt, ymin_gt, xmax_gt, ymax_gt))
    if args.visualize_annotations:
        cv2.rectangle(img_vis, (int(xmin_gt), int(ymin_gt)), (int(xmax_gt), int(ymax_gt)), get_color(label), 2,
                      cv2.LINE_AA)
        cv2.putText(img_vis, f"{args.class_list[label]}", (int(xmin_gt), int(ymin_gt - 5)),
                    cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(label), thickness=2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 循环遍历 annotations 中的每个标注框
  • labels.append:将标注框添加到 labels 列表
  • 如果 args.visualize_annotations 为True,将标注框绘制在图像上,并在框上方显示类别名称
2.3 处理预测结果
for pred in result.object_prediction_list:
    bbox = pred.bbox
    cls = pred.category.id
    score = pred.score.value
    xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy
    detections.append(BoundingBox.of_bbox(image_name, cls, xmin_pd, ymin_pd, xmax_pd, ymax_pd, score))

    if args.visualize_predictions:
        cv2.rectangle(img_vis, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                      get_color(cls + len(args.class_list)), 2, cv2.LINE_AA)
        cv2.putText(img_vis, f"{args.class_list[cls]} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                    cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(cls + len(args.class_list)), thickness=2)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 循环遍历 result.object_prediction_list 中的每个预测结果
  • detections.append:将预测结果添加到 detections 列表
  • 如果 args.visualize_predictions 为真True,将预测框绘制在图像上,并在框上方显示类别名称和置信度分数
2.4 显示图像
if args.visualize_predictions or args.visualize_annotations:
    cv2.imshow(image_name, img_vis)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
  • 1
  • 2
  • 3
  • 4
  • 如果 args.visualize_prediction=True 显示推理结果
  • 如果 args.visualize_annotations=True 显示标注结果
  • 如果 args.visualize_prediction=True or args.visualize_annotations=True 同时显示推理结果和标注结果
  • 如果 args.visualize_prediction=False or args.visualize_annotations=False 将不会可视化结果,而是直接计算评估指标

3. 评估模型

我们使用 podm 库来计算 PASCAL VOC 评估指标

def evaluate_model(labels, detections):
    results = get_pascal_voc_metrics(labels, detections, 0.5)
    table = [
        [args.class_list[int(class_id)], m.recall[-1], m.precision[-1], m.ap]
        for class_id, m in results.items() if m.num_groundtruth > 0
    ]
    map_score = MetricPerClass.mAP(results)
    print(tabulate(table, headers=["ClassID", "Recall", "Precision", "AP"], floatfmt=".2f"))
    print(f"\nmAP: {map_score:.4f}")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

4. 主函数

主函数中加载模型,遍历图像文件夹,处理每张图像,并在最后评估模型

def main():
    detection_model = load_detection_model()
    image_names = [name for name in os.listdir(args.filepath) if
                   os.path.splitext(name)[1].lower() in args.images_format]
    labels, detections = [], []

    for i, image_name in enumerate(tqdm(image_names, desc="Processing images")):
        process_image(image_name, detection_model, labels, detections)

    evaluate_model(labels, detections)

if __name__ == "__main__":
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

二、完整代码

import os
import cv2
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from tabulate import tabulate
from podm.metrics import BoundingBox, get_pascal_voc_metrics, MetricPerClass
import argparse
from tqdm import tqdm

parser = argparse.ArgumentParser(description="Object Detection Evaluation Script")
parser.add_argument('--filepath', type=str, default='val/images', help='Path to the images folder')
parser.add_argument('--annotation_folder', type=str, default='val/labels', help='Path to the annotation folder')

parser.add_argument('--model_type', type=str, default='yolov8', help='Type of the detection model')
parser.add_argument('--model_path', type=str, default='kitti_baseline/weights/best.pt',
                    help='Path to the model weights')
parser.add_argument('--confidence_threshold', type=float, default=0.4, help='Confidence threshold for the model')
parser.add_argument('--device', type=str, default="cuda:0", help='Device to run the model on')

parser.add_argument('--slice_height', type=int, default=256, help='Height of the image slices')
parser.add_argument('--slice_width', type=int, default=256, help='Width of the image slices')
parser.add_argument('--overlap_height_ratio', type=float, default=0.2, help='Overlap height ratio for slicing')
parser.add_argument('--overlap_width_ratio', type=float, default=0.2, help='Overlap width ratio for slicing')

parser.add_argument('--visualize_predictions', action='store_true', default=False, help='Visualize prediction results')
parser.add_argument('--visualize_annotations', action='store_true', default=False, help='Visualize annotation results')

parser.add_argument('--class_list', type=str, nargs='+',
                    default=['Pedestrian', 'Car', 'Van', 'Truck', 'Person_sitting', 'Cyclist', 'Tram'],
                    help='List of class names')
parser.add_argument('--images_format', type=str, nargs='+', default=['.png', '.jpg', '.jpeg'],
                    help='List of acceptable image formats')

args = parser.parse_args()


def get_color(idx):
    idx = int(idx) + 5
    return ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)


def read_boxes(txt_file, img_w, img_h):
    boxes = []
    with open(txt_file, 'r') as f:
        for line in f:
            items = line.strip().split()
            box = [
                int(items[0]),
                (float(items[1]) - float(items[3]) / 2) * img_w,
                (float(items[2]) - float(items[4]) / 2) * img_h,
                (float(items[1]) + float(items[3]) / 2) * img_w,
                (float(items[2]) + float(items[4]) / 2) * img_h
            ]
            boxes.append(box)
    return boxes


def load_detection_model():
    return AutoDetectionModel.from_pretrained(
        model_type=args.model_type,
        model_path=args.model_path,
        confidence_threshold=args.confidence_threshold,
        device=args.device
    )


def process_image(image_name, model, labels, detections):
    img_path = os.path.join(args.filepath, image_name)
    img_vis = cv2.imread(img_path)
    img_h, img_w, _ = img_vis.shape

    result = get_sliced_prediction(
        img_path,
        model,
        slice_height=args.slice_height,
        slice_width=args.slice_width,
        overlap_height_ratio=args.overlap_height_ratio,
        overlap_width_ratio=args.overlap_width_ratio,
        verbose = 0
    )

    anno_file = os.path.join(args.annotation_folder, image_name[:-4] + '.txt')
    annotations = read_boxes(anno_file, img_w, img_h)

    for anno in annotations:
        label, xmin_gt, ymin_gt, xmax_gt, ymax_gt = anno
        labels.append(BoundingBox.of_bbox(image_name, label, xmin_gt, ymin_gt, xmax_gt, ymax_gt))
        if args.visualize_annotations:
            cv2.rectangle(img_vis, (int(xmin_gt), int(ymin_gt)), (int(xmax_gt), int(ymax_gt)), get_color(label), 2,
                          cv2.LINE_AA)
            cv2.putText(img_vis, f"{args.class_list[label]}", (int(xmin_gt), int(ymin_gt - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(label), thickness=2)

    for pred in result.object_prediction_list:
        bbox = pred.bbox
        cls = pred.category.id
        score = pred.score.value
        xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy
        detections.append(BoundingBox.of_bbox(image_name, cls, xmin_pd, ymin_pd, xmax_pd, ymax_pd, score))

        if args.visualize_predictions:
            cv2.rectangle(img_vis, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
                          get_color(cls + len(args.class_list)), 2, cv2.LINE_AA)
            cv2.putText(img_vis, f"{args.class_list[cls]} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
                        cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(cls + len(args.class_list)), thickness=2)

    if args.visualize_predictions or args.visualize_annotations:
        cv2.imshow(image_name, img_vis)
        cv2.waitKey(0)
        cv2.destroyAllWindows()


def evaluate_model(labels, detections):
    results = get_pascal_voc_metrics(labels, detections, 0.5)
    table = [
        [args.class_list[int(class_id)], m.recall[-1], m.precision[-1], m.ap]
        for class_id, m in results.items() if m.num_groundtruth > 0
    ]
    map_score = MetricPerClass.mAP(results)
    print(tabulate(table, headers=["ClassID", "Recall", "Precision", "AP"], floatfmt=".2f"))
    print(f"\nmAP: {map_score:.4f}")


def main():
    detection_model = load_detection_model()
    image_names = [name for name in os.listdir(args.filepath) if
                   os.path.splitext(name)[1].lower() in args.images_format]
    labels, detections = [], []

    for i, image_name in enumerate(tqdm(image_names, desc="Processing images")):
        process_image(image_name, detection_model, labels, detections)

    evaluate_model(labels, detections)


if __name__ == "__main__":
    main()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137

三、效果展示

计算评估指标

在这里插入图片描述

可视化推理结果

在这里插入图片描述

可视化标注结果

在这里插入图片描述

同时可视化推理结果和标注结果

在这里插入图片描述


总结

本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!

最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG

学习交流群:995760755

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/知新_RL/article/detail/859623
推荐阅读
相关标签
  

闽ICP备14008679号