当前位置:   article > 正文

yolov8 strongSORT多目标跟踪工具箱BOXMOT

boxmot

1 引言

多目标跟踪MOT项目在Github中比较完整有:BOXMOT , 由mikel brostrom提供。在以前的版本中,有yolov5+deepsort(版本v3-v5), yolov8+strongsort(版本v6-v9),直至演变到v10,名称BOXMOT。
BOXMOT提供三种对象检测器:yolov8, yolo_nas, yolox; 支持多个跟踪器:BoTSORT, DeepOCSORT, OCSORT, Hybridsort, ByteTrack, StrongSORT 。以前常见的DeepSort在此由增强型StrongSORT替代。

2 安装BOXMOT

boxmot安装
安装环境:Ubuntu18.04,python 3.8,已建有虚拟环境。

将boxmot从github克隆到本地,建立yolo_tracking目录:

git clone https://github.com/mikel-brostrom/yolo_tracking.git
cd yolo_tracking
pip install -v -e .
  • 1
  • 2
  • 3

第三个命令pip install -v -e. 相当于 python setup.py develop,即根据setup.py执行安装,其中“develop”参数将软件包以开发模式安装到Python环境中,以便在开发过程中能够即时反映源代码的修改。
完成BOXMOT安装后,看看yolo_tracking下面有什么:
在这里插入图片描述
图1 yolo_tracking目录
见图1, BOXMOT支持的三个对象检测器定义文件yolonas.py, yolov8.py, yolox.py在examples/detectors目录。跟踪器tracker在boxmot/trackers目录下。 boxmot/appearance目录是tracker所需使用的外观识别REID模块。boxmot/configs为tracker的参数构造文件。boxmot/motion目录是运动预测用Kalman滤波器。
下一步根据需要,安装三个对象检测器。

2.1 安装ultralytics

对象检测器yolov8需要安装ultralytics python库,需注意,BOXMOT适用ultralytics v8.0.146,而最新的版本不适用。
安装ultralytics到yolo_tracking目录,操作如下:
先删除虚拟环境下和系统中可能安装的ultralytics模块:

pip uninstall ultralytics
  • 1

克隆ultralytics v8.0.146

git clone https://github.com/mikel-brostrom/ultralytics.git
  • 1

此操作在home目录下产生ultralytics目录。我们需要将ultralytics二级目录:~/ultralytics/ultralytics移动到yolo_tracking目录下,完成安装ultralytics。这样在python程序调试时,可以跟踪到ultralytics模块。为了防止混淆,将examples/track.py和val.py中安装ultralytics语句注释掉:

#__tr = TestRequirements()
#__tr.check_packages(('ultralytics @ git+https://github.com/mikel-brostrom/ultralytics.git', ))  # install  ultralytics
  • 1
  • 2

现在,已实现基础对象检测器yolov8的运行环境,下面其他对象检测器yoloNAS,yolox可根据您的需要进行安装,不是必须项。据博主试验,yoloNAS和yolox并无多大优势,还不及yolov8性能好。
运行yolov8s+strongsort对输入视频进行车辆跟踪示例:

python examples/track.py  \
   --yolo-model yolov8s    \
   --reid-mode  osnet_x0_25_market1501.pt   \
   --source     ~/yolo_tracking/MOT16-13-h264.mp4  \ 
   --save         \
   --show         \
   --classes 2     \
   --tracking-method strongsort 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

若已下载yolov8s权重文件yolov8s.pt,可在–yolo-model 变量中指定文件路径,若没下载,则track.py根据"yolov8s"自动从网上下载。

2.2 安装yoloNAS

yoloNAS需安装super-gradients:

pip install super-gradients
  • 1

运行yoloNAS,实现的示例:

python examples/track.py  \
   --yolo-model yolo_nas_s    \
   --reid-mode  osnet_x0_25_market1501.pt   \
   --source     ~/yolo_tracking/MOT16-13-h264.mp4  \ 
   --save         \
   --show         \
   --classes 2     \
   --tracking-method strongsort  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

