当前位置:   article > 正文

【目标跟踪网络训练 Market-1501 数据集】DeepSort 训练自己的跟踪网络模型_market1501数据集

market1501数据集

前言

Deepsort之所以可以大量避免IDSwitch,是因为Deepsort算法中特征提取网络可以将目标检测框中的特征提取出来并保存,在目标被遮挡后又从新出现后,利用前后的特征对比可以将遮挡的后又出现的目标和遮挡之前的追踪的目标重新找到,大大减少了目标在遮挡后,追踪失败的可能。

一、数据集简介

Market-1501 数据集在清华大学校园中采集,夏天拍摄,在 2015 年构建并公开。它包括由6个摄像头(其中5个高清摄像头和1个低清摄像头)拍摄到的 1501 个行人、32668 个检测到的行人矩形框。每个行人至少由2个摄像头捕获到,并且在一个摄像头中可能具有多张图像。训练集有 751 人,包含 12,936 张图像,平均每个人有 17.2 张训练数据;测试集有 750 人,包含 19,732 张图像,平均每个人有 26.3 张测试数据。3368 张查询图像的行人检测矩形框是人工绘制的,而 gallery 中的行人检测矩形框则是使用DPM检测器检测得到的。

该数据集提供的固定数量的训练集和测试集均可以在single-shot或multi-shot测试设置下使用。

目录结构

Market-1501-v15.09.15

  ├── bounding_box_test

       ├── 0000_c1s1_000151_01.jpg

       ├── 0000_c1s1_000376_03.jpg

       ├── 0000_c1s1_001051_02.jpg

  ├── bounding_box_train

       ├── 0002_c1s1_000451_03.jpg

       ├── 0002_c1s1_000551_01.jpg

       ├── 0002_c1s1_000801_01.jpg

  ├── gt_bbox

       ├── 0001_c1s1_001051_00.jpg

       ├── 0001_c1s1_009376_00.jpg

       ├── 0001_c2s1_001976_00.jpg

  ├── gt_query

       ├── 0001_c1s1_001051_00_good.mat

       ├── 0001_c1s1_001051_00_junk.mat

  ├── query

       ├── 0001_c1s1_001051_00.jpg

       ├── 0001_c2s1_000301_00.jpg

       ├── 0001_c3s1_000551_00.jpg

  └── readme.txt

目录介绍

(1) “bounding_box_test”——用于测试集的 750 人,包含 19,732 张图像,前缀为 0000 表示在提取这 750 人的过程中DPM检测错的图(可能与query是同一个人),-1 表示检测出来其他人的图(不在这 750 人中)

(2) “bounding_box_train”——用于训练集的 751 人,包含 12,936 张图像

(3) “query”——为 750 人在每个摄像头中随机选择一张图像作为query,因此一个人的query最多有 6 个,共有 3,368 张图像

(4) “gt_query”——matlab格式,用于判断一个query的哪些图片是好的匹配(同一个人不同摄像头的图像)和不好的匹配(同一个人同一个摄像头的图像或非同一个人的图像)

(5) “gt_bbox”——手工标注的bounding box,用于判断DPM检测的bounding box是不是一个好的box

命名规则

以 0001_c1s1_000151_01.jpg 为例

1) 0001 表示每个人的标签编号,从0001到1501;

2) c1 表示第一个摄像头(camera1),共有6个摄像头;

3) s1 表示第一个录像片段(sequece1),每个摄像机都有数个录像段;

4) 000151 表示 c1s1 的第000151帧图片,视频帧率25fps;

5) 01 表示 c1s1_001051 这一帧上的第1个检测框,由于采用DPM检测器,对于每一帧上的行人可能会框出好几个bbox。00 表示手工标注框

二、跟踪模型介绍

特征提取的模型有很多,可以替换特征提取模型网络。

