当前位置:   article > 正文

【OpenPCDet】自定义数据集(kitti格式)训练PointPillars并评估&可视化,全过程debug_openpcdet可视化

openpcdet可视化

引言:

主要参考博客:

【3D目标检测】OpenPCDet自定义数据集训练_openpcdet 自己数据集-CSDN博客

 OpenPCDet环境搭建参考:【3D目标检测】环境搭建(OpenPCDet、MMdetection3d)icon-default.png?t=N7T8https://blog.csdn.net/qq_44703886/article/details/131732662?spm=1001.2014.3001.5502

源码地址:OpenPCDet:https://github.com/open-mmlab/OpenPCDet

PS: 如果网速差进不去github官网,可以换镜像网址:

https://dgithub.xyz 或 https://bgithub.xyz

1.详细过程(未完待续):

1.1 自定义数据集准备

首先,看一下需要的数据集格式,按照KITTI数据结构规划自定义数据集如下:

  1. custom
  2. ├── testing
  3. │ ├── velodyne # 点云数据
  4. ├── training
  5. │ ├── label_2 # 标签文件
  6. │ ├── velodyne

其中点云数据的格式是bin文件。如果采集到的数据是rosbag包,则先进行bag2pcd,然后pcd2bin。得到bin格式的点云数据。代码如下:

bag2pcd.py:

  1. import rospy
  2. import rosbag
  3. from sensor_msgs.msg import PointCloud2
  4. from sensor_msgs import point_cloud2
  5. import struct
  6. def pointcloud2_to_pcd(point_cloud2_msg, filename):
  7. # 更新头部信息以包含intensity
  8. header = f"""# .PCD v0.7 - Point Cloud Data file format
  9. VERSION 0.7
  10. FIELDS x y z intensity
  11. SIZE 4 4 4 4
  12. TYPE F F F F
  13. COUNT 1 1 1 1
  14. WIDTH {point_cloud2_msg.width}
  15. HEIGHT {point_cloud2_msg.height}
  16. VIEWPOINT 0 0 0 1 0 0 0
  17. POINTS {point_cloud2_msg.width * point_cloud2_msg.height}
  18. DATA ascii
  19. """
  20. # 将点云数据(包括intensity)转换为ASCII格式并保存到PCD文件
  21. with open(filename, 'w') as f:
  22. f.write(header)
  23. for p in point_cloud2.read_points(point_cloud2_msg, field_names=("x", "y", "z", "intensity"), skip_nans=True):
  24. f.write(f"{' '.join(str(value) for value in p)}\n")
  25. def main():
  26. bag_file = '/home/dxy/SUSTechPOINTS/备份/rosbag/train.bag'
  27. topic = '/velodyne_points'
  28. output_directory = '/home/dxy/SUSTechPOINTS/备份/lidar/'
  29. frame_count = 0
  30. with rosbag.Bag(bag_file, 'r') as bag:
  31. for topic, msg, t in bag.read_messages(topics=[topic]):
  32. if 1:
  33. filename = f"{output_directory}{t.to_nsec()}.pcd"
  34. pointcloud2_to_pcd(msg, filename)
  35. frame_count += 1
  36. print(f"Processed frame {frame_count}: Saved {filename}")
  37. else:
  38. print(f"Message is not of type PointCloud2: {type(msg)}")
  39. print(f"Total frames processed: {frame_count}")
  40. if __name__ == "__main__":
  41. main()

pcd2bin.py:

  1. import os
  2. import numpy as np
  3. def read_pcd(filepath):
  4. lidar = []
  5. header_passed = False
  6. with open(filepath, 'r') as f:
  7. for line in f:
  8. line = line.strip()
  9. if line.startswith('DATA'):
  10. header_passed = True
  11. continue
  12. if header_passed:
  13. linestr = line.split()
  14. if len(linestr) == 3:
  15. linestr_convert = list(map(float, linestr)) + [1.0]
  16. linestr_convert[2] += 0
  17. print("!!!!!!!!!!!!!!!!!!!!!!ERROR")
  18. lidar.append(linestr_convert)
  19. elif len(linestr) == 4:
  20. linestr_convert = list(map(float, linestr))
  21. linestr_convert[2] += 0
  22. lidar.append(linestr_convert)
  23. return np.array(lidar)
  24. def pcd2bin(pcdfolder, binfolder, start_idx, end_idx):
  25. ori_path = pcdfolder
  26. des_path = binfolder
  27. if not os.path.exists(des_path):
  28. os.makedirs(des_path)
  29. for idx in range(start_idx, end_idx + 1):
  30. filename = f"{idx:06d}" # 格式化文件名,确保是六位数字,例如000001
  31. velodyne_file = os.path.join(ori_path, filename + '.pcd')
  32. if os.path.exists(velodyne_file): # 确保文件存在
  33. pl = read_pcd(velodyne_file)
  34. pl = pl.reshape(-1, 4).astype(np.float32)
  35. velodyne_file_new = os.path.join(des_path, filename + '.bin')
  36. pl.tofile(velodyne_file_new)
  37. else:
  38. print(f"File not found: {velodyne_file}")
  39. if __name__ == "__main__":
  40. pcdfolder = "/home/dxy/SUSTechPOINTS/data/备份/lidar_copy"
  41. binfolder = "/home/dxy/SUSTechPOINTS/data/备份/lidar_bin"
  42. # 可以在这里设置开始和结束的帧
  43. start_frame = 1
  44. end_frame = 35
  45. pcd2bin(pcdfolder, binfolder, start_idx=start_frame, end_idx=end_frame)
'
运行

经过上面的步骤,就可以得到velodyne文件夹下所需的bin格式点云数据。

1.2 接下来是标签文件label的格式转换。成败的关键,很重要!!!

