当前位置:   article > 正文

如何为目标识别追踪项目mikel-brostrom/yolov8_tracking增加计数功能?_yolov8 track

yolov8 track

等待被改造的代码“mikel-brostrom/yolov8_tracking”地址GitHub - mikel-brostrom/yolo_tracking: A collection of SOTA real-time, multi-object tracking algorithms for object detectors

模仿的代码“Yolov5 + Deep Sort with PyTorch”地址:https://github.com/dongdv95/yolov5/tree/master/Yolov5_DeepSort_Pytorch

如果你想了解YOLOv8的模型细节和里面每个流程,可以阅读这篇博客YOLOv8详解全流程捋清楚-每个步骤_德彪稳坐倒骑驴的博客-CSDN博客

如果这篇博客对你有帮助,希望你 点赞、收藏、关注、评论,您的认可将是我创作下去最大的动力!

视频人群计数是基于目标追踪做的,目标追踪做好了,每个object都被分配了一个id以后,做counting很容易的,在此基础上统计一下不重复的id数量即可。

本文修改完以后的代码:

https://download.csdn.net/download/Albert233333/88320278

代码需要用到的模型文件

https://download.csdn.net/download/Albert233333/88320284

PandasCV编写的79页多目标跟踪入门教程

https://download.csdn.net/download/Albert233333/88320283

当然你也可以直接私信我,我会把这三个文件通过腾讯微云网盘分享给你,你就不用花积分购买了

其实真正的难点在于ReID和追踪这个过程,作者已经帮你做了这一步,最难的部分已经解决了。你只需要做

YOLOv5中和计数有关的代码只有下面四段

放在开头
# line39 数据初始化,车数和每个车的id组成的列表
count = 0
data = []



在目标检测(识别)的循环里面
# line 168 汽车识别的方框上面有一个数字,表示这是第几辆车
# draw boxes for visualization
if len(outputs) > 0:
for j, (output, conf) in enumerate(zip(outputs, confs)):

bboxes = output[0:4]
id = output[4]
cls = output[5]
#count
count_obj(bboxes,w,h,id)
c = int(cls) # integer class
label = f'{id} {names[c]} {conf:.2f}'
annotator.box_label(bboxes, label, color=colors(c, True))



在左上角显示累计有多少辆车过线
# line205, 将计数的结果 多少辆车这个数字 显示在视频的左上角
# Stream results
im0 = annotator.result()
if show_vid:
global count
color=(0,255,0)
start_point = (0, h-350)
end_point = (w, h-350)
cv2.line(im0, start_point, end_point, color, thickness=2)
thickness = 3
org = (150, 150)
font = cv2.FONT_HERSHEY_SIMPLEX
fontScale = 3
cv2.putText(im0, str(count), org, font,
fontScale, color, thickness, cv2.LINE_AA)
cv2.imshow(str(p), im0)
if cv2.waitKey(1) == ord('q'): # q to quit
raise StopIteration


放在结尾
# line 238 介绍如何实现计数
def count_obj(box,w,h,id):
global count,data
center_coordinates = (int(box[0]+(box[2]-box[0])/2) , int(box[1]+(box[3]-box[1])/2))
# 分别求出识别方框的横纵坐标的中心点,从拿到目标的中心点
# 如过目标的纵坐标大于图片的高减去350个像素(也就是绿色的线所在的位置)(图片的坐标是从左上角为中心,向下向右为正半轴)
# 也就是目标的纵坐标低于这条线,并且这个目标之前没有出现,就将id记录进data,总车数加1
if int(box[1]+(box[3]-box[1])/2) > (h -350):
if id not in data:
count += 1
data.append(id)

输出数字

