OpenCV与AI深度学习 | 实战 | 基于YOLOv9和OpenCV实现车辆跟踪计数(步骤 + 源码)



导  读

    本文主要介绍使用YOLOv9和OpenCV实现车辆跟踪计数(步骤 + 源码)。 



【1】安装ultralytics,因为它拥有直接使用 YoloV9 预训练模型的方法。

pip install ultralytics


  1. import math
  2. class CustomTracker:
  3. def __init__(self):
  4. # Store the center positions of the objects
  5. self.custom_center_points = {}
  6. # Keep the count of the IDs
  7. # each time a new object id detected, the count will increase by one
  8. self.custom_id_count = 0
  9. def custom_update(self, custom_objects_rect):
  10. # Objects boxes and ids
  11. custom_objects_bbs_ids = []
  12. # Get center point of new object
  13. for custom_rect in custom_objects_rect:
  14. x, y, w, h = custom_rect
  15. cx = (x + x + w) // 2
  16. cy = (y + y + h) // 2
  17. # Find out if that object was detected already
  18. same_object_detected = False
  19. for custom_id, pt in self.custom_center_points.items():
  20. dist = math.hypot(cx - pt[0], cy - pt[1])
  21. if dist < 35:
  22. self.custom_center_points[custom_id] = (cx, cy)
  23. custom_objects_bbs_ids.append([x, y, w, h, custom_id])
  24. same_object_detected = True
  25. break
  26. # New object is detected we assign the ID to that object
  27. if same_object_detected is False:
  28. self.custom_center_points[self.custom_id_count] = (cx, cy)
  29. custom_objects_bbs_ids.append([x, y, w, h, self.custom_id_count])
  30. self.custom_id_count += 1
  31. # Clean the dictionary by center points to remove IDS not used anymore
  32. new_custom_center_points = {}
  33. for custom_obj_bb_id in custom_objects_bbs_ids:
  34. _, _, _, _, custom_object_id = custom_obj_bb_id
  35. center = self.custom_center_points[custom_object_id]
  36. new_custom_center_points[custom_object_id] = center
  37. # Update dictionary with IDs not used removed
  38. self.custom_center_points = new_custom_center_points.copy()
  39. return custom_objects_bbs_ids


  1. # Import the Libraries
  2. import cv2
  3. import pandas as pd
  4. from ultralytics import YOLO
  5. from tracker import *

    导入所有必要的库后,就可以导入模型了。我们不必从任何存储库下载模型。Ultralytics 做得非常出色,让我们可以更轻松地直接下载它们。


    这会将 yolov9c.pt 模型下载到当前目录中。该模型已经在由 80 个不同类别组成的 COCO 数据集上进行了训练。现在让我们指定类:

​​​​​​​class_list = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']


  1. tracker=CustomTracker()
  2. count=0
  3. cap = cv2.VideoCapture('traffictrim.mp4')
  4. # Get video properties
  5. fps = int(cap.get(cv2.CAP_PROP_FPS))
  6. width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
  7. height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
  8. # Create VideoWriter object to save the modified frames
  9. output_video_path = 'output_video.mp4'
  10. fourcc = cv2.VideoWriter_fourcc(*'mp4v') # You can use other codecs like 'XVID' based on your system
  11. out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))


  1. # Looping over each frame and Performing the Detection
  2. down = {}
  3. counter_down = set()
  4. while True:
  5. ret, frame = cap.read()
  6. if not ret:
  7. break
  8. count += 1
  9. results = model.predict(frame)
  10. a = results[0].boxes.data
  11. a = a.detach().cpu().numpy()
  12. px = pd.DataFrame(a).astype("float")
  13. # print(px)
  14. list = []
  15. for index, row in px.iterrows():
  16. # print(row)
  17. x1 = int(row[0])
  18. y1 = int(row[1])
  19. x2 = int(row[2])
  20. y2 = int(row[3])
  21. d = int(row[5])
  22. c = class_list[d]
  23. if 'car' in c:
  24. list.append([x1, y1, x2, y2])
  25. bbox_id = tracker.custom_update(list)
  26. # print(bbox_id)
  27. for bbox in bbox_id:
  28. x3, y3, x4, y4, id = bbox
  29. cx = int(x3 + x4) // 2
  30. cy = int(y3 + y4) // 2
  31. # cv2.circle(frame,(cx,cy),4,(0,0,255),-1) #draw ceter points of bounding box
  32. # cv2.rectangle(frame, (x3, y3), (x4, y4), (0, 255, 0), 2) # Draw bounding box
  33. # cv2.putText(frame,str(id),(cx,cy),cv2.FONT_HERSHEY_COMPLEX,0.8,(0,255,255),2)
  34. y = 308
  35. offset = 7
  36. ''' condition for red line '''
  37. if y < (cy + offset) and y > (cy - offset):
  38. ''' this if condition is putting the id and the circle on the object when the center of the object touched the red line.'''
  39. down[id] = cy # cy is current position. saving the ids of the cars which are touching the red line first.
  40. # This will tell us the travelling direction of the car.
  41. if id in down:
  42. cv2.circle(frame, (cx, cy), 4, (0, 0, 255), -1)
  43. #cv2.putText(frame, str(id), (cx, cy), cv2.FONT_HERSHEY_COMPLEX, 0.8, (0, 255, 255), 2)
  44. counter_down.add(id)
  45. # # line
  46. text_color = (255, 255, 255) # white color for text
  47. red_color = (0, 0, 255) # (B, G, R)
  48. # print(down)
  49. cv2.line(frame, (282, 308), (1004, 308), red_color, 3) # starting cordinates and end of line cordinates
  50. cv2.putText(frame, ('red line'), (280, 308), cv2.FONT_HERSHEY_SIMPLEX, 0.5, text_color, 1, cv2.LINE_AA)
  51. downwards = (len(counter_down))
  52. cv2.putText(frame, ('Vehicle Counter - ') + str(downwards), (60, 40), cv2.FONT_HERSHEY_SIMPLEX, 0.5, red_color, 1,
  53. cv2.LINE_AA)
  54. cv2.line(frame,(282,308),(1004,308),red_color,3) # starting cordinates and end of line cordinates
  55. cv2.putText(frame,('red line'),(280,308),cv2.FONT_HERSHEY_SIMPLEX, 0.5, text_color, 1, cv2.LINE_AA)
  56. # This will write the Output Video to the location specified above
  57. out.write(frame)