BOXMOT使用yoloNAS不能对跟踪对象类进行筛选,而把图像中所有符合COCO数据集类型的对象都提取,大数量目标对跟踪器造成很大负担,运行速度慢。
在yolonas.py的postprocess函数中增加类型过滤:

def postprocess(self, path, preds, im, im0s):

        results = []
        for i, pred in enumerate(preds):

            if pred is None:
                pred = torch.empty((0, 6))
                r = Results(
                    path=path,
                    boxes=pred,
                    orig_img=im0s[i],
                    names=self.names
                )
                results.append(r)
            else:

                pred[:, :4] = ops.scale_boxes(im.shape[2:], pred[:, :4], im0s[i].shape)
                # filter boxes by classes   ############################################
                pred = pred[torch.isin(pred[:, 5].cpu(), torch.as_tensor(self.args.classes))]   # added by someone

                r = Results(
                    path=path,
                    boxes=pred,
                    orig_img=im0s[i],
                    names=self.names
                )
            results.append(r)
        return results
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28

2.3 安装YOLOX

YOLOX官网克隆到本地,产生YOLOX目录。

git clone https://github.com/Megvii-BaseDetection/YOLOX.git
  • 1

将YOLOX目录下三个子目录:yolox, tools, exps 复制到yolo_tracking,完成YOLOX环境。
实现YOLOX对象检测器的跟踪示例

python examples/track.py  \
   --yolo-model yolox_s    \
   --reid-mode  osnet_x0_25_market1501.pt   \
   --source     ~/yolo_tracking/MOT16-13-h264.mp4  \ 
   --save         \
   --show         \
   --tracking-method strongsort  
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7

此BOXMOT版本v10.043在使用YOLOX上有限制,从track.py程序下载的权重文件yolox_s仅支持一个类型“person”的对象检测,不支持其他对象检测。从YOLOX Github下载权重文件yolox_s.pth支持COCO数据集的80个类型,这里需要修改:
1 YOLOX官网下载的权重文件yolox_s.pth,80个类型; BOXMOT,track.py下载的yolox_s.pt,1个person类型。在examples/detectors/yolox.py中如下修改:

def __init__(self, model, device, args):
        self.args = args
        self.pt = False
        self.stride = 32  # max stride in YOLOX

        # model_type one of: 'yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x'
        model_type = self.get_model_from_weigths(YOLOX_ZOO.keys(), model)

        if model_type == 'yolox_n':
            exp = get_exp(None, 'yolox_nano')
        else:
            exp = get_exp(None, model_type)

        LOGGER.info(f'Loading {model_type} with {str(model)}')

        # download crowdhuman bytetrack models
        if not model.exists() and model.stem == model_type:
            LOGGER.info('Downloading pretrained weights...')
            gdown.download(
                url=YOLOX_ZOO[model_type + '.pt'],
                output=str(model),
                quiet=False
            )
            # "yolox_s.pt" 表示num_classes =1,  "yolox_s.pth"表示num_classes = 80。
            #    boxmot下载的ckpt只处理“person”类型,
            #     github.com/Megvii-BaseDetection/YOLOX提供的ckpt用于num_classes,可处理多种类型。
                  
            # needed for bytetrack yolox people models
            # update with your custom model needs
            exp.num_classes = 1             
            self.num_classes = 1
        elif model.stem == model_type:
            exp.num_classes = 1
            self.num_classes = 1
            _, file_extension = os.path.splitext(str(model))            
            if file_extension == ".pth":
                exp.num_classes = 80           
                self.num_classes = 80
            if exp.num_classes ==1:    
                self.img_normal = True  #num_classes = 1,    BOXMOT format: 0.0-1.0
            else:
                self.img_normal = False   #num_classes = 80, YOLOX website ckpt format:  0-255 

        ckpt = torch.load(
            str(model),
            map_location=torch.device('cpu')
        )

        self.model = exp.get_model()
        self.model.eval()
        self.model.load_state_dict(ckpt["model"])
        self.model = fuse_model(self.model)
        self.model.to(device)
        self.model.eval()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54