我们在yolov8里面按照下面这样的步骤进行修改和增加内容如下

  1. # 变量初始化
  2. # line22
  3. # ——@@@@@ 新加的代码(1) @@@@@——
  4. # ,车数和每个车的id组成的列表
  5. count = 0
  6. obj_id_list = []
  7. # 完成计数
  8. # line271
  9. # ——@@@@@ 新加的代码(2) @@@@@——
  10. w, h = im0.shape[1], im0.shape[0] # 获得汽车目标这个小图的宽和高
  11. # 将所有出现过的目标的id全部
  12. count_obj(bbox, w, h, id)
  13. # ——@@@@@ 新添加的代码 @@@@@——
  14. # 每帧图片累计汽车的数字显示在左上角
  15. # line320
  16. # ——@@@@@ 新添加的代码(3) @@@@@——
  17. # ——用于在左上角显示每一帧累计通过的汽车输出量
  18. global count # 把全局变量中那个count的数据拿过来这里要显示在窗口里
  19. org = (150, 150) # 数字的大小吧,我猜
  20. font = cv2.FONT_HERSHEY_SIMPLEX # 字体
  21. fontScale = 3 # 字号
  22. color = (0, 255, 0)
  23. thickness = 3 # 厚度?
  24. cv2.putText(im0, str(count), org, font,
  25. fontScale, color, thickness, cv2.LINE_AA)
  26. # ——@@@@@ 新添加的代码 @@@@@——

关于tab是这样放

这样修改完毕以后,识别的视频的左上角就有了实时波动的数字,非常的直观

  1. # 定义实现计数的函数
  2. # line373
  3. # ——@@@@@ 新加的代码(4) @@@@@——
  4. def count_obj(box,w,h,id):
  5. global count,obj_id_list
  6. center_coordinates = (int(box[0]+(box[2]-box[0])/2) , int(box[1]+(box[3]-box[1])/2))
  7. # 分别求出识别方框的横纵坐标的中心点,从拿到目标的中心点
  8. # 如过目标的纵坐标大于图片的高减去350个像素(也就是绿色的线所在的位置)(图片的坐标是从左上角为中心,向下向右为正半轴)
  9. # 也就是目标的纵坐标低于这条线,并且这个目标之前没有出现,就将id记录进data,总车数加1
  10. if int(box[1]+(box[3]-box[1])/2) > (h -350):
  11. if id not in obj_id_list:
  12. count += 1
  13. obj_id_list.append(id)
  14. # ——@@@@@ 新加的代码(4) @@@@@——

因为line373 count_obj()这个函数,它统计了所有追踪的object的id的列表统计下来了(obj_id_list),还把id unique的数量也统计出来了(count),而且这两个变量已经全局化了,所以直接调用即可

  1. # 将人数和id列表打印出来
  2. # line365
  3. # ——@@@@@ 新加的代码(5) @@@@@——
  4. print("*"*50)
  5. print("There are %s Cars in total"%(count))
  6. print("The id list : %s"%(obj_id_list))
  7. print("*" * 50)
  8. # ——@@@@@ 新加的代码(5) @@@@@——

是的,只统计id就可把车的数量数出来,输出结果是这样


<code class="language-plaintext hljs">**************************************************
There are 11 Cars in total
The id list : [1, 2, 3, 4, 5, 8, 7, 14, 12, 13, 20]
**************************************************</code>

加上绿色的横线

在yolo8 tracking里面这样修改

  1. # line304
  2. else:
  3. # pass
  4. #tracker_list[i].tracker.pred_n_update_all_tracks() # 这是作者注释掉的,不是我注释掉的
  5. # 防止后面因为参数未定义引发的报错
  6. # ——@@@@@ 新添加的代码(7) @@@@@——
  7. # 当视频刚刚开始识别的时候,画面中没有一辆车,所以line252这个判断无法通过“ if len(outputs[i]) > 0:”
  8. # 所判断为True里面,line272的 “w, h = im0.shape[1], im0.shape[0]”,自然就没有执行,所以w 和 h就没有定义,会报错
  9. # 为了不报这个错,这里定义一次。。后面画面中有车了以后走 true那个部分也就不会重复定义了
  10. w, h = im0.shape[1], im0.shape[0]
  11. # ——@@@@@ 新添加的代码(7) @@@@@——
  1. # line343
  2. # 给在视频没一帧的图片的最先有这样一条绿色的横线
  3. # ——@@@@@ 新添加的代码(6) @@@@@——
  4. # w, h = w_copy, h_copy
  5. start_point = (0, h-350)
  6. end_point = (w, h-350)
  7. cv2.line(im0, start_point, end_point, color, thickness=2)
  8. # ——@@@@@ 新添加的代码(6) @@@@@——