标注工具:sustechpoints

SUSTechPOINTS三维点云标注工具使用-CSDN博客

 sustechpoints标注工具安装:

https://zhuanlan.zhihu.com/p/687518464

标注完成后会在label文件夹中生成.json文件.

sustech标注数据向kitti数据格式转换,很重要!!!

附上转换文件代码,sus2kitti.py

  1. import os
  2. import json
  3. import math
  4. import numpy as np
  5. import sys
  6. def trans_detection_label(src_label_path, tgt_label_path, start_idx=None, end_idx=None):
  7. files = os.listdir(src_label_path)
  8. files.sort() # 确保文件按名称排序
  9. # 初始化最大ID为0
  10. max_id = 0
  11. # 如果指定了 start_idx 和 end_idx,过滤出范围内的文件
  12. if start_idx is not None and end_idx is not None:
  13. files = [f for f in files if f.split('.')[0].isdigit() and start_idx <= int(f.split('.')[0]) <= end_idx]
  14. for fname in files:
  15. frame, _ = os.path.splitext(fname)
  16. print(frame)
  17. kitti_lines = []
  18. with open(os.path.join(src_label_path, fname), encoding='utf-8') as f:
  19. labels = json.load(f, strict=False)
  20. for label in labels:
  21. obj_type = label["obj_type"]
  22. # if label.get('obj_attr') == 'static':
  23. # continue # 跳过当前对象
  24. # 根据条件修改 obj_type
  25. if obj_type == 'Scooter':
  26. obj_type = 'Bicycle'
  27. elif obj_type == 'Bus':
  28. obj_type = 'Truck'
  29. if obj_type == 'Bicycle':
  30. obj_type = 'Cyclist'
  31. box_id = int(label["obj_id"])
  32. box_id += 282
  33. # 更新最大ID
  34. if int(box_id) > max_id:
  35. max_id = int(box_id)
  36. box_position_x = label['psr']['position']['x']
  37. box_position_y = label['psr']['position']['y']
  38. box_position_z = label['psr']['position']['z']
  39. box_scale_x = label['psr']['scale']['x']
  40. box_scale_y = label['psr']['scale']['y']
  41. box_scale_z = label['psr']['scale']['z']
  42. box_position_z_kitti = float(box_position_z) + 0 - float(box_scale_z / 2)
  43. rotation_yaw = -float(label['psr']['rotation']['z']) - math.pi / 2
  44. kitti_lines.append(f'{obj_type} 1.0 0 0.0 -1 -1 -1 -1 {box_scale_z:.4f} {box_scale_y:.4f} {box_scale_x:.4f} '
  45. f'{box_position_x:.4f} {box_position_y:.4f} {box_position_z_kitti:.4f} {rotation_yaw:.4f}\n')
  46. with open(os.path.join(tgt_label_path, frame + ".txt"), 'w') as outfile:
  47. outfile.writelines(kitti_lines)
  48. # 在处理完所有文件后打印最大ID
  49. print(f"The maximum ID in the sequence is: {max_id}")
  50. if __name__ == "__main__":
  51. src_label = "/home/dxy/SUSTechPOINTS/data/备份/label_copy" # 替换成自己的路径
  52. tgt_label = "/home/dxy/SUSTechPOINTS/data/备份/label_kitti/"
  53. # 这里你可以指定开始和结束的索引,例如处理000001到001000范围内的文件
  54. start_idx = 1
  55. end_idx = 100
  56. trans_detection_label(src_label, tgt_label, start_idx, end_idx)

运行分割数据集代码:

  1. """
  2. 2024.03.21
  3. author:alian
  4. 数据预处理操作
  5. 1.数据集分割
  6. """
  7. import os
  8. import random
  9. import shutil
  10. import numpy as np
  11. def get_train_val_txt_kitti(src_path):
  12. """
  13. 数据格式:KITTI
  14. # For KITTI Dataset
  15. └── KITTI_DATASET_ROOT
  16. ├── training <-- 7481 train data
  17. | ├── image_2 <-- for visualization
  18. | ├── calib
  19. | ├── label_2
  20. | └── velodyne
  21. └── testing <-- 7580 test data
  22. ├── image_2 <-- for visualization
  23. ├── calib
  24. └── velodyne
  25. src_path: KITTI_DATASET_ROOT kitti文件夹
  26. """
  27. # 1.自动生成数据集划分文件夹ImageSets
  28. set_path = "%s/ImageSets/"%src_path
  29. if os.path.exists(set_path): # 如果文件存在
  30. shutil.rmtree(set_path) # 清空原始数据
  31. os.makedirs(set_path) # 重新创建
  32. else:
  33. os.makedirs(set_path) # 自动新建文件夹
  34. # 2.训练样本分割 生成train.txt val.txt trainval.txt
  35. train_list = os.listdir(os.path.join(src_path,'training','velodyne'))
  36. random.shuffle(train_list) # 打乱顺序,随机采样
  37. # 设置训练和验证的比例
  38. train_p = 0.8
  39. # 开始写入分割文件
  40. f_train = open(os.path.join(set_path, "train.txt"), 'w')
  41. f_val = open(os.path.join(set_path, "val.txt"), 'w')
  42. f_trainval = open(os.path.join(set_path, "trainval.txt"), 'w')
  43. for i,src in enumerate(train_list):
  44. if i<int(len(train_list)*train_p): # 训练集的数量
  45. f_train.write(src[:-4] + '\n')
  46. f_trainval.write(src[:-4] + '\n')
  47. else:
  48. f_val.write(src[:-4] + '\n')
  49. f_trainval.write(src[:-4] + '\n')
  50. # 3.测试样本分割 生成test.txt
  51. test_list = os.listdir(os.path.join(src_path,'testing','velodyne'))
  52. f_test = open(os.path.join(set_path, "test.txt"), 'w')
  53. for i,src in enumerate(test_list):
  54. f_test.write(src[:-4] + '\n')
  55. if __name__=='__main__':
  56. """
  57. src_path: 数据目录
  58. """
  59. src_path = '/home/dxy/Openpcdet-Test-master/data/custom'
  60. get_train_val_txt_kitti(src_path)

 分割数据集后,如图:

1.3 生成标准数据格式

建议复制kitti_dataset.py、kitti_dataset.yaml,重命名为custom_dataset.py、kitti_custom_dataset.yaml,修改文件路径如下:

    Openpcdet-Test-master/pcdet/datasets/custom/custom_dataset.py
    Openpcdet-Test-master/tools/cfgs/dataset_configs/kitti_custom_dataset.yaml

Openpcdet-Test-master/tools/cfgs/dataset_configs/kitti_custom_dataset.yaml

  1. DATASET: 'CustomDataset'
  2. DATA_PATH: '/home/dxy/Openpcdet-Test-master/data/custom' # 1.绝对路径
  3. # If this config file is modified then pcdet/models/detectors/detector3d_template.py:
  4. # Detector3DTemplate::build_networks:model_info_dict needs to be modified.
  5. POINT_CLOUD_RANGE: [-70.4, -40, -3, 70.4, 40, 1] # x=[-70.4, 70.4], y=[-40,40], z=[-3,1] 根据自己的标注框进行调整
  6. DATA_SPLIT: {
  7. 'train': train,
  8. 'test': val
  9. }
  10. INFO_PATH: {
  11. 'train': [custom_infos_train.pkl],
  12. 'test': [custom_infos_val.pkl],
  13. }
  14. GET_ITEM_LIST: ["points"]
  15. FOV_POINTS_ONLY: True
  16. POINT_FEATURE_ENCODING: {
  17. encoding_type: absolute_coordinates_encoding,
  18. used_feature_list: ['x', 'y', 'z', 'intensity'],
  19. src_feature_list: ['x', 'y', 'z', 'intensity'],
  20. }
  21. # Same to pv_rcnn[DATA_AUGMENTOR]
  22. DATA_AUGMENTOR:
  23. DISABLE_AUG_LIST: ['placeholder']
  24. AUG_CONFIG_LIST:
  25. - NAME: gt_sampling
  26. # Notice that 'USE_ROAD_PLANE'
  27. USE_ROAD_PLANE: False
  28. DB_INFO_PATH:
  29. - custom_dbinfos_train.pkl # pcdet/datasets/augmentor/database_ampler.py:line 26
  30. PREPARE: {
  31. filter_by_min_points: ['Car:5'], # 2.修改类别
  32. # filter_by_difficulty: [-1], # 注释掉,防止训练报错
  33. }
  34. SAMPLE_GROUPS: ['Car:15'] # 3. 修改类别
  35. NUM_POINT_FEATURES: 4
  36. DATABASE_WITH_FAKELIDAR: False
  37. REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
  38. LIMIT_WHOLE_SCENE: True
  39. - NAME: random_world_flip
  40. ALONG_AXIS_LIST: ['x']
  41. - NAME: random_world_rotation
  42. WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
  43. - NAME: random_world_scaling
  44. WORLD_SCALE_RANGE: [0.95, 1.05]
  45. DATA_PROCESSOR:
  46. - NAME: mask_points_and_boxes_outside_range
  47. REMOVE_OUTSIDE_BOXES: True
  48. - NAME: shuffle_points
  49. SHUFFLE_ENABLED: {
  50. 'train': True,
  51. 'test': False
  52. }
  53. - NAME: transform_points_to_voxels
  54. VOXEL_SIZE: [0.05, 0.05, 0.1]
  55. MAX_POINTS_PER_VOXEL: 5
  56. MAX_NUMBER_OF_VOXELS: {
  57. 'train': 16000,
  58. 'test': 40000
  59. }

