赞
踩
在上一篇文章中,我们深入探讨了如何通过结合YOLOv8和SAHI来增强小目标检测效果
,并计算了相关评估指标,虽然我们也展示了可视化功能,但是这些功能往往需要结合实际的ground truth(GT)数据进行对比,这在实际操作中可能会稍显不便。
为了进一步简化操作,这篇文章将直接分享可以用来推理图像和视频的代码,通过这段代码,我们能够更加方便地使用SAHI进行小目标检测,而不需要反复处理和对比GT数据。
不多说啦,以下是完整的代码示例,供大家参考使用
b站链接:使用SAHI增强YOLOv8推理,提升小目标检测效果(附教程)
import os
import cv2
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
import argparse
from tqdm import tqdm
import time
parser = argparse.ArgumentParser(description="Object Detection Evaluation Script")
parser.add_argument('--filepath', type=str, default='test/images', help='Path to the images folder or video file')
parser.add_argument('--output_dir', type=str, default='output', help='Directory to save the output images or video')
parser.add_argument('--model_type', type=str, default='yolov8', help='Type of the detection model')
parser.add_argument('--model_path', type=str, default='yolov8n.pt', help='Path to the model weights')
parser.add_argument('--confidence_threshold', type=float, default=0.5, help='Confidence threshold for the model')
parser.add_argument('--device', type=str, default="cuda:0", help='Device to run the model on')
parser.add_argument('--slice_height', type=int, default=256, help='Height of the image slices')
parser.add_argument('--slice_width', type=int, default=256, help='Width of the image slices')
parser.add_argument('--overlap_height_ratio', type=float, default=0.2, help='Overlap height ratio for slicing')
parser.add_argument('--overlap_width_ratio', type=float, default=0.2, help='Overlap width ratio for slicing')
parser.add_argument('--visualize_predictions', action='store_true', default=False, help='Visualize prediction results')
parser.add_argument('--images_format', type=str, nargs='+', default=['.png', '.jpg', '.jpeg'],
help='List of acceptable image formats')
parser.add_argument('--videos_format', type=str, nargs='+', default=['.mp4', '.avi'],
help='List of acceptable video formats')
args = parser.parse_args()
def get_color(idx):
idx = int(idx) + 5
return ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
def load_detection_model():
return AutoDetectionModel.from_pretrained(
model_type=args.model_type,
model_path=args.model_path,
confidence_threshold=args.confidence_threshold,
device=args.device
)
def process_image(image_name, model):
img_path = os.path.join(args.filepath, image_name)
img_vis = cv2.imread(img_path)
result = get_sliced_prediction(
img_path,
model,
slice_height=args.slice_height,
slice_width=args.slice_width,
overlap_height_ratio=args.overlap_height_ratio,
overlap_width_ratio=args.overlap_width_ratio,
)
for pred in result.object_prediction_list:
bbox = pred.bbox
cls = pred.category.id
score = pred.score.value
name = pred.category.name
xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy
cv2.rectangle(img_vis, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
get_color(cls), 2, cv2.LINE_AA)
cv2.putText(img_vis, f"{name} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
cv2.FONT_HERSHEY_COMPLEX, 0.5, get_color(cls), thickness=2)
if args.visualize_predictions:
cv2.imshow(image_name, img_vis)
cv2.waitKey(0)
cv2.destroyAllWindows()
# 保存结果图像
output_path = os.path.join(args.output_dir, image_name)
os.makedirs(args.output_dir, exist_ok=True)
cv2.imwrite(output_path, img_vis)
def process_video(video_path, model):
cap = cv2.VideoCapture(video_path)
output_path = os.path.join(args.output_dir, os.path.basename(video_path))
os.makedirs(args.output_dir, exist_ok=True)
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = None
fps_list = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
start_time = time.time()
result = get_sliced_prediction(
frame,
model,
slice_height=args.slice_height,
slice_width=args.slice_width,
overlap_height_ratio=args.overlap_height_ratio,
overlap_width_ratio=args.overlap_width_ratio,
)
for pred in result.object_prediction_list:
bbox = pred.bbox
cls = pred.category.id
score = pred.score.value
name = pred.category.name
xmin_pd, ymin_pd, xmax_pd, ymax_pd = bbox.minx, bbox.miny, bbox.maxx, bbox.maxy
cv2.rectangle(frame, (int(xmin_pd), int(ymin_pd)), (int(xmax_pd), int(ymax_pd)),
get_color(cls), 2, cv2.LINE_AA)
cv2.putText(frame, f"{name} {score:.2f}", (int(xmin_pd), int(ymin_pd - 5)),
cv2.FONT_HERSHEY_COMPLEX, 0.8, get_color(cls), thickness=2)
end_time = time.time()
fps = 1 / (end_time - start_time)
fps_list.append(fps)
cv2.putText(frame, f"FPS: {fps:.2f}", (30, 60), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 4)
if out is None:
frame_height, frame_width = frame.shape[:2]
out = cv2.VideoWriter(output_path, fourcc, cap.get(cv2.CAP_PROP_FPS), (frame_width, frame_height))
out.write(frame)
if args.visualize_predictions:
cv2.imshow('Video', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
cap.release()
out.release()
cv2.destroyAllWindows()
avg_fps = sum(fps_list) / len(fps_list)
print(f"Average FPS: {avg_fps:.2f}")
def main():
detection_model = load_detection_model()
if os.path.isfile(args.filepath) and os.path.splitext(args.filepath)[1].lower() in args.videos_format:
process_video(args.filepath, detection_model)
else:
image_names = [name for name in os.listdir(args.filepath) if
os.path.splitext(name)[1].lower() in args.images_format]
for i, image_name in enumerate(tqdm(image_names, desc="Processing images")):
process_image(image_name, detection_model)
if __name__ == "__main__":
main()
python yolov8_sahi_inference.py --filepath test/images --output_dir output --model_type yolov8 --model_path yolov8n.pt
关键参数详解:
效果:
python yolov8_sahi_inference.py --filepath inputvideo.mp4 --output_dir output --model_type yolov8 --model_path yolov8n.pt --visualize_predictions
关键参数详解:
效果:
本期博客就到这里啦,喜欢的小伙伴们可以点点关注,感谢!
最近经常在b站上更新一些有关目标检测的视频,大家感兴趣可以来看看 https://b23.tv/1upjbcG
学习交流群:995760755
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。