只要id出现过不管过没过线都统计上

上面那个计数的函数count_obj(),是过了绿线,才计入数字。我后面想,写一个计数函数,只要出现在画面里,有分配过id就算数。统计所有出现过的人的数量

原来的那个计数的函数改一下名字,好区分

  1. # line 399
  2. # 定义实现计数的函数
  3. # ——@@@@@ 新加的代码(4) @@@@@——
  4. def count_obj_cross_line(box,w,h,id): # ——@@@@@ 修该过(7) @@@@@——
  5. # 特别强调这是跨越线的
  6. global count,obj_id_list
  7. center_coordinates = (int(box[0]+(box[2]-box[0])/2) , int(box[1]+(box[3]-box[1])/2))
  8. # 分别求出识别方框的横纵坐标的中心点,从拿到目标的中心点
  9. # 如过目标的纵坐标大于图片的高减去350个像素(也就是绿色的线所在的位置)(图片的坐标是从左上角为中心,向下向右为正半轴)
  10. # ——整个图片的高度为h,350是绿色线距离最底端的距离
  11. # ——这个车的左上角的纵坐标是box[3],整个车的纵坐标的中心点是 box[3]+车高的一半,也就是下面这个柿子
  12. # 也就是目标的“纵坐标中心点”低于这条横着的绿色线,并且这个目标之前没有出现,就将id记录进data,总车数加1
  13. if int(box[1]+(box[3]-box[1])/2) > (h -350):
  14. if id not in obj_id_list:
  15. count += 1
  16. obj_id_list.append(id)
  17. # ——@@@@@ 新加的代码(4) @@@@@——
  18. # line276
  19. count_obj_cross_line(bbox, w, h, id) # ——@@@@@ 修该过(8) @@@@@——
  20. # 调用跨线的那个计数
  21. # 新增一个计数的函数,只要id不重复就计入数字
  22. # line 421
  23. def count_obj_all(id):
  24. if id not in obj_id_list:
  25. count += 1
  26. obj_id_list.append(id)
  27. # 并调用“只要id不重复就计入数字”的函数实现计数
  28. # line281
  29. # count_obj_cross_line(bbox, w, h, id) # ——@@@@@ 修该过(8) @@@@@——
  30. # # 调用跨线的那个计数
  31. # ——@@@@@ 新添加的代码(2) @@@@@——
  32. # ——@@@@@ 新加的代码(9) @@@@@——
  33. count_obj_all(id)
  34. # ——@@@@@ 新加的代码(9) @@@@@——

它确实实现了,只要出现就计数,不管过不过线

让他运行的时候只识别 人类,不识别其他的

python track.py --source ./val_data/dandong.mp4 --classes 0

  • 把绿的线也去掉

  1. # line354注释掉
  2. # cv2.line(im0, start_point, end_point, color, thickness=2) # 非过线的方法,就注释掉这句

绿色的线确实被我去掉了

  • 加一个开关,

  • 开关一开就是(1)根据是否过线来计数(2)有绿色的线;

  • 开关一关就是(1)根据id是否出现来计数(2)没有绿色的线

  1. # line 27
  2. # 开关决定是否采用 过线计数的方法
  3. is_count_by_line = True
  4. # 计数的时候启用哪种,可以做个判断。判断使用哪个函数进行计数
  5. # line280
  6. if is_count_by_line:
  7. count_obj_cross_line(bbox, w, h, id) # ——@@@@@ 修该过(8) @@@@@——
  8. # 调用跨线的那个计数;关闭它
  9. # ——@@@@@ 新添加的代码(2) @@@@@——
  10. else:
  11. # ——@@@@@ 新加的代码(9) @@@@@——
  12. count_obj_all(id)
  13. # ——@@@@@ 新加的代码(9) @@@@@——
  14. # 是否加横着的绿色的线,做个判断
  15. # line 359
  1. 原来这个划线的标准是距离一帧截图的最下面350个像素。这仅仅适用于Traffic.mp4这个视频的截图。对于那些视频一帧截图的图片尺寸很小的视频(视频画面的长和宽很小),你要求绿线距离一帧截图画面最底端350像素,这条绿线就到截图的最上面甚至在截图上面了。