这里,对 YoloXStrategy(YoloInterface)类增加类变量:self.num_classes,适应来自不同网站的yolox权重文件。

2 YOLOX官网yolox detector所处理的图像为0-255数据,而BOXMOT yolox detector所处理图像数据为0.0 - 1.0,需要针对不同权重文件对图像输入进行变换。对 YoloXStrategy(YoloInterface)类增加类变量:self.img_normal。

更改yolox.py postprocess

def postprocess(self, path, preds, im, im0s):

        results = []
        for i, pred in enumerate(preds):

            pred = postprocess(
                pred.unsqueeze(0),  # YOLOX postprocessor expects 3D arary
                self.num_classes,                                             #  1   num_classes  
                conf_thre=0.5,              #0.1
                nms_thre=0.7,            # 0.45
                class_agnostic=True,   
            )[0]
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

更改ultralytics/engine/predictor.py preprocess函数

def preprocess(self, im):
        """Prepares input image before inference.

        Args:
            im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
        """
        not_tensor = not isinstance(im, torch.Tensor)
        if not_tensor:
            im = np.stack(self.pre_transform(im))
            im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
            im = np.ascontiguousarray(im)  # contiguous
            im = torch.from_numpy(im)

        img = im.to(self.device)
        img = img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32
        if not_tensor:                                 # -------------------------------------------------------------------------
            if not self.model.img_normal:
                return img                          # yolox  num_classes = 80  , img 取值 0-255                                    
            img /= 255                              # 0 - 255 to 0.0 - 1.0   # yolov8, yolo_nas and yolox  num_classes = 1的ckpt,img 取值 0.0 - 1.0 。
        return img
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

注:如果更改BOXMOT中默认的yolox单一类型处理模式,由于对ultralytics/engine/predictor.py preprocess方法做了修改,会涉及到yolov8和yoloNAS,所以:
yolov8: ultralytics/nn/autobackend.py, 类 AutoBackend, 增加类变量 self.img_normal = True
yoloNAS:examples/detectors/yolonas.py, 类 YoloNASStrategy, 增加类变量 self.img_normal = True

当然,如果不改变BOXMOT默认的yolox,则无需改变yolov8和yoloNAS。

对比BOXMOT 默认yolox_s.pt和YOLOX官网yolo_s.pth,两者分别运行val.py

python examples/val.py   \
--yolo-model    examples/weights/yolox_s.pt  \   
 --tracking-method  deepocsort   \
 --benchmark  MOT17


python examples/val.py   \
--yolo-model    examples/weights/yolox_s.pth  \   
 --tracking-method  deepocsort   \
 --benchmark  MOT17
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

得到如下结果(取MOT17 - FRCNN每个序列的前10帧):

                                  HOTA           MOTA          MOTP          IDF1
yolox_s.pt             68.294         67.268         81.267       81.403   
yolox_s.pth           60.038         52.223         78.765       70.531
  • 1
  • 2
  • 3

这因为yolox_s.pt是针对拥挤行人情况的权重,而yolox_s.pth则适用于COCO数据集的多种对象类型,MOT17-FRCNN侧重评价对行人对象的多目标跟踪,所以yolox_s.pt权重获得的评价指标略有提高。

3 ultralytics对检测模型的定义

以下只针对对象检测,在ultralytics命名为:detect。
根据ultralytics, class YOLO在ultralytic/models/yolo/models.py中。以“detect”为例,有两个重要的对象:model和predictor
‘model’: DetectionModel,
‘predictor’: yolo.detect.DetectionPredictor,
这里,YOLO类继承Model类,Model类在ultralytics/enging/model.py中定义,predictor是Model类的属性,self.predictor是BasePredictor类的实例。而BasePredictor属性self.model,是AutoBackend类的实例,此self.model就是yolov8 DetectionModel。
DetectionModel是对象检测的核心神经网络。
yolo.detect.DetectionPredictor则是对象检测数据的处理过程,用于对检测模型输入数据的预处理,对象数据提取(推理),以及检测数据后处理等过程。
yolov8有detect, segment, classify, pose模型,载入权重文件时决定了所使用的哪种模型。
模型载入:

model = YOLO("yolov8s.pt")
  • 1

模型使用:

dets = model.predict(source="bus.jpg")
  • 1

第一步模型载入后,仅规定了核心神经网络部分,其中的predictor还未定义,执行
dets = model.predict(), 或者:
dets = model()
都可以完成检测模型初始化,实现对predictor定义,并提供对输入画面的对象检测。实际上,dets = model() 的执行函数

__call__()
  • 1

就是跳转到model.predict()方法。所以,调用检测模型可直接使用
dets = model.predict(source = image)
此时还可加入predictor需要变更的类变量,而predictor的类变量通过model()无法直接加入。

dets = model.predict(source = im0, save=True, imgsz=args.imgsz, classes = args.classes, conf = args.conf)
  • 1

4 根据boxmot跟踪框架的检测-跟踪简单程序

以下给出简化的跟踪程序track_yolov8.py,以便了解多目标跟踪的主要流程。这里借用examples/track.py的参数导入部分。
将权重文件yolov8s.pt, osnet_x0_25_msmt17.pt放入examples/weights,运行程序track_yolov8.py:

python examples/track_yolov8.py
  • 1

在这里插入图片描述

import cv2
import torch
import argparse
import numpy as np
from pathlib import Path
from boxmot import DeepOCSORT, StrongSORT
from ultralytics import YOLO
from boxmot.utils import ROOT, WEIGHTS

