当前位置:   article > 正文

2023.5--YOLOV5 版本6.2简化推理代码!打开即用!_yolov5推理代码简化

yolov5推理代码简化

这是一篇基于YOLOv5的对象检测代码的介绍。该代码是由Python编写的,用于管理和操作YOLOv5模型。这个类库的主要功能是提供一个方便的接口用于加载训练好的模型,处理输入的图像,并进行推理。此外,它还可以将检测结果绘制到原始图像上,以便于进行可视化。

首先,我们需要导入一些必要的库,包括OpenCV,Torch,以及YOLOv5的相关模块。这些库用于图像处理,深度学习模型操作,以及一些工具函数。

  1. #!/usr/bin/python3
  2. # -*- coding: utf-8 -*-
  3. import glob
  4. import sys
  5. from pathlib import Path
  6. import os
  7. FILE = Path(__file__).resolve()
  8. ROOT = FILE.parents[0] # YOLOv5 root directory
  9. if str(ROOT) not in sys.path:
  10. sys.path.append(str(ROOT)) # add ROOT to PATH
  11. ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
  12. import cv2
  13. import torch
  14. import torch.backends.cudnn as cudnn
  15. from models.common import DetectMultiBackend
  16. from utils.general import (check_img_size, non_max_suppression, scale_coords)
  17. from utils.torch_utils import select_device
  18. import numpy as np
  19. from utils.augmentations import letterbox
  20. from time import sleep

在Yolov5Manager类的初始化函数中,我们设置了模型的参数,包括权重文件的路径,标签名,图像大小,置信度阈值,IOU阈值,设备类型,以及是否使用半精度计算。然后,我们调用了DetectMultiBackend函数(Yolo原版)来加载模型,并根据所选择的设备(CPU或GPU)以及计算精度来设置模型。

  1. class Yolov5Manager(object):
  2. def __init__(self, weights=r'', names=[], imgsz=[640, 640], conf_thres=0.3,
  3. half=True, iou_thres=0.2,
  4. device='0',
  5. dnn=False, data=None):
  6. self.names = names
  7. self.half = half
  8. self.conf_thres = conf_thres
  9. self.iou_thres = iou_thres
  10. self.device = select_device(device)
  11. self.model = DetectMultiBackend(weights, device=self.device, dnn=dnn, data=data)
  12. self.stride, pt, jit, onnx, engine = self.model.stride, self.model.pt, self.model.jit, self.model.onnx, self.model.engine # endine:False onnx:False pt:True jit:False
  13. if self.names is None or len(self.names) == 0:
  14. self.names = self.model.names
  15. self.imgsz = check_img_size(imgsz, s=self.stride)
  16. self.auto = True #
  17. self.half &= (
  18. pt or jit or onnx or engine) and self.device.type != 'cpu' # FP16 supported on limited backends with CUDA
  19. if pt or jit:
  20. self.model.model.half() if self.half else self.model.model.float()
  21. cudnn.benchmark = True # set True to speed up constant image size inference
  22. self.model.warmup(imgsz=(1, 3, *imgsz), half=self.half) # warmup

我们还定义了一个内部函数__draw_image,它接受一个OpenCV格式的图像,一个表示检测框位置的列表,以及一些可选的参数,如标签,线条宽度,和颜色。这个函数会在图像上绘制检测框和标签。

  1. def __draw_image(self, opencv_img, box, label='', line_width=None, box_color=(255, 255, 255),
  2. txt_box_color=(200, 200, 200),
  3. txt_color=(0, 0, 255)):
  4. lw = line_width or max(round(sum(opencv_img.shape) / 2 * 0.005), 2) # line width
  5. p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
  6. cv2.rectangle(opencv_img, p1, p2, box_color, thickness=lw, lineType=cv2.LINE_AA)
  7. if label:
  8. tf = max(lw - 1, 1) # font thickness
  9. w, h = cv2.getTextSize(label, 0, fontScale=lw / 4, thickness=tf)[0] # text width, height
  10. outside = p1[1] - h - 1 >= 0 # label fits outside bo
  11. p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
  12. cv2.rectangle(opencv_img, p1, p2, txt_box_color, -1, cv2.LINE_AA) # filled 背景
  13. label = label.split(',')[0]
  14. cv2.putText(opencv_img, label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), 0, lw / 4, txt_color,
  15. thickness=tf, lineType=cv2.LINE_AA)
  16. return opencv_img