因此我们需要将过线计数的那个方法从固定的350,改成一个比例比如30%

因为这个start_point和end_point要传进后面cv2.line()的一个函数里,传进去的值必须是int整数。你弄个0.3*h就变成float不满足人家的类型要求,所以前面要加个 int()转换一下。

  1. # line357
  2. # 让横线距离最底端留个350像素,是仅仅适用于
  3. # 之前的
  4. # start_point = (0, h-350) # 横坐标顶着照片的最左边,纵坐标是照片的高度 减去 底下空出来的350个像素
  5. # 之后修改的
  6. start_point = (0, int(h-0.3*h)) # 距离照片最底端留下30%
  7. # 之前的
  8. # end_point = (w, h-350) # 横坐标顶着照片的最右边,纵坐标,同上
  9. # 之后修改的
  10. end_point = (w, int(h-0.3*h)) # 横坐标顶着照片的最右边,纵坐标,同上
  11. # line435
  12. # ——@@@@@ 新加的代码(4) @@@@@——
  13. def count_obj_cross_line(box,w,h,id): # ——@@@@@ 修该过(7) @@@@@——
  14. # 特别强调这是跨越线的
  15. # ——bbox是识别出来的小车的方框的 四个数字表示坐标
  16. # ——w 和 h是视频每一帧图片的长和宽
  17. global count,obj_id_list
  18. center_coordinates = (int(box[0]+(box[2]-box[0])/2) , int(box[1]+(box[3]-box[1])/2))
  19. # 分别求出识别方框的横纵坐标的中心点,从拿到目标的中心点
  20. # 如过目标的纵坐标大于图片的高减去350个像素(也就是绿色的线所在的位置)(图片的坐标是从左上角为中心,向下向右为正半轴)
  21. # ——整个图片的高度为h,350是绿色线距离最底端的距离
  22. # ——这个车的左上角的纵坐标是box[3],整个车的纵坐标的中心点是 box[3]+车高的一半,也就是下面这个柿子
  23. # 也就是目标的“纵坐标中心点”低于这条横着的绿色线,并且这个目标之前没有出现,就将id记录进data,总车数加1
  24. # 原来的
  25. # if int(box[1]+(box[3]-box[1])/2) > (h -350):
  26. # 修改后的
  27. if int(box[1]+(box[3]-box[1])/2) > (int(h-0.3*h)): # 低于70%就计数
  28. if id not in obj_id_list:
  29. count += 1
  30. obj_id_list.append(id)
  31. # ——@@@@@ 新加的代码(4) @@@@@——

人体中有一点点碰触到低端30%,就计入数字一次

你得首先知道box[0] 1 2 3 中间哪一个是上端纵坐标?哪一个是下端纵坐标?或者怎样组合可以得出上下端纵坐标?

我暂时觉得0 和 2 横向左右两个横坐标,1和3是纵向上下两个坐标——确认你去 box 0 1 2 3定义的地方去寻找他们如何定义,定义的函数网上应该有讲解

以zidian这个图为例,图片的尺寸是,宽1280 pixels width 高 720 pixels

识别出来识别框是下面这样打的

我将box 0 1 2 3分别打印出来,对照图片推测这个四个东西分别代表什么意思

  1. print('我猜是横坐标',box[0],'\t',box[2]) # 横坐标
  2. print('我猜是纵坐标',box[1],'\t',box[3]) # 纵坐标
  3. print("#"*20)

