赞
踩
【博主使用的Python版本:3.9.7】
【博主使用的 OpenCV版本:4.5.0】
本文所使用的资料已上传到百度网盘【https://pan.baidu.com/s/1-OyW8kGbfV58bO4q3GK0tA?pwd=j7u9】,提取码:j7u9。
OpenCV的全称是:Open Source Computer Vision Library, OpenCV是一个基于Apache2.0许可(开源)发行的跨平台计算机视觉和机器学习软件库, 其采用 C/C++ 编写,同时提供了Python、Ruby、MATLAB等语言的接口,实现了图像处理和计算机视觉方面的很多通用算法。其主要关注的是实时应用,同时,OpenCV 的另一个目标是构建一个简单易用的计算机视觉框架,以帮助开发人员更便捷地设计更复杂的计算机视觉相关的应用程序。可以从http://opencv.org获取。
我们要做的事是搭建一个【多目标跟踪】的简单的框架,你可以跟随我的步骤在PyCharm中一步步地把代码填进去,也可以直接复制完整代码,完整代码在本文最底部。
在开始之前,我们需要引入的库:
如果你没有以上的库,请自行安装。
import cv2
import numpy as np
from object_detection import ObjectDetection
import math
object_detection.py代码如下:
import cv2 import numpy as np class ObjectDetection: def __init__(self, weights_path="dnn_model/yolov4.weights", cfg_path="dnn_model/yolov4.cfg"): print("Loading Object Detection") print("Running opencv dnn with YOLOv4") self.nmsThreshold = 0.4 self.confThreshold = 0.5 self.image_size = 608 # Load Network net = cv2.dnn.readNet(weights_path, cfg_path) # Enable GPU CUDA net.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA) net.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA) self.model = cv2.dnn_DetectionModel(net) self.classes = [] self.load_class_names() self.colors = np.random.uniform(0, 255, size=(80, 3)) self.model.setInputParams(size=(self.image_size, self.image_size), scale=1/255) def load_class_names(self, classes_path="dnn_model/classes.txt"): with open(classes_path, "r") as file_object: for class_name in file_object.readlines(): class_name = class_name.strip() self.classes.append(class_name) self.colors = np.random.uniform(0, 255, size=(80, 3)) return self.classes def detect(self, frame): return self.model.detect(frame, nmsThreshold=self.nmsThreshold, confThreshold=self.confThreshold)
在开始之前首先进行第一个测试,以确保后面所做的都是正确的。
cap = cv2.VideoCapture("los_angeles.mp4") # 加载视频片段
_, frame = cap.read() # 从视频中获取帧
cv2.imshow("Frame", frame) # 显示帧
cv2.waitKey(0) # 保持窗口打开
运行结果如下:
现在我们成功获取并加载显示出了视频的第一帧,这是一个好的开始。有了第一帧,接下来,继续加载整个视频。
什么是视频?
通俗的来讲,视频就是一个接一个的大量图像。如果你检查相机的规格,例如相机可以以30fps的速度录制,这就意味着相机每秒记录30帧,说明在一秒钟内有30张图像。
所以现在我们将获取帧的步骤放入循环中,在循环内一个接一个地获取帧。注:这里要判断视频是否播放完,即判断是否存在帧,如果不存在帧,则退出循环。
cap = cv2.VideoCapture("los_angeles.mp4") # 加载视频片段
while True: # 获取连续的帧
ret, frame = cap.read() # 从视频中获取帧
if not ret: # 是否存在帧
break # 如果不存在帧,则退出
cv2.imshow("Frame", frame) # 显示帧
key = cv2.waitKey(1) # 1:每帧延迟1ms
if key == 27: # 注:ESC键
break
cap.release() # 释放视频文件
cv2.destroyAllWindows() # 关闭所有窗口
运行结果如下(按下ESC键,退出循环):
(这里只能上传5M以内的图,所以上传的动图压缩了,略模糊)
已经确保了视频每帧都能成功获取,现在调用object_detection.py中目标检测函数,获取每帧中的包含的目标信息,代码如下:
cap = cv2.VideoCapture("los_angeles.mp4") # 加载视频片段 while True: # 获取连续的帧 ret, frame = cap.read() # 从视频中获取帧 if not ret: # 是否存在帧 break # 如果不存在帧,则退出 (class_ids, scores, boxes) = od.detect(frame) # 当前帧中检测目标中的信息 # class_id:what object is (car / track / person) # score: how confident is about the detection and # box: bounding box of the location of each object for box in boxes: # 不区分类别,只画框 print(box) # 打印框,确保提取目标正确 cv2.imshow("Frame", frame) # 显示帧 key = cv2.waitKey(1) # 1:每帧延迟1ms if key == 27: # 注:ESC键 break cap.release() # 释放视频文件 cv2.destroyAllWindows() # 关闭所有窗口
运行结果如下(这里只放了部分):
[505 802 133 178]
[376 683 122 118]
[1671 603 159 64]
[727 605 68 87]
[972 610 92 74]
[898 508 61 52]
[826 531 61 69]
[592 457 40 32]
[861 457 39 32]
[1214 880 244 199]
[735 445 36 37]
[1100 424 37 25]
[1835 560 85 91]
以上结果中,每一行代表一个目标的信息,其中前两个数字为目标框的左上角点坐标(x,y),第三个数字为目标框的宽度,第四个数字为目标框的高度。知道这些信息后,我们就可以在每帧中画出矩形框,框出目标。代码如下:
(class_ids, scores, boxes) = od.detect(frame) # 当前帧中检测目标中的信息
# class_id:what object is (car / track / person)
# score: how confident is about the detection and
# box: bounding box of the location of each object
for box in boxes: # 不区分类别,只画框
(x, y, w, h) = box # 矩形的左上角坐标,及宽度和高度
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) # 在当前帧中根据box绘制矩形
运行结果:
接下来可以给每个跟踪目标分配ID,确保跟踪的为同一个目标。
for box in boxes: # 不区分类别,只画框 # print(box) # 打印框,确保提取目标正确 (x, y, w, h) = box # 矩形的左上角坐标,及宽度和高度 cx = int((x + x + w) / 2) # 中心点的x坐标 cy = int((y + y + h) / 2) # 中心点的y坐标 center_points_cur_frame.append((cx, cy)) # 添加新的中心点到数组中 # print("FRAME N°", count, " ", x, y, w, h) # 打印每一帧中的框 # cv2.circle(frame, (cx, cy), 5, (0, 0, 255), -1) # 在当前帧中以框的中心点为中心画圆,半径为5,红色,用所有颜色填充圆圈 cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) # 在当前帧中根据box绘制矩形 # only at the beginning we compare previous and current frame if count <= 2: for pt in center_points_cur_frame: # cv2.circle(frame, pt, 5, (0, 0, 255), -1) # 画出所有中心点 for pt2 in center_points_prev_frame: distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1]) # 当前帧与前一帧的目标中心点的距离 if distance < 20: # 距离小于20像素 tracking_objects[track_id] = pt # 当前帧中的目标中心 track_id += 1 else: tracking_objects_copy = tracking_objects.copy() # 建立跟踪目标字典副本 center_points_cur_frame_copy = center_points_cur_frame.copy() for object_id, pt2 in tracking_objects.copy().items(): # 首先遍历新的数组 object_exists = False # 首先假设当前帧中不存在目标 for pt in center_points_cur_frame: # 遍历当前帧中的目标 distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1]) # 当前帧与前一帧的目标中心点的距离 # Update IDs position if distance < 20: tracking_objects[object_id] = pt # 更新目标位置 object_exists = True # 目标存在 if pt in center_points_cur_frame: center_points_cur_frame.remove(pt) continue # 继续下一帧 # Remove IDs lost if not object_exists: # 目标不存在 tracking_objects.pop(object_id) # 则移除目标ID # Add new IDs found for pt in center_points_cur_frame: tracking_objects[track_id] = pt track_id += 1 for object_id, pt in tracking_objects.items(): cv2.circle(frame, pt, 5, (0, 0, 255), -1) # 画出所有中心点 # 在帧中显示目标id,文本位置,文本字体类型,字体大小,颜色,粗细(注:需>=0) cv2.putText(frame, str(object_id), (pt[0], pt[1] - 7), 0, 1, (0, 0, 255), 0)
运行结果:
object_tracking.py完整代码如下:
import cv2 import numpy as np from object_detection import ObjectDetection import math # Initialize Object Detection od = ObjectDetection() # 加载目标 cap = cv2.VideoCapture("los_angeles.mp4") # 加载视频片段 frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) # 通过属性获取帧数 # Initialize count count = 0 # 用于计算视频的实际帧数 # center_points = [] # 用于储存所有的中心点 center_points_prev_frame = [] # 空数组用于储存第一帧前的空帧 tracking_objects = {} # 用于储存跟踪目标 track_id = 0 # 跟踪目标初始序号 print(cv2.__version__) while True: # 获取连续的帧 ret, frame = cap.read() # 从视频中获取帧 count += 1 if not ret: # 是否存在帧 break # 如果不存在帧,则退出 # point current frame center_points_cur_frame = [] # 用于储存当前帧的目标的中心点 # Detect objects on frame (class_ids, scores, boxes) = od.detect(frame) # 当前帧中检测目标中的信息 # class_id:what object is (car / track / person) # score: how confident is about the detection and # box: bounding box of the location of each object for box in boxes: # 不区分类别,只画框 # print(box) # 打印框,确保提取目标正确 (x, y, w, h) = box # 矩形的左上角坐标,及宽度和高度 cx = int((x + x + w) / 2) # 中心点的x坐标 cy = int((y + y + h) / 2) # 中心点的y坐标 center_points_cur_frame.append((cx, cy)) # 添加新的中心点到数组中 # print("FRAME N°", count, " ", x, y, w, h) # 打印每一帧中的框 # cv2.circle(frame, (cx, cy), 5, (0, 0, 255), -1) # 在当前帧中以框的中心点为中心画圆,半径为5,红色,用所有颜色填充圆圈 cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2) # 在当前帧中根据box绘制矩形 # only at the beginning we compare previous and current frame if count <= 2: for pt in center_points_cur_frame: # cv2.circle(frame, pt, 5, (0, 0, 255), -1) # 画出所有中心点 for pt2 in center_points_prev_frame: distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1]) # 当前帧与前一帧的目标中心点的距离 if distance < 20: # 距离小于20像素 tracking_objects[track_id] = pt # 当前帧中的目标中心 track_id += 1 else: tracking_objects_copy = tracking_objects.copy() # 建立跟踪目标字典副本 center_points_cur_frame_copy = center_points_cur_frame.copy() for object_id, pt2 in tracking_objects.copy().items(): # 首先遍历新的数组 object_exists = False # 首先假设当前帧中不存在目标 for pt in center_points_cur_frame: # 遍历当前帧中的目标 distance = math.hypot(pt2[0] - pt[0], pt2[1] - pt[1]) # 当前帧与前一帧的目标中心点的距离 # Update IDs position if distance < 20: tracking_objects[object_id] = pt # 更新目标位置 object_exists = True # 目标存在 if pt in center_points_cur_frame: center_points_cur_frame.remove(pt) continue # 继续下一帧 # Remove IDs lost if not object_exists: # 目标不存在 tracking_objects.pop(object_id) # 则移除目标ID # Add new IDs found for pt in center_points_cur_frame: tracking_objects[track_id] = pt track_id += 1 for object_id, pt in tracking_objects.items(): cv2.circle(frame, pt, 5, (0, 0, 255), -1) # 画出所有中心点 # 在帧中显示目标id,文本位置,文本字体类型,字体大小,颜色,粗细(注:需>=0) cv2.putText(frame, str(object_id), (pt[0], pt[1] - 7), 0, 1, (0, 0, 255), 0) print("Tracking Objects :") print(tracking_objects) print("CUR FRAME LEFT PTS :") print(center_points_cur_frame) # print("PREV FRAME :") # print(center_points_prev_frame) cv2.imshow("Frame", frame) # 显示帧 # Make a copy of the points center_points_prev_frame = center_points_cur_frame.copy() key = cv2.waitKey(0) # 1:每帧延迟1ms 0:保持当前帧不动 if key == 27: # 注:ESC键 break cap.release() # 释放视频文件 cv2.destroyAllWindows() # 关闭所有窗口 print("通过属性获取的视频帧数 :", frames) print("实际遍历整个视频的帧数 :", count-1)
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。