当前位置:   article > 正文

目标跟踪JDE模型训练自己的数据集_jde算法进行训练

jde算法进行训练

前言

        之前训练CenterTrack的时候制作了一个mot格式的数据集,现在要训练JDE,下面是我的自制数据集的文件目录。基本上就是按照MOT17数据集整理的,除了视频序列的名字不一样。img目录下存放着视屏序列的图片,gt目录下存放着按照mot格式标注好的标签数据(使用Darklabel工具)。

  1. ─custom    //根目录
  2. ├─annotations
  3. ├─test
  4. │  ├─1
  5. │  │  ├─gt
  6. │  │  └─img
  7. │  ├─3
  8. │  │  ├─gt
  9. │  │  └─img
  10. │  └─....
  11. └─train
  12.     ├─1
  13.     │  ├─gt
  14.     │  └─img
  15.     ├─11
  16.     │  ├─gt
  17.     │  └─img
  18.     └─...

数据集格式调整

        将数据集从mot调整为JDE需要的格式。

        按照JDE官方提供Readme文件,完成数据集调整需要完成三个部分的任务。

1. 将数据集目录结构调整为如下结构

  1. Caltech
  2. |——————images
  3. | └——————00001.jpg
  4. | |—————— ...
  5. | └——————0000N.jpg
  6. └——————labels_with_ids
  7. └——————00001.txt
  8. |—————— ...
  9. └——————0000N.txt

        按照这个结构,图片数据不需要按照视频序列的方式来保存,所有图片统一保存到了images目录下。labels_with_ids下的每个txt和图片对应,其中保存着对应图片中的目标信息(类似于yolo格式),具体格式如下。class对应就是自己的类别,本博客使用的自制数据集仅有一个类别,identity则是在图片中对应目标的id,对应到mot格式的gt文件下则是id,后面bbox数据则是根据图像大小进行归一化后的。

[class] [identity] [x_center] [y_center] [width] [height]   //JDE标签文件格式
<frame>, <id>, <bb_left>, <bb_top>, <bb_width>, <bb_height>, <conf>, <x>, <y>, <z>   //mot 标签格式

        所以首先要做的就是标签文件从下面的mot格式转为上面的JDE需要的格式。下面的代码则用来完成这个任务,并按照train_half的方式,将视频序列划分为前后两部分,分别用于训练和测试

2.生成用于训练的xxx.train和xxx.val文件

        xxx.train和xxx.val是用来记录训练数据的路径和测试数据的路径。如下图为JDE源码中打他目录下的mot17.train的样式。