打印出来的结果如下

  1. id is 1
  2. 我猜是横坐标 123.0 1111.0
  3. 我猜是纵坐标 197.0 711.0
  4. ####################
  5. id is 2
  6. 我猜是横坐标 747.0 1142.0
  7. 我猜是纵坐标 41.0 712.0
  8. ####################
  9. Video window size is 1280 pixels width and 720 pixels height
  • 横坐标和图中情况完美匹配。id为1的用户,横坐标最小,为123px。id=1 的目标和id=2的目标的右侧横坐标比较靠右,而且和接近,分别为1111和1142。图片最右边的px值为1280,他们确实很靠右了,都一千多了。

  • 纵坐标有点奇怪,明明这两个人的低端纵坐标都很接近最底下了,为什么二者的纵坐标都没有接近0的数字呢?

  • 原因是原点不是在左下角,而是在左上角。纵坐标是从左上角,向下和向右延伸的。

  • 这样就说得通了box[1]和box[3]表示object上下两端的纵坐标。

  • box[1]表示目标的上侧的纵坐标。id=1的object向下多一点,是197px;id=2的object距离上端原点很接近,纵坐标很小,为41px。

  • box[3]表示目标的下册的纵坐标。id=1和id=2的两个object的下端都距离低端很近,低到很逼近低端了,因此应该和画面的高(720px)很接近了,分别是711和712.

知道了谁是纵坐标以后,只要二者有一个低于设定的那条低端30%的线就cout一次即可

line424这个函数count_obj_cross_line() 更名为这样count_obj_cross_line_center()

track.py line451新建这个函数

  1. # 只要object纵向上有一点碰到了低端的30%,就count进去
  2. def count_obj_cross_line_any(box,w,h,id):
  3. global count,obj_id_list
  4. y_upper = box[1] # object的上端纵坐标
  5. y_lower = box[3] # object的下端纵坐标
  6. # 这个行人可以从下端进来,y_upper先大于 低端30%的纵坐标;行人也有可能从上端进来,即y_lower大于低端30%的坐标
  7. # 两个条件满足其中之一即可,所以用or
  8. # 行人从远处走近的可能性比从镜头后面走入然后进入画面的可能性高,所以放在or前面
  9. if (y_lower > int(h-0.3*h)) or (y_upper > int(h-0.3*h)): # 低于70%就计数
  10. if id not in obj_id_list:
  11. count += 1
  12. obj_id_list.append(id)

其他设定

  • 如何只识别人,不识别其他东西的运行方法,记录一下。.(其实就是加了了个--classes 0)

python track.py --source ./val_data/dandong.mp4 --classes 0

如果你希望选择模型tracking使用某个特定的ReID或者tracking的方法,可以用下面这句代码。就能达成这个效果


<code class="language-plaintext hljs">python track.py --source ./val_data/Traffic.mp4 --reid-weights osnet_x0_25_market1501.pt --tracking-method ocsort</code>

其实你可以使用model和 tracking method不是一个算法,比如下面这个。tracking是的算法是最严格的那个strongsort,使用的模型是ocsort(可能不是我想的那样,可能这个模型不是ocsort专用的,而是任何追踪算法都能用的)

如果--tracking-method选择strongsort,识别的时候会特别慢,但是最后的结果数的最仔细,精度更高


<code class="language-plaintext hljs">python track.py --source ./val_data/Traffic.mp4 --reid-weights osnet_x0_25_market1501.pt --tracking-method strongsort</code>