Openpcdet-Test-master/pcdet/datasets/custom/custom_dataset.py

  1. import copy
  2. import pickle
  3. import os
  4. import numpy as np
  5. from skimage import io
  6. from ...ops.roiaware_pool3d import roiaware_pool3d_utils
  7. from ...utils import box_utils, common_utils, object3d_custom
  8. from ..dataset import DatasetTemplate
  9. # 定义属于自己的数据集,集成数据集模板
  10. class CustomDataset(DatasetTemplate):
  11. def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None, ext='.bin'):
  12. """
  13. Args:
  14. root_path:
  15. dataset_cfg:
  16. class_names:
  17. training:
  18. logger:
  19. """
  20. super().__init__(
  21. dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
  22. )
  23. self.split = self.dataset_cfg.DATA_SPLIT[self.mode]
  24. self.root_split_path = os.path.join(self.root_path, ('training' if self.split != 'test' else 'testing'))
  25. split_dir = os.path.join(self.root_path, 'ImageSets',(self.split + '.txt'))
  26. self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if os.path.exists(split_dir) else None
  27. self.custom_infos = []
  28. self.include_custom_data(self.mode)
  29. self.ext = ext
  30. # 用于导入自定义数据
  31. def include_custom_data(self, mode):
  32. if self.logger is not None:
  33. self.logger.info('Loading Custom dataset.')
  34. custom_infos = []
  35. for info_path in self.dataset_cfg.INFO_PATH[mode]:
  36. info_path = self.root_path / info_path
  37. if not info_path.exists():
  38. continue
  39. with open(info_path, 'rb') as f:
  40. infos = pickle.load(f)
  41. custom_infos.extend(infos)
  42. self.custom_infos.extend(custom_infos)
  43. if self.logger is not None:
  44. self.logger.info('Total samples for CUSTOM dataset: %d' % (len(custom_infos)))
  45. # 用于获取标签的标注信息
  46. def get_infos(self, num_workers=4, has_label=True, count_inside_pts=True, sample_id_list=None):
  47. import concurrent.futures as futures
  48. # 线程函数,主要是为了多线程读取数据,加快处理速度
  49. # 处理一帧
  50. def process_single_scene(sample_idx):
  51. print('%s sample_idx: %s' % (self.split, sample_idx))
  52. # 创建一个用于存储一帧信息的空字典
  53. info = {}
  54. # 定义该帧点云信息,pointcloud_info
  55. pc_info = {'num_features': 4, 'lidar_idx': sample_idx}
  56. # 将pc_info这个字典作为info字典里的一个键值对的值,其键名为‘point_cloud’添加到info里去
  57. info['point_cloud'] = pc_info
  58. '''
  59. # image信息和calib信息都暂时不需要
  60. # image_info = {'image_idx': sample_idx, 'image_shape': self.get_image_shape(sample_idx)}
  61. # info['image'] = image_info
  62. # calib = self.get_calib(sample_idx)
  63. # P2 = np.concatenate([calib.P2, np.array([[0., 0., 0., 1.]])], axis=0)
  64. # R0_4x4 = np.zeros([4, 4], dtype=calib.R0.dtype)
  65. # R0_4x4[3, 3] = 1.
  66. # R0_4x4[:3, :3] = calib.R0
  67. # V2C_4x4 = np.concatenate([calib.V2C, np.array([[0., 0., 0., 1.]])], axis=0)
  68. # calib_info = {'P2': P2, 'R0_rect': R0_4x4, 'Tr_velo_to_cam': V2C_4x4}
  69. # info['calib'] = calib_info
  70. '''
  71. if has_label:
  72. # 通过get_label函数,读取出该帧的标签标注信息
  73. obj_list = self.get_label(sample_idx)
  74. # 创建用于存储该帧标注信息的空字典
  75. annotations = {}
  76. # 下方根据标注文件里的属性将对应的信息加入到annotations的键值对,可以根据自己的需求取舍
  77. annotations['name'] = np.array([obj.cls_type for obj in obj_list])
  78. # annotations['truncated'] = np.array([obj.truncation for obj in obj_list])
  79. # annotations['occluded'] = np.array([obj.occlusion for obj in obj_list])
  80. # annotations['alpha'] = np.array([obj.alpha for obj in obj_list])
  81. # annotations['bbox'] = np.concatenate([obj.box2d.reshape(1, 4) for obj in obj_list], axis=0)
  82. annotations['dimensions'] = np.array([[obj.l, obj.h, obj.w] for obj in obj_list]) # lhw(camera) format
  83. annotations['location'] = np.concatenate([obj.loc.reshape(1, 3) for obj in obj_list], axis=0)
  84. annotations['rotation_y'] = np.array([obj.ry for obj in obj_list])
  85. annotations['score'] = np.array([obj.score for obj in obj_list])
  86. # annotations['difficulty'] = np.array([obj.level for obj in obj_list], np.int32)
  87. # 统计有效物体的个数,即去掉类别名称为“Dontcare”以外的
  88. num_objects = len([obj.cls_type for obj in obj_list if obj.cls_type != 'DontCare'])
  89. # 统计物体的总个数,包括了Dontcare
  90. num_gt = len(annotations['name'])
  91. # 获得当前的index信息
  92. index = list(range(num_objects)) + [-1] * (num_gt - num_objects)
  93. annotations['index'] = np.array(index, dtype=np.int32)
  94. # 从annotations里提取出从标注信息里获取的location、dims、rots等信息,赋值给对应的变量
  95. loc = annotations['location'][:num_objects]
  96. dims = annotations['dimensions'][:num_objects]
  97. rots = annotations['rotation_y'][:num_objects]
  98. # 由于我们的数据集本来就是基于雷达坐标系标注,所以无需坐标转换
  99. #loc_lidar = calib.rect_to_lidar(loc)
  100. loc_lidar = self.get_calib(loc)
  101. # 原来的dims排序是高宽长hwl,现在转到pcdet的统一坐标系下,按lhw排布
  102. l, h, w = dims[:, 0:1], dims[:, 1:2], dims[:, 2:3]
  103. # 由于我们基于雷达坐标系标注,所以获取的中心点本来就是空间中心,所以无需从底面中心转到空间中心
  104. # bottom center -> object center: no need for loc_lidar[:, 2] += h[:, 0] / 2
  105. # print("sample_idx: ", sample_idx, "loc: ", loc, "loc_lidar: " , sample_idx, loc_lidar)
  106. # get gt_boxes_lidar see https://zhuanlan.zhihu.com/p/152120636
  107. # loc_lidar[:, 2] += h[:, 0] / 2
  108. gt_boxes_lidar = np.concatenate([loc_lidar, l, w, h, -(np.pi / 2 + rots[..., np.newaxis])], axis=1)
  109. # 将雷达坐标系下的真值框信息存入annotations中
  110. annotations['gt_boxes_lidar'] = gt_boxes_lidar
  111. # 将annotations这整个字典作为info字典里的一个键值对的值
  112. info['annos'] = annotations
  113. return info
  114. # 后续的由于没有calib信息和image信息,所以可以直接注释
  115. '''
  116. # if count_inside_pts:
  117. # points = self.get_lidar(sample_idx)
  118. # calib = self.get_calib(sample_idx)
  119. # pts_rect = calib.lidar_to_rect(points[:, 0:3])
  120. # fov_flag = self.get_fov_flag(pts_rect, info['image']['image_shape'], calib)
  121. # pts_fov = points[fov_flag]
  122. # corners_lidar = box_utils.boxes_to_corners_3d(gt_boxes_lidar)
  123. # num_points_in_gt = -np.ones(num_gt, dtype=np.int32)
  124. # for k in range(num_objects):
  125. # flag = box_utils.in_hull(pts_fov[:, 0:3], corners_lidar[k])
  126. # num_points_in_gt[k] = flag.sum()
  127. # annotations['num_points_in_gt'] = num_points_in_gt
  128. # return info
  129. '''
  130. sample_id_list = sample_id_list if sample_id_list is not None else self.sample_id_list
  131. with futures.ThreadPoolExecutor(num_workers) as executor:
  132. infos = executor.map(process_single_scene, sample_id_list)
  133. return list(infos)
  134. # 此时返回值infos是列表,列表元素为字典类型
  135. # 用于获取标定信息
  136. def get_calib(self, loc):
  137. # calib_file = self.root_split_path / 'calib' / ('%s.txt' % idx)
  138. # assert calib_file.exists()
  139. # return calibration_kitti.Calibration(calib_file)
  140. # loc_lidar = np.concatenate([np.array((float(loc_obj[2]),float(-loc_obj[0]),float(loc_obj[1]-2.3)),dtype=np.float32).reshape(1,3) for loc_obj in loc])
  141. # return loc_lidar
  142. # 这里做了一个由相机坐标系到雷达坐标系翻转(都遵从右手坐标系),但是 -2.3这个数值具体如何得来需要再看下
  143. # 我们的label中的xyz就是在雷达坐标系下,不用转变,直接赋值
  144. loc_lidar = np.concatenate([np.array((float(loc_obj[0]),float(loc_obj[1]),float(loc_obj[2])),dtype=np.float32).reshape(1,3) for loc_obj in loc])
  145. return loc_lidar
  146. # 用于获取标签
  147. def get_label(self, idx):
  148. # 从指定路径中提取txt内容
  149. label_file = self.root_split_path / 'label_2' / ('%s.txt' % idx)
  150. assert label_file.exists()
  151. # 主要就是从这个函数里获取具体的信息
  152. return object3d_custom.get_objects_from_label(label_file)
  153. # 用于获取雷达点云信息
  154. def get_lidar(self, idx, getitem):
  155. """
  156. Loads point clouds for a sample
  157. Args:
  158. index (int): Index of the point cloud file to get.
  159. Returns:
  160. np.array(N, 4): point cloud.
  161. """
  162. # get lidar statistics
  163. if getitem == True:
  164. lidar_file = self.root_split_path + '/velodyne/' + ('%s.bin' % idx)
  165. else:
  166. lidar_file = self.root_split_path / 'velodyne' / ('%s.bin' % idx)
  167. return np.fromfile(str(lidar_file), dtype=np.float32).reshape(-1, 4)
  168. # 用于数据集划分
  169. def set_split(self, split):
  170. super().__init__(
  171. dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training, root_path=self.root_path, logger=self.logger
  172. )
  173. self.split = split
  174. self.root_split_path = self.root_path / ('training' if self.split != 'test' else 'testing')
  175. split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
  176. self.sample_id_list = [x.strip() for x in open(split_dir).readlines()] if split_dir.exists() else None
  177. # 创建真值数据库
  178. # Create gt database for data augmentation
  179. def create_groundtruth_database(self, info_path=None, used_classes=None, split='train'):
  180. import torch
  181. database_save_path = Path(self.root_path) / ('gt_database' if split == 'train' else ('gt_database_%s' % split))
  182. db_info_save_path = Path(self.root_path) / ('custom_dbinfos_%s.pkl' % split)
  183. database_save_path.mkdir(parents=True, exist_ok=True)
  184. all_db_infos = {}
  185. with open(info_path, 'rb') as f:
  186. infos = pickle.load(f)
  187. for k in range(len(infos)):
  188. print('gt_database sample: %d/%d' % (k + 1, len(infos)))
  189. info = infos[k]
  190. sample_idx = info['point_cloud']['lidar_idx']
  191. points = self.get_lidar(sample_idx,False)
  192. annos = info['annos']
  193. names = annos['name']
  194. # difficulty = annos['difficulty']
  195. # bbox = annos['bbox']
  196. gt_boxes = annos['gt_boxes_lidar']
  197. num_obj = gt_boxes.shape[0]
  198. point_indices = roiaware_pool3d_utils.points_in_boxes_cpu(
  199. torch.from_numpy(points[:, 0:3]), torch.from_numpy(gt_boxes)
  200. ).numpy() # (nboxes, npoints)
  201. for i in range(num_obj):
  202. filename = '%s_%s_%d.bin' % (sample_idx, names[i], i)
  203. filepath = database_save_path / filename
  204. gt_points = points[point_indices[i] > 0]
  205. gt_points[:, :3] -= gt_boxes[i, :3]
  206. with open(filepath, 'w') as f:
  207. gt_points.tofile(f)
  208. if (used_classes is None) or names[i] in used_classes:
  209. db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
  210. # db_info = {'name': names[i], 'path': db_path, 'image_idx': sample_idx, 'gt_idx': i,
  211. # 'box3d_lidar': gt_boxes[i], 'num_points_in_gt': gt_points.shape[0],
  212. # 'difficulty': difficulty[i], 'bbox': bbox[i], 'score': annos['score'][i]}
  213. db_info = {'name': names[i], 'path': db_path, 'gt_idx': i,
  214. 'box3d_lidar': gt_boxes[i], 'num_points_in_gt': gt_points.shape[0], 'score': annos['score'][i]}
  215. if names[i] in all_db_infos:
  216. all_db_infos[names[i]].append(db_info)
  217. else:
  218. all_db_infos[names[i]] = [db_info]
  219. for k, v in all_db_infos.items():
  220. print('Database %s: %d' % (k, len(v)))
  221. with open(db_info_save_path, 'wb') as f:
  222. pickle.dump(all_db_infos, f)
  223. # 生成预测字典信息
  224. @staticmethod
  225. def generate_prediction_dicts(batch_dict, pred_dicts, class_names, output_path=None):
  226. """
  227. Args:
  228. batch_dict:
  229. frame_id:
  230. pred_dicts: list of pred_dicts
  231. pred_boxes: (N,7), Tensor
  232. pred_scores: (N), Tensor
  233. pred_lables: (N), Tensor
  234. class_names:
  235. output_path:
  236. Returns:
  237. """
  238. def get_template_prediction(num_smaples):
  239. ret_dict = {
  240. 'name': np.zeros(num_smaples), 'alpha' : np.zeros(num_smaples),
  241. 'dimensions': np.zeros([num_smaples, 3]), 'location': np.zeros([num_smaples, 3]),
  242. 'rotation_y': np.zeros(num_smaples), 'score': np.zeros(num_smaples),
  243. 'boxes_lidar': np.zeros([num_smaples, 7])
  244. }
  245. return ret_dict
  246. def generate_single_sample_dict(batch_index, box_dict):
  247. pred_scores = box_dict['pred_scores'].cpu().numpy()
  248. pred_boxes = box_dict['pred_boxes'].cpu().numpy()
  249. pred_labels = box_dict['pred_labels'].cpu().numpy()
  250. # Define an empty template dict to store the prediction information, 'pred_scores.shape[0]' means 'num_samples'
  251. pred_dict = get_template_prediction(pred_scores.shape[0])
  252. # If num_samples equals zero then return the empty dict
  253. if pred_scores.shape[0] == 0:
  254. return pred_dict
  255. # No calibration files
  256. # pred_boxes_camera = box_utils.boxes3d_lidar_to_kitti_camera(pred_boxes,None)
  257. pred_dict['name'] = np.array(class_names)[pred_labels - 1]
  258. # pred_dict['alpha'] = -np.arctan2(-pred_boxes[:, 1], pred_boxes[:, 0]) + pred_boxes_camera[:, 6]
  259. # pred_dict['dimensions'] = pred_boxes_camera[:, 3:6]
  260. # pred_dict['location'] = pred_boxes_camera[:, 0:3]
  261. # pred_dict['rotation_y'] = pred_boxes_camera[:, 6]
  262. pred_dict['score'] = pred_scores
  263. pred_dict['boxes_lidar'] = pred_boxes
  264. return pred_dict
  265. annos = []
  266. for index, box_dict in enumerate(pred_dicts):
  267. frame_id = batch_dict['frame_id'][index]
  268. single_pred_dict = generate_single_sample_dict(index, box_dict)
  269. single_pred_dict['frame_id'] = frame_id
  270. annos.append(single_pred_dict)
  271. # Output pred results to Output-path in .txt file
  272. if output_path is not None:
  273. cur_det_file = output_path / ('%s.txt' % frame_id)
  274. with open(cur_det_file, 'w') as f:
  275. bbox = single_pred_dict['bbox']
  276. loc = single_pred_dict['location']
  277. dims = single_pred_dict['dimensions'] # lhw -> hwl: lidar -> camera
  278. for idx in range(len(bbox)):
  279. print('%s -1 -1 %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f %.4f'
  280. % (single_pred_dict['name'][idx], single_pred_dict['alpha'][idx],
  281. bbox[idx][0], bbox[idx][1], bbox[idx][2], bbox[idx][3],
  282. dims[idx][1], dims[idx][2], dims[idx][0], loc[idx][0],
  283. loc[idx][1], loc[idx][2], single_pred_dict['rotation_y'][idx],
  284. single_pred_dict['score'][idx]), file=f)
  285. return annos
  286. def evaluation(self, det_annos, class_names, **kwargs):
  287. if 'annos' not in self.custom_infos[0].keys():
  288. return None, {}
  289. from .kitti_object_eval_python import eval as kitti_eval
  290. eval_det_annos = copy.deepcopy(det_annos)
  291. eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.custom_infos]
  292. ap_result_str, ap_dict = kitti_eval.get_official_eval_result(eval_gt_annos, eval_det_annos, class_names)
  293. return ap_result_str, ap_dict
  294. # 用于返回训练帧的总个数
  295. def __len__(self):
  296. if self._merge_all_iters_to_one_epoch:
  297. return len(self.sample_id_list) * self.total_epochs
  298. return len(self.custom_infos)
  299. # 用于将点云与3D标注框均转至前述统一坐标定义下,送入数据基类提供的self.prepare_data()
  300. def __getitem__(self, index): ## 修改如下
  301. if self._merge_all_iters_to_one_epoch:
  302. index = index % len(self.custom_infos)
  303. info = copy.deepcopy(self.custom_infos[index])
  304. sample_idx = info['point_cloud']['lidar_idx']
  305. points = self.get_lidar(sample_idx, True)
  306. input_dict = {
  307. 'frame_id': self.sample_id_list[index],
  308. 'points': points
  309. }
  310. if 'annos' in info:
  311. annos = info['annos']
  312. annos = common_utils.drop_info_with_name(annos, name='DontCare')
  313. gt_names = annos['name']
  314. gt_boxes_lidar = annos['gt_boxes_lidar']
  315. input_dict.update({
  316. 'gt_names': gt_names,
  317. 'gt_boxes': gt_boxes_lidar
  318. })
  319. data_dict = self.prepare_data(data_dict=input_dict)
  320. return data_dict
  321. # 用于创建自定义数据集的信息
  322. def create_custom_infos(dataset_cfg, class_names, data_path, save_path, workers=4):
  323. dataset = CustomDataset(dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path, training=False)
  324. train_split, val_split = 'train', 'val'
  325. # 定义文件的路径和名称
  326. train_filename = save_path / ('custom_infos_%s.pkl' % train_split)
  327. val_filename = save_path / ('custom_infos_%s.pkl' % val_split)
  328. trainval_filename = save_path / 'custom_infos_trainval.pkl'
  329. test_filename = save_path / 'custom_infos_test.pkl'
  330. print('---------------Start to generate data infos---------------')
  331. dataset.set_split(train_split)
  332. # 执行完上一步,得到train相关的保存文件,以及sample_id_list的值为train.txt文件下的数字
  333. # 下面是得到train.txt中序列相关的所有点云数据的信息,并且进行保存
  334. custom_infos_train = dataset.get_infos(num_workers=workers, has_label=True, count_inside_pts=True)
  335. with open(train_filename, 'wb') as f:
  336. pickle.dump(custom_infos_train, f)
  337. print('Custom info train file is saved to %s' % train_filename)
  338. dataset.set_split(val_split)
  339. # 对验证集的数据进行信息统计并保存
  340. custom_infos_val = dataset.get_infos(num_workers=workers, has_label=True, count_inside_pts=True)
  341. with open(val_filename, 'wb') as f:
  342. pickle.dump(custom_infos_val, f)
  343. print('Custom info val file is saved to %s' % val_filename)
  344. with open(trainval_filename, 'wb') as f:
  345. pickle.dump(custom_infos_train + custom_infos_val, f)
  346. print('Custom info trainval file is saved to %s' % trainval_filename)
  347. dataset.set_split('test')
  348. # kitti_infos_test = dataset.get_infos(num_workers=workers, has_label=False, count_inside_pts=False)
  349. custom_infos_test = dataset.get_infos(num_workers=workers, has_label=False, count_inside_pts=False)
  350. with open(test_filename, 'wb') as f:
  351. pickle.dump(custom_infos_test, f)
  352. print('Custom info test file is saved to %s' % test_filename)
  353. print('---------------Start create groundtruth database for data augmentation---------------')
  354. # 用trainfile产生groundtruth_database
  355. # 只保存训练数据中的gt_box及其包围点的信息,用于数据增强
  356. dataset.set_split(train_split)
  357. dataset.create_groundtruth_database(info_path=train_filename, split=train_split)
  358. print('---------------Data preparation Done---------------')
  359. if __name__=='__main__':
  360. import sys
  361. if sys.argv.__len__() > 1 and sys.argv[1] == 'create_custom_infos':
  362. import yaml
  363. from pathlib import Path
  364. from easydict import EasyDict
  365. dataset_cfg = EasyDict(yaml.safe_load(open(sys.argv[2])))
  366. ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve()
  367. create_custom_infos(
  368. dataset_cfg=dataset_cfg,
  369. class_names=['Car'], # 1.修改类别
  370. data_path=ROOT_DIR / 'data' / 'custom',
  371. save_path=ROOT_DIR / 'data' / 'custom'
  372. )

