当前位置:   article > 正文

调用yolov5模型基于opencv实现区域检测功能_opencv调用yolov5的pt文件进行实时检测代码

opencv调用yolov5的pt文件进行实时检测代码

调用yolov5模型基于opencv实现检测指定种类的物体是否在固定区域内

简介:

文件结构:

utils和models中的文件在yolov5官方文件中查找即可。这些是必需的。

本方法适用于yolov5的预训练模型和自己训练的pt模型,既可检测图片视频,也可实现实时检测摄像头的功能。

代码分析:

首先导入需要的包:

import cv2
import torch
from models.experimental import attempt_load
from utils.general import is_ascii, non_max_suppression, scale_coords, set_logging
from utils.plots import Annotator, colors
from utils.torch_utils import select_device

然后定义一些yolo需要使用到的参数

weights='pretrained/yolov5s.pt',# 指定网络权重的路径
conf_thres=0.6# 置信度的阈值,即置信度小于该值不显示
iou_thres=0.2# NMS IOU threshold
max_det=1# 能显示的最大检测数量,最多1000个,这里设置为1
device=''# cuda device,  0 or 0,1,2,3 or cpu
classes=0# 指定要检测种类,通过数据集的yaml文件得知每个种类
agnostic_nms=False,  # 跨类别NMS
line_thickness=2# bounding box检测框的粗细
half=False,  # 是否使用FP16版精度推理

在这里,我只用到了这些参数,如有其他需要,可参照yolov5官方源代码中的detect.py文件修改。

模型导入:

model = attempt_load(weights, device=device)

读取输入(视频或摄像头):

cap = cv2.VideoCapture('video2.mp4')
# cap = cv2.VideoCapture(0)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))  # 获取视频的宽度
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))  # 获取视频的高度

获取视频的宽度与高度是为了方便后面在原视频上划定检测区域。下面是检测区域的规定,(x,y)是检测区域的左上角坐标,w、h为检测区域的宽度和高度。

w = 500
h = 700
x = int(width / 2 - w / 2)
y =
int(height / 2 - h / 2)

font = cv2.FONT_HERSHEY_SIMPLEX  # 设置字体样式

调用模型检测:

# Inference
pred = model(img, augment=False, visualize=False)[0]

# NMS
pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)

其中,augment是模型增强功能,可以提升检测效果,但是对硬件要求较高。

获得检测框的坐标、置信度、标签,并与划定检测区域进行比较,判断是否在框内,并反馈出判断结果:

for *xyxy, conf, cls in reversed(det):
    c =
int(cls)  # integer class
    # print(names[c])  #
输出类别名字
   
label = f'{names[c]} {conf:.2f}# 标签
   
annotator.box_label(xyxy, label, color=colors(c, True))
   
# print(xyxy)
    # print(int(xyxy[0].numpy()), int(xyxy[1].numpy()), int(xyxy[2].numpy()), int(xyxy[3].numpy()))
   
x1 = int(xyxy[0].numpy()) # 左上角横坐标
    y1 = int(xyxy[1].numpy()) # 左上角横坐标
    x2 = int(xyxy[2].numpy()) # 右下角横坐标
    y2 = int(xyxy[3].numpy()) # 右下角横坐标
    if x1 >= x and y1 >= y and x2 <= x + w and y2 <= y + h:
        result =
"True"
       
color = (0, 255, 0)
   
else:
        result =
"False"
       
color = (0, 0, 255)
    cv2.putText(frame
, result, (10, 30), font, 1.0, color, 2)

在原视频上绘制出检测区域:

cv2.rectangle(frame, (x, y), (x + w, y + h), color, 5)