这里你也可以选择其他的追踪模型,你只要在command line的--reid-weights这个参数写了这个模型的名字,代码就会去作者事先规定好的那个网址(比如对于模型文件osnet_x0_25_market1501.pt,下载网址是 https://drive.google.com/uc?id=1z1UghYvOTtjx7kEoRfmqSMu-z62J6MAj)去下载这个模型,放在这个位置(/home/albert/.cache/torch/checkpoints),每次调用的时候都会去这个位置去load进来。

其实也不是任何模型文件名丢进 --reid-weights参数中都可以下载下来,比如作者markdown文档中提及这个几个模型文件传进--reid-weights参数里面,都无法把模型文件下载下来。"osnet_x0_25_msmt17.pt",“mobilenetv2_x1_4_msmt17.engine”,“resnet50_msmt17.onnx”,“osnet_x1_0_msmt17.pt”。

实际上,你是可以自己实现在本地下载好,放到一个文件夹里面,写好路径,每次做tracking的时候进去读就行了。这样就不用存到这个位置了(/home/albert/.cache/torch/checkpoints),存到这个位置不便于管理。


<code class="language-plaintext hljs">python track.py --source ./val_data/Traffic.mp4 --reid-weights ./weights/osnet_x0_25_imagenet.pth --tracking-method ocsort</code>

tracking ReID这件事YOLOv8+Tracking肯定做了,我的原因如下

(1)代码里肯定有使用ReID这个算法的代码,只是我不知道这些代码被放在了什么位置

(2)识别这个阶段里:YOLOv8做识别的时候,官方提供了训练好的模型(炼好的丹);第二个阶段追踪,也使用的是大牛提供的训练好的追踪模型


我对比YOLOv5 +DeepSort和 YOLOv8 tracking的代码。后者只有detection和tracking、每个object分配ID的功能,前者多了一个功能,计数。通过对比二者的不同。我们可以模仿前者,一步步把后者改造成有计数功能的代码

YOLOv5

YOLOv8

detect()函数

run()函数

line113 : for frame_idx, (path, img, im0s, vid_cap, s) in enumerate(dataset):

Line 167 : for frame_idx, batch in enumerate(dataset):

line125 : pred = model(img, augment=opt.augment, visualize=visualize)

Line 179 : preds = model(im, augment=augment, visualize=visualize)

line134 : for i, det in enumerate(pred):

Line 197 : for i, det in enumerate(pred):

line 146: annotator = Annotator(im0, line_width=2, pil=not ascii)

Line 222 : annotator = Annotator(im0, line_width=line_thickness, example=str(names))

line148 : if det is not None and len(det):

if det is not None and len(det):

# Rescale boxes from img_size to im0 size

det[:, :4] = scale_coords(

img.shape[2:], det[:, :4], im0.shape).round()

Line 228 :

if det is not None and len(det):

if is_seg:

shape = im0.shape

# scale bbox first the crop masks

if retina_masks:

det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], shape).round() # rescale boxes to im0 size

masks.append(process_mask_native(proto[i], det[:, 6:], det[:, :4], im0.shape[:2])) # HWC

else:

masks.append(process_mask(proto[i], det[:, 6:], det[:, :4], im.shape[2:], upsample=True)) # HWC

det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], shape).round() # rescale boxes to im0 size

else:

det[:, :4] = scale_boxes(im.shape[2:], det[:, :4], im0.shape).round() # rescale boxes to im0 size

Line 153

for c in det[:, -1].unique():

n = (det[:, -1] == c).sum() # detections per class

s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string

Line 242

for c in det[:, 5].unique():

n = (det[:, 5] == c).sum() # detections per class

s += f"{n} {names[int(c)]}{'s' * (n > 1)}, " # add to string

进行追踪,line163

# 进行追踪# pass detections to deepsortt4 = time_sync()

outputs = deepsort.update(xywhs.cpu(), confs.cpu(), clss.cpu(), im0)

t5 = time_sync()

dt[3] += t5 - t4

进行追踪, line 224

# pass detections to strongsort

with dt[3]:

outputs[i] = tracker_list[i].update(det.cpu(), im0)

line173,,画出bouding box

if len(outputs) > 0:

Line 174

for j, (output, conf) in enumerate(zip(outputs, confs)):

bboxes = output[0:4]

id = output[4]

cls = output[5]

line251

if len(outputs[i]) > 0:

Line 262

for j, (output) in enumerate(outputs[i]):

bbox = output[0:4]

id = output[4]

cls = output[5]

最佳插入点(它是这么插入的)

line 180

count_obj(bboxes,w,h,id)

最佳插入点(等待我插入)

Line 269

Line 147

w, h = im0.shape[1],im0.shape[0]

没有,需要你添加

Line 205

im0 = annotator.result()

Line 306

im0 = annotator.result()

line 218cv2.imshow(str(p), im0)

if cv2.waitKey(1) == ord('q'): # q to quit

line312

cv2.imshow(str(p), im0)

if cv2.waitKey(1) == ord('q'): # 1 millisecond

Line 216

cv2.putText(im0, str(count), org, font,

fontScale, color, thickness, cv2.LINE_AA)

没有,添加到line313

所有追踪算法,YOLOv8的分数

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/花生_TL007/article/detail/441960
推荐阅读
相关标签
  

闽ICP备14008679号