此外,inference_image函数接受一个OpenCV格式的图像,并将其预处理为模型可以接受的格式,然后进行推理。最后,它调用non_max_suppression函数来进行非极大值抑制,并返回检测结果。

  1. def inference_image(self, opencv_img):
  2. img = letterbox(opencv_img, self.imgsz, stride=self.stride, auto=self.auto)[0]
  3. img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
  4. img = np.ascontiguousarray(img)
  5. img = torch.from_numpy(img).to(self.device)
  6. img = img.half() if self.half else img.float() # uint8 to fp16/32
  7. img /= 255 # 0 - 255 to 0.0 - 1.0
  8. if len(img.shape) == 3:
  9. img = img[None] # expand for batch dim
  10. pred = self.model(img, augment=False, visualize=False)
  11. pred = non_max_suppression(pred, self.conf_thres, self.iou_thres, None, False, max_det=100)
  12. result_list = []
  13. # Process predictions
  14. for i, det in enumerate(pred): # per image
  15. if len(det):
  16. # Rescale boxes from img_size to im0 size
  17. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], opencv_img.shape).round()
  18. for *xyxy, conf, cls in reversed(det):
  19. result_list.append(
  20. [self.names[int(cls)], round(float(conf), 2), int(xyxy[0]), int(xyxy[1]), int(xyxy[2]),
  21. int(xyxy[3])])
  22. return result_list

我们还提供了一些实用的函数,如start_camerastart_video,和start_video_and_save。这些函数可以分别从摄像头,视频文件,或者保存视频文件中读取图像,并进行推理和绘图。

  1. @torch.no_grad()
  2. def start_camera(self, camera_index=0):
  3. cap = cv2.VideoCapture(camera_index)
  4. while True:
  5. ret, frame = cap.read()
  6. if not ret:
  7. break
  8. result_list = self.inference_image(frame)
  9. frame = self.draw_image(result_list, frame)
  10. cv2.imshow('frame', frame)
  11. if cv2.waitKey(1) & 0xFF == ord('q'):
  12. break
  13. cap.release()
  14. cv2.destroyAllWindows()
  1. @torch.no_grad()
  2. def start_video(self, video_file):
  3. cap = cv2.VideoCapture(video_file)
  4. while cap.isOpened():
  5. ret, frame = cap.read()
  6. if not ret:
  7. print('ret is False')
  8. break
  9. result_list = self.inference_image(frame)
  10. frame = self.draw_image(result_list, frame)
  11. cv2.imshow('frame', frame)
  12. if cv2.waitKey(1) & 0xFF == ord('q'):
  13. break
  14. cap.release()
  15. cv2.destroyAllWindows()
  1. @torch.no_grad()
  2. def start_video_and_save(self, video_file, save_file, show=True):
  3. cap = cv2.VideoCapture(video_file)
  4. # 获取视频帧速率 FPS
  5. frame_fps = int(cap.get(cv2.CAP_PROP_FPS))
  6. # 获取视频帧宽度和高度
  7. frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  8. frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  9. print("video fps={},width={},height={}".format(frame_fps, frame_width, frame_height))
  10. fourcc = cv2.VideoWriter_fourcc(*'XVID')
  11. out = cv2.VideoWriter(save_file, fourcc, frame_fps, (frame_width, frame_height))
  12. count = 0
  13. while cap.isOpened():
  14. ret, frame = cap.read()
  15. if not ret:
  16. print("read over or error!")
  17. break
  18. result_list = self.inference_image(frame)
  19. frame = self.draw_image(result_list, frame)
  20. out.write(frame)
  21. if show:
  22. cv2.imshow("result", frame)
  23. if cv2.waitKey(1) & 0xFF == ord('q'):
  24. break
  25. out.release()
  26. cap.release()
  27. cv2.destroyAllWindows()

load_labels函数可以从一个文本文件中加载标签名。

  1. @classmethod
  2. def load_labels(cls, name_file):
  3. with open(name_file, 'r') as f:
  4. lines = f.read().rstrip('\n').split('\n')
  5. return lines

最后,在主程序中,我们实例化了一个Yolov5Manager对象,并使用它来进行一些实际的检测任务。例如,我们可以从摄像头中读取图像,并实时进行检测和绘图。我们也可以从视频文件中读取图像,进行检测和绘图,并将结果保存为一个新的视频文件。

  1. if __name__ == '__main__':
  2. infer = Yolov5Manager(weights=r'yolov5s.pt',conf_thres=0.3,half=True,
  3. iou_thres=0.2,device='0',)
  4. beg = time.time()
  5. img = r'cccccc.png'
  6. img = cv2.imread(img)
  7. result_list = infer.inference_image(img)
  8. infer.imshow(img, result_list)
  9. print(result_list)

这个代码库提供了一个非常方便的接口,使得我们可以轻松地使用YOLOv5模型进行对象检测。我们可以通过修改和扩展这个代码库来满足我们的特定需求。

在此,我想推荐大家加入我们的YOLO目标检测交流学习群。群号是732818397。在这个群里,我们可以一起学习和探讨关于YOLO目标检测的各种问题和挑战。无论你是初学者还是有经验的专业人士,我们都欢迎你的加入。希望我们能在学习和交流的过程中共同进步,共同提高。期待在群里遇见你。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/菜鸟追梦旅行/article/detail/371485
推荐阅读
相关标签
  

闽ICP备14008679号