3. 编写cfg目录下的ccmcpe.json文件

        ccmcpe.json文件记录有 数据集的目录、训练用.train文件和测试用.val文件。

        train对应编写.train文件

        test对应编写.val文件

        test_emb我不知道干嘛用的/(ㄒoㄒ)/~~

  1. {
  2. "root":"mydataset",
  3. "train":
  4. {
  5. "mot17":"./data/custom.train"
  6. },
  7. "test_emb":
  8. {
  9. },
  10. "test":
  11. {
  12. "mot17":"./data/custom.val"
  13. }
  14. }

 4.好了再加一步,助君一步到位。

        下面的代码可以帮你调整好数据集格式,并生成对应.train和.val文件。

  1. import os
  2. import random
  3. import shutil
  4. def convert(box, size=(1280, 800)):
  5. dw = 1. / size[0]
  6. dh = 1. / size[1]
  7. x = (box[0] + box[2]) / 2.0
  8. y = (box[1] + box[3]) / 2.0
  9. w = box[2] - box[0]
  10. h = box[3] - box[1]
  11. x = x * dw
  12. w = w * dw
  13. y = y * dh
  14. h = h * dh
  15. return x, y, w, h
  16. def modify_dataset(dataset_root, train_val_root, new_dataset_root):
  17. # 数据集根目录
  18. # dataset_root = 'custom/train/'
  19. # train_val_root = './data' # cfg_save_dir
  20. # new_dataset_root = "./mydataset" # new_dataset_root
  21. # 读取训练集图片路径
  22. train_image_dir = os.listdir(dataset_root)
  23. format_img, format_gt = "img", "gt"
  24. train_txt = os.path.join(train_val_root, "custom.train")
  25. val_txt = os.path.join(train_val_root, "custom.val")
  26. train_txt_fp = open(train_txt, "w")
  27. val_txt_fp = open(val_txt, "w")
  28. mydataset_image_dir = os.path.join(new_dataset_root, "images")
  29. mydataset_labels_dir = os.path.join(new_dataset_root, "labels_with_ids")
  30. os.makedirs(mydataset_image_dir, exist_ok=True)
  31. os.makedirs(mydataset_labels_dir, exist_ok=True)
  32. classes = 0
  33. image_cnt = 0
  34. new_id_cnt = 0
  35. for one_dir in train_image_dir:
  36. id_dict = {}
  37. image_dir = os.path.join(dataset_root, one_dir, format_img)
  38. label_txt_path = os.path.join(dataset_root, one_dir, format_gt, "gt.txt")
  39. label_fp = open(label_txt_path, "r")
  40. label_lines = label_fp.readlines()
  41. image_name_list = os.listdir(image_dir)
  42. image_num = len(image_name_list) // 2
  43. train_list = image_name_list[:image_num]
  44. val_list = image_name_list[image_num:]
  45. pre_image_id = -1
  46. for line in label_lines:
  47. line_list = line.strip().split(',')
  48. image_id = int(line_list[0])
  49. if image_id != pre_image_id:
  50. src_image_path = os.path.join(image_dir, "{:06d}.jpg".format(image_id))
  51. if not os.path.exists(src_image_path):
  52. continue
  53. image_cnt += 1
  54. dst_image_path = os.path.join(mydataset_image_dir, "{:06d}.jpg".format(image_cnt))
  55. shutil.copy(src_image_path, dst_image_path)
  56. if "{:06d}.jpg".format(image_id) in train_list:
  57. train_txt_fp.write(dst_image_path + '\n')
  58. elif "{:06d}.jpg".format(image_id) in val_list:
  59. val_txt_fp.write(dst_image_path + '\n')
  60. pre_image_id = image_id
  61. img_label_txt = os.path.join(mydataset_labels_dir, "{:06d}.txt".format(image_cnt))
  62. img_label_txt_fp = open(img_label_txt, "a")
  63. identity = int(line_list[1])
  64. if (not id_dict.get(identity)) or (id_dict.get(identity) == 0):
  65. id_dict[identity] = new_id_cnt
  66. new_id_cnt += 1
  67. identity = id_dict[identity]
  68. bbox = [int(i) for i in line_list[2:6]]
  69. bbox[2] = bbox[0] + bbox[2]
  70. bbox[3] = bbox[1] + bbox[3]
  71. center_x, center_y, w, h = convert(bbox)
  72. img_label_txt_fp.write(
  73. "{},{:d},{:.6f},{:.6f},{:.6f},{:.6f}\n".format(classes, identity, center_x, center_y, w, h))
  74. if __name__ == '__main__':
  75. custom_mot = "./custom/train" # 指定自己的数据集目录 如果需要可以把自己的测试集test下的视频序列也拷贝到train下进行划分
  76. train_val_root = "./data" # 这个train和val文件存放的目录,默认保存在jde源代码中data目录下
  77. new_dataset_root = "mydataset" # 这个目录用来保存转换后数据集
  78. modify_dataset(custom_mot, train_val_root, new_dataset_root)

训练

        数据集整理完便可以开始着手训练了。目前JDE的源码直接跑会报错,坑啊,参考了这篇博客做出了一些修改,才得以正常运行。主要就是两处修改。

1.在这行代码前添加一个判断

  1. mkdir_if_missing(weights_to+"/cfg") # 添加这行
  2. copyfile(cfg, weights_to + '/cfg/yolo3.cfg') # 找到这行

2.注释掉如下代码

 开始训练

python train.py --cfg cfg/yolov3_576x320.cfg --batch-size 16

测试

...未进行

声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:【wpsshop博客】
推荐阅读
相关标签
  

闽ICP备14008679号