下述给出的是 deep_sort/deep/model.py 里面的模型代码

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class BasicBlock(nn.Module):
  5. def __init__(self, c_in, c_out, is_downsample=False):
  6. super(BasicBlock, self).__init__()
  7. self.is_downsample = is_downsample
  8. if is_downsample:
  9. self.conv1 = nn.Conv2d(
  10. c_in, c_out, 3, stride=2, padding=1, bias=False)
  11. else:
  12. self.conv1 = nn.Conv2d(
  13. c_in, c_out, 3, stride=1, padding=1, bias=False)
  14. self.bn1 = nn.BatchNorm2d(c_out)
  15. self.relu = nn.ReLU(True)
  16. self.conv2 = nn.Conv2d(c_out, c_out, 3, stride=1,
  17. padding=1, bias=False)
  18. self.bn2 = nn.BatchNorm2d(c_out)
  19. if is_downsample:
  20. self.downsample = nn.Sequential(
  21. nn.Conv2d(c_in, c_out, 1, stride=2, bias=False),
  22. nn.BatchNorm2d(c_out)
  23. )
  24. elif c_in != c_out:
  25. self.downsample = nn.Sequential(
  26. nn.Conv2d(c_in, c_out, 1, stride=1, bias=False),
  27. nn.BatchNorm2d(c_out)
  28. )
  29. self.is_downsample = True
  30. def forward(self, x):
  31. y = self.conv1(x)
  32. y = self.bn1(y)
  33. y = self.relu(y)
  34. y = self.conv2(y)
  35. y = self.bn2(y)
  36. if self.is_downsample:
  37. x = self.downsample(x)
  38. return F.relu(x.add(y), True)
  39. def make_layers(c_in, c_out, repeat_times, is_downsample=False):
  40. blocks = []
  41. for i in range(repeat_times):
  42. if i == 0:
  43. blocks += [BasicBlock(c_in, c_out, is_downsample=is_downsample), ]
  44. else:
  45. blocks += [BasicBlock(c_out, c_out), ]
  46. return nn.Sequential(*blocks)
  47. class Net(nn.Module):
  48. def __init__(self, num_classes=751, reid=False):
  49. super(Net, self).__init__()
  50. # 3 128 64
  51. self.conv = nn.Sequential(
  52. nn.Conv2d(3, 64, 3, stride=1, padding=1),
  53. nn.BatchNorm2d(64),
  54. nn.ReLU(inplace=True),
  55. # nn.Conv2d(32,32,3,stride=1,padding=1),
  56. # nn.BatchNorm2d(32),
  57. # nn.ReLU(inplace=True),
  58. nn.MaxPool2d(3, 2, padding=1),
  59. )
  60. # 32 64 32
  61. self.layer1 = make_layers(64, 64, 2, False)
  62. # 32 64 32
  63. self.layer2 = make_layers(64, 128, 2, True)
  64. # 64 32 16
  65. self.layer3 = make_layers(128, 256, 2, True)
  66. # 128 16 8
  67. self.layer4 = make_layers(256, 512, 2, True)
  68. # 256 8 4
  69. self.avgpool = nn.AvgPool2d((8, 4), 1)
  70. # 256 1 1
  71. self.reid = reid
  72. self.classifier = nn.Sequential(
  73. nn.Linear(512, 256),
  74. nn.BatchNorm1d(256),
  75. nn.ReLU(inplace=True),
  76. nn.Dropout(),
  77. nn.Linear(256, num_classes),
  78. )
  79. def forward(self, x):
  80. x = self.conv(x)
  81. x = self.layer1(x)
  82. x = self.layer2(x)
  83. x = self.layer3(x)
  84. x = self.layer4(x)
  85. x = self.avgpool(x)
  86. x = x.view(x.size(0), -1)
  87. # B x 128
  88. if self.reid:
  89. x = x.div(x.norm(p=2, dim=1, keepdim=True))
  90. return x
  91. # classifier
  92. x = self.classifier(x)
  93. return x
  94. if __name__ == '__main__':
  95. net = Net()
  96. x = torch.randn(4, 3, 128, 64)
  97. y = net(x)
  98. import ipdb
  99. ipdb.set_trace()

三、数据集处理

splitDataset.py