注:源码中已存在custom的相关文件,因为数据标注格式以kitti为标准,所以笔者是基于kitti文件的格式进行修改

生成标注数据指令

python -m pcdet.datasets.kitti.custom_dataset create_custom_infos tools/cfgs/dataset_configs/kitti_custom_dataset.yaml

没有报错的话,文件夹中会生成一些文件,如图:

2 模型训练

笔者选用模型为potinpillar,其他模型以此类推
修改文件如下:

Openpcdet-Test-master/tools/cfgs/kitti_models/pointpillar.yaml

  1. # CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist'] # 修改类别
  2. CLASS_NAMES: ['Car']
  3. DATA_CONFIG:
  4. _BASE_CONFIG_: /home/dxy/Openpcdet-Test-master/tools/cfgs/dataset_configs/kitti_custom_dataset.yaml
  5. POINT_CLOUD_RANGE: [0, -39.68, -3, 69.12, 39.68, 1]
  6. DATA_PROCESSOR:
  7. - NAME: mask_points_and_boxes_outside_range
  8. REMOVE_OUTSIDE_BOXES: True
  9. - NAME: shuffle_points
  10. SHUFFLE_ENABLED: {
  11. 'train': True,
  12. 'test': False
  13. }
  14. - NAME: transform_points_to_voxels
  15. VOXEL_SIZE: [0.16, 0.16, 4]
  16. MAX_POINTS_PER_VOXEL: 32
  17. MAX_NUMBER_OF_VOXELS: {
  18. 'train': 16000,
  19. 'test': 40000
  20. }
  21. DATA_AUGMENTOR:
  22. DISABLE_AUG_LIST: ['placeholder']
  23. AUG_CONFIG_LIST:
  24. - NAME: gt_sampling
  25. USE_ROAD_PLANE: False
  26. DB_INFO_PATH:
  27. - custom_dbinfos_train.pkl
  28. PREPARE: {
  29. filter_by_min_points: ['Car:5'],
  30. filter_by_difficulty: [-1],
  31. }
  32. SAMPLE_GROUPS: ['Car:5']
  33. NUM_POINT_FEATURES: 4
  34. DATABASE_WITH_FAKELIDAR: False
  35. REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0]
  36. LIMIT_WHOLE_SCENE: False
  37. - NAME: random_world_flip
  38. ALONG_AXIS_LIST: ['x']
  39. - NAME: random_world_rotation
  40. WORLD_ROT_ANGLE: [-0.78539816, 0.78539816]
  41. - NAME: random_world_scaling
  42. WORLD_SCALE_RANGE: [0.95, 1.05]
  43. MODEL:
  44. NAME: PointPillar
  45. VFE:
  46. NAME: PillarVFE
  47. WITH_DISTANCE: False
  48. USE_ABSLOTE_XYZ: True
  49. USE_NORM: True
  50. NUM_FILTERS: [64]
  51. MAP_TO_BEV:
  52. NAME: PointPillarScatter
  53. NUM_BEV_FEATURES: 64
  54. BACKBONE_2D:
  55. NAME: BaseBEVBackbone
  56. LAYER_NUMS: [3, 5, 5]
  57. LAYER_STRIDES: [2, 2, 2]
  58. NUM_FILTERS: [64, 128, 256]
  59. UPSAMPLE_STRIDES: [1, 2, 4]
  60. NUM_UPSAMPLE_FILTERS: [128, 128, 128]
  61. DENSE_HEAD:
  62. NAME: AnchorHeadSingle
  63. CLASS_AGNOSTIC: False
  64. USE_DIRECTION_CLASSIFIER: True
  65. DIR_OFFSET: 0.78539
  66. DIR_LIMIT_OFFSET: 0.0
  67. NUM_DIR_BINS: 2
  68. # anchor配置,需要适配自己的数据集
  69. ANCHOR_GENERATOR_CONFIG: [
  70. # {
  71. # 'class_name': 'Car',
  72. # 'anchor_sizes': [[3.9, 1.6, 1.56]],
  73. # 'anchor_rotations': [0, 1.57],
  74. # 'anchor_bottom_heights': [-1.78],
  75. # 'align_center': False,
  76. # 'feature_map_stride': 2,
  77. # 'matched_threshold': 0.6,
  78. # 'unmatched_threshold': 0.45
  79. # },
  80. {
  81. 'class_name': 'Car',
  82. 'anchor_sizes': [[0.8, 0.6, 1.0]],
  83. 'anchor_rotations': [0, 1.57],
  84. 'anchor_bottom_heights': [-0.4],
  85. 'align_center': False,
  86. 'feature_map_stride': 2,
  87. 'matched_threshold': 0.5,
  88. 'unmatched_threshold': 0.35
  89. },
  90. # {
  91. # 'class_name': 'stone',
  92. # 'anchor_sizes': [[1.0, 1.0, 0.73]],
  93. # 'anchor_rotations': [0, 1.57],
  94. # 'anchor_bottom_heights': [-0.6],
  95. # 'align_center': False,
  96. # 'feature_map_stride': 2,
  97. # 'matched_threshold': 0.5,
  98. # 'unmatched_threshold': 0.35
  99. # }
  100. ]
  101. TARGET_ASSIGNER_CONFIG:
  102. NAME: AxisAlignedTargetAssigner
  103. POS_FRACTION: -1.0
  104. SAMPLE_SIZE: 512
  105. NORM_BY_NUM_EXAMPLES: False
  106. MATCH_HEIGHT: False
  107. BOX_CODER: ResidualCoder
  108. LOSS_CONFIG:
  109. LOSS_WEIGHTS: {
  110. 'cls_weight': 1.0,
  111. 'loc_weight': 2.0,
  112. 'dir_weight': 0.2,
  113. 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
  114. }
  115. POST_PROCESSING:
  116. RECALL_THRESH_LIST: [0.3, 0.5, 0.7]
  117. SCORE_THRESH: 0.1
  118. OUTPUT_RAW_SCORE: False
  119. EVAL_METRIC: kitti
  120. NMS_CONFIG:
  121. MULTI_CLASSES_NMS: False
  122. NMS_TYPE: nms_gpu
  123. NMS_THRESH: 0.01
  124. NMS_PRE_MAXSIZE: 4096
  125. NMS_POST_MAXSIZE: 500
  126. OPTIMIZATION:
  127. BATCH_SIZE_PER_GPU: 4
  128. NUM_EPOCHS: 80
  129. OPTIMIZER: adam_onecycle
  130. LR: 0.003
  131. WEIGHT_DECAY: 0.01
  132. MOMENTUM: 0.9
  133. MOMS: [0.95, 0.85]
  134. PCT_START: 0.4
  135. DIV_FACTOR: 10
  136. DECAY_STEP_LIST: [35, 45]
  137. LR_DECAY: 0.1
  138. LR_CLIP: 0.0000001
  139. LR_WARMUP: False
  140. WARMUP_EPOCH: 1
  141. GRAD_NORM_CLIP: 10

训练指令:

python tools/train.py --cfg_file tools/cfgs/kitti_models/pointpillar.yaml --batch_size=2 --epochs=100
正常训练的话,应该是这样的:

测试指令:
python tools/demo.py --cfg_file /home/dxy/Openpcdet-Test-master/tools/cfgs/kitti_models/pointpillar.yaml  --data_path /home/dxy/Openpcdet-Test-master/data/custom/testing/velodyne/ --ckpt /home/dxy/Openpcdet-Test-master/output/cfgs/kitti_models/pointpillar/default/ckpt/checkpoint_epoch_100.pth

3.实验过程中遇到的一些问题:

报错:‘quaternion_to_rotation_matrix‘ is being compiled since it was called from ‘quat_to_mat‘

解决办法:

pip install kornia==0.6.8

解决print(torch.cuda.is_available())返回false的问题

重启一遍试试!!!

安装open3d报错:“libc++.so.1: cannot open shared object file: No such file or directory

解决办法:

sudo apt install libc++-dev

题外话:

数据训练所需:

nomachine    https://downloads.nomachine.com/download/?id=40

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/795078
推荐阅读
相关标签
  

闽ICP备14008679号