@torch.no_grad()    
def run(args):    
    # Load a model
    yolo = YOLO(args.yolo_model )
    tracker = StrongSORT(
        model_weights=Path('examples/weights/osnet_x0_25_msmt17.pt'), # which ReID model to use
        device='cuda:0',
        fp16=True,
    )
    video_path = "MOT16-13-raw.mp4"                                  
    vid = cv2.VideoCapture(video_path)
    color = (0, 0, 255)  # BGR
    thickness = 2
    fontscale = 0.5
    frame_number = 0
    
    if not vid.isOpened():
        print("无法打开视频文件")
    else:
        # 获取视频帧的宽度和高度,总帧数
        img_w = int(vid.get(cv2.CAP_PROP_FRAME_WIDTH))
        img_h = int(vid.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(vid.get(cv2.CAP_PROP_FRAME_COUNT))  

        while True:
            success, im0 = vid.read()
            if not success:
                break
            else:
                # 使用检测模型
                dets = yolo.predict(source = im0, save=True, imgsz=args.imgsz, classes = args.classes, conf = args.conf)
                for det in dets:
                    boxes = det.boxes.xyxy
                    confs = det.boxes.conf
                    cls = det.boxes.cls
                # 将PyTorch张量转换为NumPy数组
                    boxes_np = boxes.cpu().numpy()
                    confs_np = confs.cpu().numpy()
                    cls_np = cls.cpu().numpy()

                # 将boxes、confs和cls堆叠成一个数组
                    detection_results = np.column_stack((boxes_np, confs_np, cls_np))
                
                print(f"当前帧号: {frame_number}/{total_frames}")
                frame_number = frame_number +1

                tracks = tracker.update(detection_results, im0) # --> (x, y, x, y, id, conf, cls, ind)
                if tracks.shape[0] != 0:
                    xyxys = tracks[:, 0:4].astype('int') # float64 to int
                    ids = tracks[:, 4].astype('int') # float64 to int
                    confs = tracks[:, 5]
                    clss = tracks[:, 6].astype('int') # float64 to int
                # print bboxes with their associated id, cls and conf
                if tracks.shape[0] != 0:
                    for xyxy, id, conf, cls in zip(xyxys, ids, confs, clss):
                        im0 = cv2.rectangle(
                            im0,
                            (xyxy[0], xyxy[1]),
                            (xyxy[2], xyxy[3]),
                            color,
                            thickness
                        )
                        cv2.putText(
                            im0,
                            f'{id} ',                #f'id: {id}, conf: {conf}, c: {cls}',
                            (xyxy[0], xyxy[1]-10),
                            cv2.FONT_HERSHEY_SIMPLEX,
                            fontscale,
                            color,
                            thickness
                        )
                # show image with bboxes, ids, classes and confidences
                origin_size = (img_w, img_h)
                im0  = cv2.resize(im0, origin_size)
                cv2.imshow('frame', im0)

                # break on pressing q
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
        
    # Release the video capture object and close the display window
    vid.release()
    cv2.destroyAllWindows()

def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument('--yolo-model', type=Path, default=WEIGHTS / 'yolov8s',
                        help='yolo model path')
    parser.add_argument('--reid-model', type=Path, default=WEIGHTS / 'osnet_x0_25_msmt17.pt',
                        help='reid model path')
    parser.add_argument('--tracking-method', type=str, default='strongsort',
                        help='deepocsort, botsort, strongsort, ocsort, bytetrack')
    parser.add_argument('--source', type=str, default='0',
                        help='file/dir/URL/glob, 0 for webcam')
    parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640],
                        help='inference size h,w')
    parser.add_argument('--conf', type=float, default=0.5,       # 0.5
                        help='confidence threshold')
    parser.add_argument('--iou', type=float, default=0.7,
                        help='intersection over union (IoU) threshold for NMS')
    parser.add_argument('--device', default='',
                        help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    parser.add_argument('--show', action='store_true',
                        help='display tracking video results')
    parser.add_argument('--save', action='store_true',    
                        help='save video tracking results')
    # class 0 is person, 1 is bycicle, 2 is car... 79 is oven
    parser.add_argument('--classes', nargs='+', type=int, default=0,
                        help='filter by class: --classes 0, or --classes 0 2 3')
    parser.add_argument('--project', default=ROOT / 'runs' / 'track',
                        help='save results to project/name')
    parser.add_argument('--name', default='exp',
                        help='save results to project/name')
    parser.add_argument('--exist-ok', action='store_true',
                        help='existing project/name ok, do not increment')
    parser.add_argument('--half', action='store_true',
                        help='use FP16 half-precision inference')
    parser.add_argument('--vid-stride', type=int, default=1,
                        help='video frame-rate stride')
    parser.add_argument('--show-labels', action='store_false',       #  labels    store_false
                        help='either show all or only bboxes')
    parser.add_argument('--show-conf', action='store_true',                         #   conf    store_false
                        help='hide confidences when show')
    parser.add_argument('--save-txt', action='store_false',                #  store_true,    
                        help='save tracking results in a txt file')
    parser.add_argument('--save-id-crops', action='store_true',           #  id-crops    store_true
                        help='save each crop to its respective id folder')
    parser.add_argument('--save-mot', action='store_true',               # 保存mot txt文件,与输入视频同目录。
                        help='save tracking results in a single txt file')
    parser.add_argument('--line-width', default=1, type=int,   # default=None
                        help='The line width of the bounding boxes. If None, it is scaled to the image size.')
    parser.add_argument('--per-class', default=False, action='store_true',
                        help='not mix up classes when tracking')
    parser.add_argument('--verbose', default=True, action='store_true',
                        help='print results per frame')
    parser.add_argument('--vid_stride', default=1, type=int,
                        help='video frame-rate stride')
        
    opt = parser.parse_args()
    return opt

if __name__ == "__main__":
    opt = parse_opt()
    run(opt)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/盐析白兔/article/detail/424652
推荐阅读
相关标签
  

闽ICP备14008679号