用于存放训练的图片

  1. # -*- coding:utf-8 -*-
  2. # @author: 牧锦程
  3. # @微信公众号: AI算法与电子竞赛
  4. # @Email: m21z50c71@163.com
  5. # @VX:fylaicai
  6. import os
  7. from shutil import copyfile
  8. # You only need to change this line to your dataset download path
  9. download_path = 'Market-1501-v15.09.15'
  10. if not os.path.isdir(download_path):
  11. print('please change the download_path')
  12. save_path = 'pytorch'
  13. if not os.path.isdir(save_path):
  14. os.mkdir(save_path)
  15. # ------------------- query ----------------------
  16. query_path = download_path + '/query'
  17. query_save_path = save_path + '/query'
  18. print("process: ", query_path)
  19. if not os.path.isdir(query_save_path):
  20. os.mkdir(query_save_path)
  21. for root, dirs, files in os.walk(query_path, topdown=True):
  22. for name in files:
  23. if not name[-3:] == 'jpg':
  24. continue
  25. ID = name.split('_')
  26. src_path = query_path + '/' + name
  27. dst_path = query_save_path + '/' + ID[0]
  28. if not os.path.isdir(dst_path):
  29. os.mkdir(dst_path)
  30. copyfile(src_path, dst_path + '/' + name)
  31. # ----------------- multi-query ------------------------
  32. query_path = download_path + '/gt_bbox'
  33. print("process: ", query_path)
  34. # for dukemtmc-reid, we do not need multi-query
  35. if os.path.isdir(query_path):
  36. query_save_path = save_path + '/multi-query'
  37. if not os.path.isdir(query_save_path):
  38. os.mkdir(query_save_path)
  39. for root, dirs, files in os.walk(query_path, topdown=True):
  40. for name in files:
  41. if not name[-3:] == 'jpg':
  42. continue
  43. ID = name.split('_')
  44. src_path = query_path + '/' + name
  45. dst_path = query_save_path + '/' + ID[0]
  46. if not os.path.isdir(dst_path):
  47. os.mkdir(dst_path)
  48. copyfile(src_path, dst_path + '/' + name)
  49. # ------------------- gallery ----------------------
  50. gallery_path = download_path + '/bounding_box_test'
  51. gallery_save_path = save_path + '/gallery'
  52. print("process: ", gallery_path)
  53. if not os.path.isdir(gallery_save_path):
  54. os.mkdir(gallery_save_path)
  55. for root, dirs, files in os.walk(gallery_path, topdown=True):
  56. for name in files:
  57. if not name[-3:] == 'jpg':
  58. continue
  59. ID = name.split('_')
  60. src_path = gallery_path + '/' + name
  61. dst_path = gallery_save_path + '/' + ID[0]
  62. if not os.path.isdir(dst_path):
  63. os.mkdir(dst_path)
  64. copyfile(src_path, dst_path + '/' + name)
  65. # ------------------ train ---------------------
  66. train_path = download_path + '/bounding_box_train'
  67. train_save_path = save_path + '/train'
  68. val_save_path = save_path + '/test'
  69. if not os.path.isdir(train_save_path):
  70. os.mkdir(train_save_path)
  71. os.mkdir(val_save_path)
  72. print("process: ", train_path)
  73. for root, dirs, files in os.walk(train_path, topdown=True):
  74. for name in files:
  75. if not name[-3:] == 'jpg':
  76. continue
  77. ID = name.split('_')
  78. src_path = train_path + '/' + name
  79. dst_path = train_save_path + '/' + ID[0]
  80. if not os.path.isdir(dst_path):
  81. os.mkdir(dst_path)
  82. # first image is used as val image
  83. dst_path = val_save_path + '/' + ID[0]
  84. os.mkdir(dst_path)
  85. copyfile(src_path, dst_path + '/' + name)

四、模型训练

修改数据集路径

修改 data-dir 参数为自己的数据集路径

修改数据增强

增加一个尺寸修改

修改类别数量

代码中通过dataloader来获取,因此可以不进行修改

num_classes = max(len(trainloader.dataset.classes), len(testloader.dataset.classes))

修改保存模型名称

这里的修改是为例与原始的模型进行区分,可以做对比

训练结果

查看精度

先运行test.py,生成 features.pth

在运行 evaluate.py,得到如下精度:

五、链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文内容由网友自发贡献,转载请注明出处:https://www.wpsshop.cn/w/Monodyee/article/detail/732795
推荐阅读
相关标签
  

闽ICP备14008679号