完整代码:

  1. import cv2
  2. import torch
  3. from models.experimental import attempt_load
  4. from utils.general import is_ascii, non_max_suppression, scale_coords, set_logging
  5. from utils.plots import Annotator, colors
  6. from utils.torch_utils import select_device
  7. @torch.no_grad()
  8. def run(weights='pretrained/yolov5s.pt', # model.pt path(s)
  9. conf_thres=0.6, # confidence threshold
  10. iou_thres=0, # NMS IOU threshold
  11. max_det=1, # maximum detections per image
  12. device='', # cuda device, i.e. 0 or 0,1,2,3 or cpu
  13. classes=0, # filter by class: --class 0, or --class 0 2 3
  14. agnostic_nms=False, # class-agnostic NMS
  15. line_thickness=2, # bounding box thickness (pixels)
  16. half=False, # use FP16 half-precision inference
  17. ):
  18. # Initialize
  19. global color
  20. set_logging()
  21. device = select_device(device)
  22. print(device)
  23. half &= device.type != 'cpu' # half precision only supported on CUDA
  24. model = attempt_load(weights, device=device) # load FP32 model
  25. names = model.module.names if hasattr(model, 'module') else model.names # get class names
  26. ascii = is_ascii(names) # names are ascii (use PIL for UTF-8)
  27. cap = cv2.VideoCapture('video2.mp4')
  28. # cap = cv2.VideoCapture(0)
  29. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 获取视频的宽度
  30. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 获取视频的高度
  31. w = 500
  32. h = 700
  33. x = int(width / 2 - w / 2)
  34. y = int(height / 2 - h / 2)
  35. font = cv2.FONT_HERSHEY_SIMPLEX # 设置字体样式
  36. while True:
  37. # 获取一帧q
  38. ret, frame = cap.read()
  39. # frame = cv2.resize(frame, (width, height)) # 设置画面宽长
  40. img = torch.from_numpy(frame).to(device)
  41. img = img.half() if half else img.float() # uint8 to fp16/32
  42. img = img / 255 # 0 - 255 to 0.0 - 1.0
  43. if len(img.shape) == 3:
  44. img = img[None] # expand for batch dim
  45. img = img.transpose(2, 3)
  46. img = img.transpose(1, 2)
  47. # Inference
  48. pred = model(img, augment=False, visualize=False)[0]
  49. # NMS
  50. pred = non_max_suppression(pred, conf_thres, iou_thres, classes, agnostic_nms, max_det=max_det)
  51. # Process predictions
  52. for i, det in enumerate(pred): # detections per image
  53. s = ''
  54. annotator = Annotator(frame, line_width=line_thickness, pil=not ascii)
  55. if len(det):
  56. # Rescale boxes from img_size to im0 size
  57. det[:, :4] = scale_coords(img.shape[2:], det[:, :4], frame.shape).round()
  58. # Print results
  59. for c in det[:, -1].unique():
  60. n = (det[:, -1] == c).sum() # detections per class
  61. s += str(n.item()) + ' ' + str(names[int(c)]) + ' ' # add to string
  62. # Write results
  63. for *xyxy, conf, cls in reversed(det):
  64. c = int(cls) # integer class
  65. # print(names[c]) # 输出类别名字
  66. label = f'{names[c]} {conf:.2f}' # 标签
  67. annotator.box_label(xyxy, label, color=colors(c, True))
  68. # print(xyxy)
  69. # print(int(xyxy[0].numpy()), int(xyxy[1].numpy()), int(xyxy[2].numpy()), int(xyxy[3].numpy()))
  70. x1 = int(xyxy[0].numpy())
  71. y1 = int(xyxy[1].numpy())
  72. x2 = int(xyxy[2].numpy())
  73. y2 = int(xyxy[3].numpy())
  74. if x1 >= x and y1 >= y and x2 <= x + w and y2 <= y + h:
  75. result = "True"
  76. color = (0, 255, 0)
  77. else:
  78. result = "False"
  79. color = (0, 0, 255)
  80. cv2.putText(frame, result, (10, 30), font, 1.0, color, 2)
  81. cv2.rectangle(frame, (x, y), (x + w, y + h), color, 5) # frame要绘制的帧,四个坐标点,颜色,线宽
  82. # print('result:' + s)
  83. cv2.imshow('frame', frame)
  84. k = cv2.waitKey(1) & 0xFF
  85. if k == 27:
  86. break
  87. def main():
  88. run()
  89. if __name__ == "__main__":
  90. main()

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/575821
推荐阅读
相关标签
  

闽ICP备14008679号