当前位置:   article > 正文

OpenPCDet训练Waymo数据集_openpcdet waymo open

openpcdet waymo open

一、前言

官方给出的方法:

https://github.com/open-mmlab/OpenPCDet/blob/master/docs/GETTING_STARTED.mdicon-default.png?t=N7T8https://github.com/open-mmlab/OpenPCDet/blob/master/docs/GETTING_STARTED.md

本文采用两个方法生成数据集官方方法自定义路径方法,后者是为了防止内存不够用。


二、使用官方方法生成并训练数据集

1、整理数据

(1)下载官方数据集

 Waymo Open Dataset, 包括训练数据:training_0000.tar~training_0031.tar以及验证集数据: validation_0000.tar~validation_0007.tar(如果仅仅只是想用一部分进行训练,可只下载几个tar压缩包就可以,数量不会影响)。

(2)将上述所有xxxx.tar文件解压到data/waymo/raw_data目录(可以得到798训练tfrecord和202验证tfrecord)。

(3)所有的数据集名称需要更改为:segment-xxxxxxxx.tfrecord。

(4)整理目录

在OpenPCDet-master/data文件目录需要整理如下:

├── data

│   ├── waymo

│   │   │── ImageSets

│   │   │── raw_data

│   │   │   │── segment-xxxxxxxx.tfrecord

|   |   |   |── ...


注意:这里只能在OpenPCDet根目录下生成数据集,否则需要更改源码,后面的方法有讲到


简单整理如下图所示:

raw_data目录:

注意:这里raw_data一定要将train和val的数据集放到一起,不能只放其中一个,否则之后eval会报错,无法评估结果。 

2、安装工具包

pip install waymo-open-dataset-tf-2.11.0==1.5.0

我安装的是waymo-open-dataset-tf-2.11.0,大家根据需求来安装版本,我的2-1-0不能用。


3、使用命令进行转换

从tfrecord中提取点云数据并生成数据信息:

python -m pcdet.datasets.waymo.waymo_dataset --func create_waymo_infos --cfg_file tools/cfgs/dataset_configs/waymo_dataset.yaml

最终生成的数据集目录如下所示:

# Download Waymo and organize it into the following form:

├── data

│   ├── waymo

│   │   │── ImageSets

│   │   │── raw_data

│   │   │   │── segment-xxxxxxxx.tfrecord

|   |   |   |── ...

|   |   |── waymo_processed_data

│   │   │   │── segment-xxxxxxxx/

|   |   |   |── ...

│   │   │── pcdet_gt_database_train_sampled_xx/

│   │   │── pcdet_waymo_dbinfos_train_sampled_xx.pkl

如图:

4、训练

这里以PV_RCNN++为例,使用命令:

python train.py --cfg_file /home/xd/xyy/OpenPCDet-master/tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml

三、在自定义目录下生成并训练数据集


自定义目录下生成并训练数据集主要是为了节约内存,不同版本pcdet及相关代码共用一个数据集。


1、整理数据

(1)请下载官方数据集:

Waymo Open Dataset, 包括训练数据:training_0000.tar~training_0031.tar以及验证集数据: validation_0000.tar~validation_0007.tar.

(2)将上述所有xxxx.tar文件解压到data/waymo/raw_data目录(可以得到798训练tfrecord和202验证tfrecord):

(3)所有的数据集名称需要更改为:segment-xxxxxxxx.tfrecord,如果是下载的是individual

(4)整理目录

前三步是与上面一样的。

下面整理目录可以将数据放在xxx/data文件目录下:

├xxx

├── data

│   ├── waymo

│   │   │── ImageSets

│   │   │── raw_data

│   │   │   │── segment-xxxxxxxx.tfrecord

|   |   |   |── ...

以下面图举例,我是放在了自定义文件夹hpc/data/waymo_pcdet_mini中,

raw_data文件夹如下所示:

2、安装工具包

pip install waymo-open-dataset-tf-2.11.0==1.5.0

3、更改waymo源码


注:源码只能在OpenPCDet-master目录下生成数据,这里需要新加一些可以传入的参数进行转换。


(1)更改生成数据源码——waymo_dataset.py

位置:

/OpenPCDet-master/pcdet/datasets/waymo/waymo_dataset.py

更改部分主要是:parser和ROOT_DIR

其中,更改部分如下图所示:

更改后的源码如下:

  1. # OpenPCDet PyTorch Dataloader and Evaluation Tools for Waymo Open Dataset
  2. # Reference https://github.com/open-mmlab/OpenPCDet
  3. # Written by Shaoshuai Shi, Chaoxu Guo
  4. # All Rights Reserved.
  5. import os
  6. import pickle
  7. import copy
  8. import numpy as np
  9. import torch
  10. import multiprocessing
  11. import SharedArray
  12. import torch.distributed as dist
  13. from tqdm import tqdm
  14. from pathlib import Path
  15. from functools import partial
  16. from ...ops.roiaware_pool3d import roiaware_pool3d_utils
  17. from ...utils import box_utils, common_utils
  18. from ..dataset import DatasetTemplate
  19. class WaymoDataset(DatasetTemplate):
  20. def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None):
  21. super().__init__(
  22. dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
  23. )
  24. self.data_path = self.root_path / self.dataset_cfg.PROCESSED_DATA_TAG
  25. self.split = self.dataset_cfg.DATA_SPLIT[self.mode]
  26. split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
  27. self.sample_sequence_list = [x.strip() for x in open(split_dir).readlines()]
  28. self.infos = []
  29. self.seq_name_to_infos = self.include_waymo_data(self.mode)
  30. self.use_shared_memory = self.dataset_cfg.get('USE_SHARED_MEMORY', False) and self.training
  31. if self.use_shared_memory:
  32. self.shared_memory_file_limit = self.dataset_cfg.get('SHARED_MEMORY_FILE_LIMIT', 0x7FFFFFFF)
  33. self.load_data_to_shared_memory()
  34. if self.dataset_cfg.get('USE_PREDBOX', False):
  35. self.pred_boxes_dict = self.load_pred_boxes_to_dict(
  36. pred_boxes_path=self.dataset_cfg.ROI_BOXES_PATH[self.mode]
  37. )
  38. else:
  39. self.pred_boxes_dict = {}
  40. def set_split(self, split):
  41. super().__init__(
  42. dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training,
  43. root_path=self.root_path, logger=self.logger
  44. )
  45. self.split = split
  46. split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
  47. self.sample_sequence_list = [x.strip() for x in open(split_dir).readlines()]
  48. self.infos = []
  49. self.seq_name_to_infos = self.include_waymo_data(self.mode)
  50. def include_waymo_data(self, mode):
  51. self.logger.info('Loading Waymo dataset')
  52. waymo_infos = []
  53. seq_name_to_infos = {}
  54. num_skipped_infos = 0
  55. for k in range(len(self.sample_sequence_list)):
  56. sequence_name = os.path.splitext(self.sample_sequence_list[k])[0]
  57. info_path = self.data_path / sequence_name / ('%s.pkl' % sequence_name)
  58. info_path = self.check_sequence_name_with_all_version(info_path)
  59. if not info_path.exists():
  60. num_skipped_infos += 1
  61. continue
  62. with open(info_path, 'rb') as f:
  63. infos = pickle.load(f)
  64. waymo_infos.extend(infos)
  65. seq_name_to_infos[infos[0]['point_cloud']['lidar_sequence']] = infos
  66. self.infos.extend(waymo_infos[:])
  67. self.logger.info('Total skipped info %s' % num_skipped_infos)
  68. self.logger.info('Total samples for Waymo dataset: %d' % (len(waymo_infos)))
  69. if self.dataset_cfg.SAMPLED_INTERVAL[mode] > 1:
  70. sampled_waymo_infos = []
  71. for k in range(0, len(self.infos), self.dataset_cfg.SAMPLED_INTERVAL[mode]):
  72. sampled_waymo_infos.append(self.infos[k])
  73. self.infos = sampled_waymo_infos
  74. self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos))
  75. use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
  76. if not use_sequence_data:
  77. seq_name_to_infos = None
  78. return seq_name_to_infos
  79. def load_pred_boxes_to_dict(self, pred_boxes_path):
  80. self.logger.info(f'Loading and reorganizing pred_boxes to dict from path: {pred_boxes_path}')
  81. with open(pred_boxes_path, 'rb') as f:
  82. pred_dicts = pickle.load(f)
  83. pred_boxes_dict = {}
  84. for index, box_dict in enumerate(pred_dicts):
  85. seq_name = box_dict['frame_id'][:-4].replace('training_', '').replace('validation_', '')
  86. sample_idx = int(box_dict['frame_id'][-3:])
  87. if seq_name not in pred_boxes_dict:
  88. pred_boxes_dict[seq_name] = {}
  89. pred_labels = np.array([self.class_names.index(box_dict['name'][k]) + 1 for k in range(box_dict['name'].shape[0])])
  90. pred_boxes = np.concatenate((box_dict['boxes_lidar'], box_dict['score'][:, np.newaxis], pred_labels[:, np.newaxis]), axis=-1)
  91. pred_boxes_dict[seq_name][sample_idx] = pred_boxes
  92. self.logger.info(f'Predicted boxes has been loaded, total sequences: {len(pred_boxes_dict)}')
  93. return pred_boxes_dict
  94. def load_data_to_shared_memory(self):
  95. self.logger.info(f'Loading training data to shared memory (file limit={self.shared_memory_file_limit})')
  96. cur_rank, num_gpus = common_utils.get_dist_info()
  97. all_infos = self.infos[:self.shared_memory_file_limit] \
  98. if self.shared_memory_file_limit < len(self.infos) else self.infos
  99. cur_infos = all_infos[cur_rank::num_gpus]
  100. for info in cur_infos:
  101. pc_info = info['point_cloud']
  102. sequence_name = pc_info['lidar_sequence']
  103. sample_idx = pc_info['sample_idx']
  104. sa_key = f'{sequence_name}___{sample_idx}'
  105. if os.path.exists(f"/dev/shm/{sa_key}"):
  106. continue
  107. points = self.get_lidar(sequence_name, sample_idx)
  108. common_utils.sa_create(f"shm://{sa_key}", points)
  109. dist.barrier()
  110. self.logger.info('Training data has been saved to shared memory')
  111. def clean_shared_memory(self):
  112. self.logger.info(f'Clean training data from shared memory (file limit={self.shared_memory_file_limit})')
  113. cur_rank, num_gpus = common_utils.get_dist_info()
  114. all_infos = self.infos[:self.shared_memory_file_limit] \
  115. if self.shared_memory_file_limit < len(self.infos) else self.infos
  116. cur_infos = all_infos[cur_rank::num_gpus]
  117. for info in cur_infos:
  118. pc_info = info['point_cloud']
  119. sequence_name = pc_info['lidar_sequence']
  120. sample_idx = pc_info['sample_idx']
  121. sa_key = f'{sequence_name}___{sample_idx}'
  122. if not os.path.exists(f"/dev/shm/{sa_key}"):
  123. continue
  124. SharedArray.delete(f"shm://{sa_key}")
  125. if num_gpus > 1:
  126. dist.barrier()
  127. self.logger.info('Training data has been deleted from shared memory')
  128. @staticmethod
  129. def check_sequence_name_with_all_version(sequence_file):
  130. if not sequence_file.exists():
  131. found_sequence_file = sequence_file
  132. for pre_text in ['training', 'validation', 'testing']:
  133. if not sequence_file.exists():
  134. temp_sequence_file = Path(str(sequence_file).replace('segment', pre_text + '_segment'))
  135. if temp_sequence_file.exists():
  136. found_sequence_file = temp_sequence_file
  137. break
  138. if not found_sequence_file.exists():
  139. found_sequence_file = Path(str(sequence_file).replace('_with_camera_labels', ''))
  140. if found_sequence_file.exists():
  141. sequence_file = found_sequence_file
  142. return sequence_file
  143. def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True, sampled_interval=1, update_info_only=False):
  144. from . import waymo_utils
  145. print('---------------The waymo sample interval is %d, total sequecnes is %d-----------------'
  146. % (sampled_interval, len(self.sample_sequence_list)))
  147. process_single_sequence = partial(
  148. waymo_utils.process_single_sequence,
  149. save_path=save_path, sampled_interval=sampled_interval, has_label=has_label, update_info_only=update_info_only
  150. )
  151. sample_sequence_file_list = [
  152. self.check_sequence_name_with_all_version(raw_data_path / sequence_file)
  153. for sequence_file in self.sample_sequence_list
  154. ]
  155. # process_single_sequence(sample_sequence_file_list[0])
  156. with multiprocessing.Pool(num_workers) as p:
  157. sequence_infos = list(tqdm(p.imap(process_single_sequence, sample_sequence_file_list),
  158. total=len(sample_sequence_file_list)))
  159. all_sequences_infos = [item for infos in sequence_infos for item in infos]
  160. return all_sequences_infos
  161. def get_lidar(self, sequence_name, sample_idx):
  162. lidar_file = self.data_path / sequence_name / ('%04d.npy' % sample_idx)
  163. point_features = np.load(lidar_file) # (N, 7): [x, y, z, intensity, elongation, NLZ_flag]
  164. points_all, NLZ_flag = point_features[:, 0:5], point_features[:, 5]
  165. if not self.dataset_cfg.get('DISABLE_NLZ_FLAG_ON_POINTS', False):
  166. points_all = points_all[NLZ_flag == -1]
  167. points_all[:, 3] = np.tanh(points_all[:, 3])
  168. return points_all
  169. @staticmethod
  170. def transform_prebox_to_current(pred_boxes3d, pose_pre, pose_cur):
  171. """
  172. Args:
  173. pred_boxes3d (N, 9 or 11): [x, y, z, dx, dy, dz, raw, <vx, vy,> score, label]
  174. pose_pre (4, 4):
  175. pose_cur (4, 4):
  176. Returns:
  177. """
  178. assert pred_boxes3d.shape[-1] in [9, 11]
  179. pred_boxes3d = pred_boxes3d.copy()
  180. expand_bboxes = np.concatenate([pred_boxes3d[:, :3], np.ones((pred_boxes3d.shape[0], 1))], axis=-1)
  181. bboxes_global = np.dot(expand_bboxes, pose_pre.T)[:, :3]
  182. expand_bboxes_global = np.concatenate([bboxes_global[:, :3],np.ones((bboxes_global.shape[0], 1))], axis=-1)
  183. bboxes_pre2cur = np.dot(expand_bboxes_global, np.linalg.inv(pose_cur.T))[:, :3]
  184. pred_boxes3d[:, 0:3] = bboxes_pre2cur
  185. if pred_boxes3d.shape[-1] == 11:
  186. expand_vels = np.concatenate([pred_boxes3d[:, 7:9], np.zeros((pred_boxes3d.shape[0], 1))], axis=-1)
  187. vels_global = np.dot(expand_vels, pose_pre[:3, :3].T)
  188. vels_pre2cur = np.dot(vels_global, np.linalg.inv(pose_cur[:3, :3].T))[:,:2]
  189. pred_boxes3d[:, 7:9] = vels_pre2cur
  190. pred_boxes3d[:, 6] = pred_boxes3d[..., 6] + np.arctan2(pose_pre[..., 1, 0], pose_pre[..., 0, 0])
  191. pred_boxes3d[:, 6] = pred_boxes3d[..., 6] - np.arctan2(pose_cur[..., 1, 0], pose_cur[..., 0, 0])
  192. return pred_boxes3d
  193. @staticmethod
  194. def reorder_rois_for_refining(pred_bboxes):
  195. num_max_rois = max([len(bbox) for bbox in pred_bboxes])
  196. num_max_rois = max(1, num_max_rois) # at least one faked rois to avoid error
  197. ordered_bboxes = np.zeros([len(pred_bboxes), num_max_rois, pred_bboxes[0].shape[-1]], dtype=np.float32)
  198. for bs_idx in range(ordered_bboxes.shape[0]):
  199. ordered_bboxes[bs_idx, :len(pred_bboxes[bs_idx])] = pred_bboxes[bs_idx]
  200. return ordered_bboxes
  201. def get_sequence_data(self, info, points, sequence_name, sample_idx, sequence_cfg, load_pred_boxes=False):
  202. """
  203. Args:
  204. info:
  205. points:
  206. sequence_name:
  207. sample_idx:
  208. sequence_cfg:
  209. Returns:
  210. """
  211. def remove_ego_points(points, center_radius=1.0):
  212. mask = ~((np.abs(points[:, 0]) < center_radius) & (np.abs(points[:, 1]) < center_radius))
  213. return points[mask]
  214. def load_pred_boxes_from_dict(sequence_name, sample_idx):
  215. """
  216. boxes: (N, 11) [x, y, z, dx, dy, dn, raw, vx, vy, score, label]
  217. """
  218. sequence_name = sequence_name.replace('training_', '').replace('validation_', '')
  219. load_boxes = self.pred_boxes_dict[sequence_name][sample_idx]
  220. assert load_boxes.shape[-1] == 11
  221. load_boxes[:, 7:9] = -0.1 * load_boxes[:, 7:9] # transfer speed to negtive motion from t to t-1
  222. return load_boxes
  223. pose_cur = info['pose'].reshape((4, 4))
  224. num_pts_cur = points.shape[0]
  225. sample_idx_pre_list = np.clip(sample_idx + np.arange(sequence_cfg.SAMPLE_OFFSET[0], sequence_cfg.SAMPLE_OFFSET[1]), 0, 0x7FFFFFFF)
  226. sample_idx_pre_list = sample_idx_pre_list[::-1]
  227. if sequence_cfg.get('ONEHOT_TIMESTAMP', False):
  228. onehot_cur = np.zeros((points.shape[0], len(sample_idx_pre_list) + 1)).astype(points.dtype)
  229. onehot_cur[:, 0] = 1
  230. points = np.hstack([points, onehot_cur])
  231. else:
  232. points = np.hstack([points, np.zeros((points.shape[0], 1)).astype(points.dtype)])
  233. points_pre_all = []
  234. num_points_pre = []
  235. pose_all = [pose_cur]
  236. pred_boxes_all = []
  237. if load_pred_boxes:
  238. pred_boxes = load_pred_boxes_from_dict(sequence_name, sample_idx)
  239. pred_boxes_all.append(pred_boxes)
  240. sequence_info = self.seq_name_to_infos[sequence_name]
  241. for idx, sample_idx_pre in enumerate(sample_idx_pre_list):
  242. points_pre = self.get_lidar(sequence_name, sample_idx_pre)
  243. pose_pre = sequence_info[sample_idx_pre]['pose'].reshape((4, 4))
  244. expand_points_pre = np.concatenate([points_pre[:, :3], np.ones((points_pre.shape[0], 1))], axis=-1)
  245. points_pre_global = np.dot(expand_points_pre, pose_pre.T)[:, :3]
  246. expand_points_pre_global = np.concatenate([points_pre_global, np.ones((points_pre_global.shape[0], 1))], axis=-1)
  247. points_pre2cur = np.dot(expand_points_pre_global, np.linalg.inv(pose_cur.T))[:, :3]
  248. points_pre = np.concatenate([points_pre2cur, points_pre[:, 3:]], axis=-1)
  249. if sequence_cfg.get('ONEHOT_TIMESTAMP', False):
  250. onehot_vector = np.zeros((points_pre.shape[0], len(sample_idx_pre_list) + 1))
  251. onehot_vector[:, idx + 1] = 1
  252. points_pre = np.hstack([points_pre, onehot_vector])
  253. else:
  254. # add timestamp
  255. points_pre = np.hstack([points_pre, 0.1 * (sample_idx - sample_idx_pre) * np.ones((points_pre.shape[0], 1)).astype(points_pre.dtype)]) # one frame 0.1s
  256. points_pre = remove_ego_points(points_pre, 1.0)
  257. points_pre_all.append(points_pre)
  258. num_points_pre.append(points_pre.shape[0])
  259. pose_all.append(pose_pre)
  260. if load_pred_boxes:
  261. pose_pre = sequence_info[sample_idx_pre]['pose'].reshape((4, 4))
  262. pred_boxes = load_pred_boxes_from_dict(sequence_name, sample_idx_pre)
  263. pred_boxes = self.transform_prebox_to_current(pred_boxes, pose_pre, pose_cur)
  264. pred_boxes_all.append(pred_boxes)
  265. points = np.concatenate([points] + points_pre_all, axis=0).astype(np.float32)
  266. num_points_all = np.array([num_pts_cur] + num_points_pre).astype(np.int32)
  267. poses = np.concatenate(pose_all, axis=0).astype(np.float32)
  268. if load_pred_boxes:
  269. temp_pred_boxes = self.reorder_rois_for_refining(pred_boxes_all)
  270. pred_boxes = temp_pred_boxes[:, :, 0:9]
  271. pred_scores = temp_pred_boxes[:, :, 9]
  272. pred_labels = temp_pred_boxes[:, :, 10]
  273. else:
  274. pred_boxes = pred_scores = pred_labels = None
  275. return points, num_points_all, sample_idx_pre_list, poses, pred_boxes, pred_scores, pred_labels
  276. def __len__(self):
  277. if self._merge_all_iters_to_one_epoch:
  278. return len(self.infos) * self.total_epochs
  279. return len(self.infos)
  280. def __getitem__(self, index):
  281. if self._merge_all_iters_to_one_epoch:
  282. index = index % len(self.infos)
  283. info = copy.deepcopy(self.infos[index])
  284. pc_info = info['point_cloud']
  285. sequence_name = pc_info['lidar_sequence']
  286. sample_idx = pc_info['sample_idx']
  287. input_dict = {
  288. 'sample_idx': sample_idx
  289. }
  290. if self.use_shared_memory and index < self.shared_memory_file_limit:
  291. sa_key = f'{sequence_name}___{sample_idx}'
  292. points = SharedArray.attach(f"shm://{sa_key}").copy()
  293. else:
  294. points = self.get_lidar(sequence_name, sample_idx)
  295. if self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED:
  296. points, num_points_all, sample_idx_pre_list, poses, pred_boxes, pred_scores, pred_labels = self.get_sequence_data(
  297. info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG,
  298. load_pred_boxes=self.dataset_cfg.get('USE_PREDBOX', False)
  299. )
  300. input_dict['poses'] = poses
  301. if self.dataset_cfg.get('USE_PREDBOX', False):
  302. input_dict.update({
  303. 'roi_boxes': pred_boxes,
  304. 'roi_scores': pred_scores,
  305. 'roi_labels': pred_labels,
  306. })
  307. input_dict.update({
  308. 'points': points,
  309. 'frame_id': info['frame_id'],
  310. })
  311. if 'annos' in info:
  312. annos = info['annos']
  313. annos = common_utils.drop_info_with_name(annos, name='unknown')
  314. if self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False):
  315. gt_boxes_lidar = box_utils.boxes3d_kitti_fakelidar_to_lidar(annos['gt_boxes_lidar'])
  316. else:
  317. gt_boxes_lidar = annos['gt_boxes_lidar']
  318. if self.dataset_cfg.get('TRAIN_WITH_SPEED', False):
  319. assert gt_boxes_lidar.shape[-1] == 9
  320. else:
  321. gt_boxes_lidar = gt_boxes_lidar[:, 0:7]
  322. if self.training and self.dataset_cfg.get('FILTER_EMPTY_BOXES_FOR_TRAIN', False):
  323. mask = (annos['num_points_in_gt'] > 0) # filter empty boxes
  324. annos['name'] = annos['name'][mask]
  325. gt_boxes_lidar = gt_boxes_lidar[mask]
  326. annos['num_points_in_gt'] = annos['num_points_in_gt'][mask]
  327. input_dict.update({
  328. 'gt_names': annos['name'],
  329. 'gt_boxes': gt_boxes_lidar,
  330. 'num_points_in_gt': annos.get('num_points_in_gt', None)
  331. })
  332. data_dict = self.prepare_data(data_dict=input_dict)
  333. data_dict['metadata'] = info.get('metadata', info['frame_id'])
  334. data_dict.pop('num_points_in_gt', None)
  335. return data_dict
  336. def evaluation(self, det_annos, class_names, **kwargs):
  337. if 'annos' not in self.infos[0].keys():
  338. return 'No ground-truth boxes for evaluation', {}
  339. def kitti_eval(eval_det_annos, eval_gt_annos):
  340. from ..kitti.kitti_object_eval_python import eval as kitti_eval
  341. from ..kitti import kitti_utils
  342. map_name_to_kitti = {
  343. 'Vehicle': 'Car',
  344. 'Pedestrian': 'Pedestrian',
  345. 'Cyclist': 'Cyclist',
  346. 'Sign': 'Sign',
  347. 'Car': 'Car'
  348. }
  349. kitti_utils.transform_annotations_to_kitti_format(eval_det_annos, map_name_to_kitti=map_name_to_kitti)
  350. kitti_utils.transform_annotations_to_kitti_format(
  351. eval_gt_annos, map_name_to_kitti=map_name_to_kitti,
  352. info_with_fakelidar=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
  353. )
  354. kitti_class_names = [map_name_to_kitti[x] for x in class_names]
  355. ap_result_str, ap_dict = kitti_eval.get_official_eval_result(
  356. gt_annos=eval_gt_annos, dt_annos=eval_det_annos, current_classes=kitti_class_names
  357. )
  358. return ap_result_str, ap_dict
  359. def waymo_eval(eval_det_annos, eval_gt_annos):
  360. from .waymo_eval import OpenPCDetWaymoDetectionMetricsEstimator
  361. eval = OpenPCDetWaymoDetectionMetricsEstimator()
  362. ap_dict = eval.waymo_evaluation(
  363. eval_det_annos, eval_gt_annos, class_name=class_names,
  364. distance_thresh=1000, fake_gt_infos=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
  365. )
  366. ap_result_str = '\n'
  367. for key in ap_dict:
  368. ap_dict[key] = ap_dict[key][0]
  369. ap_result_str += '%s: %.4f \n' % (key, ap_dict[key])
  370. return ap_result_str, ap_dict
  371. eval_det_annos = copy.deepcopy(det_annos)
  372. eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.infos]
  373. if kwargs['eval_metric'] == 'kitti':
  374. ap_result_str, ap_dict = kitti_eval(eval_det_annos, eval_gt_annos)
  375. elif kwargs['eval_metric'] == 'waymo':
  376. ap_result_str, ap_dict = waymo_eval(eval_det_annos, eval_gt_annos)
  377. else:
  378. raise NotImplementedError
  379. return ap_result_str, ap_dict
  380. def create_groundtruth_database(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10,
  381. processed_data_tag=None):
  382. use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
  383. if use_sequence_data:
  384. st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1]
  385. self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0] = min(-4, st_frame) # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
  386. st_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0]
  387. database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
  388. db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s.pkl' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
  389. db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_global.npy' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame))
  390. else:
  391. database_save_path = save_path / ('%s_gt_database_%s_sampled_%d' % (processed_data_tag, split, sampled_interval))
  392. db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d.pkl' % (processed_data_tag, split, sampled_interval))
  393. db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_global.npy' % (processed_data_tag, split, sampled_interval))
  394. database_save_path.mkdir(parents=True, exist_ok=True)
  395. all_db_infos = {}
  396. with open(info_path, 'rb') as f:
  397. infos = pickle.load(f)
  398. point_offset_cnt = 0
  399. stacked_gt_points = []
  400. for k in tqdm(range(0, len(infos), sampled_interval)):
  401. # print('gt_database sample: %d/%d' % (k + 1, len(infos)))
  402. info = infos[k]
  403. pc_info = info['point_cloud']
  404. sequence_name = pc_info['lidar_sequence']
  405. sample_idx = pc_info['sample_idx']
  406. points = self.get_lidar(sequence_name, sample_idx)
  407. if use_sequence_data:
  408. points, num_points_all, sample_idx_pre_list, _, _, _, _ = self.get_sequence_data(
  409. info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG
  410. )
  411. annos = info['annos']
  412. names = annos['name']
  413. difficulty = annos['difficulty']
  414. gt_boxes = annos['gt_boxes_lidar']
  415. if k % 4 != 0 and len(names) > 0:
  416. mask = (names == 'Vehicle')
  417. names = names[~mask]
  418. difficulty = difficulty[~mask]
  419. gt_boxes = gt_boxes[~mask]
  420. if k % 2 != 0 and len(names) > 0:
  421. mask = (names == 'Pedestrian')
  422. names = names[~mask]
  423. difficulty = difficulty[~mask]
  424. gt_boxes = gt_boxes[~mask]
  425. num_obj = gt_boxes.shape[0]
  426. if num_obj == 0:
  427. continue
  428. box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
  429. torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
  430. torch.from_numpy(gt_boxes[:, 0:7]).unsqueeze(dim=0).float().cuda()
  431. ).long().squeeze(dim=0).cpu().numpy()
  432. for i in range(num_obj):
  433. filename = '%s_%04d_%s_%d.bin' % (sequence_name, sample_idx, names[i], i)
  434. filepath = database_save_path / filename
  435. gt_points = points[box_idxs_of_pts == i]
  436. gt_points[:, :3] -= gt_boxes[i, :3]
  437. if (used_classes is None) or names[i] in used_classes:
  438. gt_points = gt_points.astype(np.float32)
  439. assert gt_points.dtype == np.float32
  440. with open(filepath, 'w') as f:
  441. gt_points.tofile(f)
  442. db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
  443. db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
  444. 'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
  445. 'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i]}
  446. # it will be used if you choose to use shared memory for gt sampling
  447. stacked_gt_points.append(gt_points)
  448. db_info['global_data_offset'] = [point_offset_cnt, point_offset_cnt + gt_points.shape[0]]
  449. point_offset_cnt += gt_points.shape[0]
  450. if names[i] in all_db_infos:
  451. all_db_infos[names[i]].append(db_info)
  452. else:
  453. all_db_infos[names[i]] = [db_info]
  454. for k, v in all_db_infos.items():
  455. print('Database %s: %d' % (k, len(v)))
  456. with open(db_info_save_path, 'wb') as f:
  457. pickle.dump(all_db_infos, f)
  458. # it will be used if you choose to use shared memory for gt sampling
  459. stacked_gt_points = np.concatenate(stacked_gt_points, axis=0)
  460. np.save(db_data_save_path, stacked_gt_points)
  461. def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False, used_classes=None,
  462. total_samples=0, use_cuda=False, crop_gt_with_tail=False):
  463. info, info_idx = info_with_idx
  464. print('gt_database sample: %d/%d' % (info_idx, total_samples))
  465. all_db_infos = {}
  466. pc_info = info['point_cloud']
  467. sequence_name = pc_info['lidar_sequence']
  468. sample_idx = pc_info['sample_idx']
  469. points = self.get_lidar(sequence_name, sample_idx)
  470. if use_sequence_data:
  471. points, num_points_all, sample_idx_pre_list, _, _, _, _ = self.get_sequence_data(
  472. info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG
  473. )
  474. annos = info['annos']
  475. names = annos['name']
  476. difficulty = annos['difficulty']
  477. gt_boxes = annos['gt_boxes_lidar']
  478. if info_idx % 4 != 0 and len(names) > 0:
  479. mask = (names == 'Vehicle')
  480. names = names[~mask]
  481. difficulty = difficulty[~mask]
  482. gt_boxes = gt_boxes[~mask]
  483. if info_idx % 2 != 0 and len(names) > 0:
  484. mask = (names == 'Pedestrian')
  485. names = names[~mask]
  486. difficulty = difficulty[~mask]
  487. gt_boxes = gt_boxes[~mask]
  488. num_obj = gt_boxes.shape[0]
  489. if num_obj == 0:
  490. return {}
  491. if use_sequence_data and crop_gt_with_tail:
  492. assert gt_boxes.shape[1] == 9
  493. speed = gt_boxes[:, 7:9]
  494. sequence_cfg = self.dataset_cfg.SEQUENCE_CONFIG
  495. assert sequence_cfg.SAMPLE_OFFSET[1] == 0
  496. assert sequence_cfg.SAMPLE_OFFSET[0] < 0
  497. num_frames = sequence_cfg.SAMPLE_OFFSET[1] - sequence_cfg.SAMPLE_OFFSET[0] + 1
  498. assert num_frames > 1
  499. latest_center = gt_boxes[:, 0:2]
  500. oldest_center = latest_center - speed * (num_frames - 1) * 0.1
  501. new_center = (latest_center + oldest_center) * 0.5
  502. new_length = gt_boxes[:, 3] + np.linalg.norm(latest_center - oldest_center, axis=-1)
  503. gt_boxes_crop = gt_boxes.copy()
  504. gt_boxes_crop[:, 0:2] = new_center
  505. gt_boxes_crop[:, 3] = new_length
  506. else:
  507. gt_boxes_crop = gt_boxes
  508. if use_cuda:
  509. box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
  510. torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
  511. torch.from_numpy(gt_boxes_crop[:, 0:7]).unsqueeze(dim=0).float().cuda()
  512. ).long().squeeze(dim=0).cpu().numpy()
  513. else:
  514. box_point_mask = roiaware_pool3d_utils.points_in_boxes_cpu(
  515. torch.from_numpy(points[:, 0:3]).float(),
  516. torch.from_numpy(gt_boxes_crop[:, 0:7]).float()
  517. ).long().numpy() # (num_boxes, num_points)
  518. for i in range(num_obj):
  519. filename = '%s_%04d_%s_%d.bin' % (sequence_name, sample_idx, names[i], i)
  520. filepath = database_save_path / filename
  521. if use_cuda:
  522. gt_points = points[box_idxs_of_pts == i]
  523. else:
  524. gt_points = points[box_point_mask[i] > 0]
  525. gt_points[:, :3] -= gt_boxes[i, :3]
  526. if (used_classes is None) or names[i] in used_classes:
  527. gt_points = gt_points.astype(np.float32)
  528. assert gt_points.dtype == np.float32
  529. with open(filepath, 'w') as f:
  530. gt_points.tofile(f)
  531. db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
  532. db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
  533. 'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
  534. 'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i],
  535. 'box3d_crop': gt_boxes_crop[i]}
  536. if names[i] in all_db_infos:
  537. all_db_infos[names[i]].append(db_info)
  538. else:
  539. all_db_infos[names[i]] = [db_info]
  540. return all_db_infos
  541. def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10,
  542. processed_data_tag=None, num_workers=16, crop_gt_with_tail=False):
  543. use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
  544. if use_sequence_data:
  545. st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1]
  546. self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0] = min(-4, st_frame) # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
  547. st_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0]
  548. database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_%sparallel' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
  549. db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_%sparallel.pkl' % (processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
  550. else:
  551. database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_parallel' % (processed_data_tag, split, sampled_interval))
  552. db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_parallel.pkl' % (processed_data_tag, split, sampled_interval))
  553. database_save_path.mkdir(parents=True, exist_ok=True)
  554. with open(info_path, 'rb') as f:
  555. infos = pickle.load(f)
  556. print(f'Number workers: {num_workers}')
  557. create_gt_database_of_single_scene = partial(
  558. self.create_gt_database_of_single_scene,
  559. use_sequence_data=use_sequence_data, database_save_path=database_save_path,
  560. used_classes=used_classes, total_samples=len(infos), use_cuda=False,
  561. crop_gt_with_tail=crop_gt_with_tail
  562. )
  563. # create_gt_database_of_single_scene((infos[300], 0))
  564. with multiprocessing.Pool(num_workers) as p:
  565. all_db_infos_list = list(p.map(create_gt_database_of_single_scene, zip(infos, np.arange(len(infos)))))
  566. all_db_infos = {}
  567. for cur_db_infos in all_db_infos_list:
  568. for key, val in cur_db_infos.items():
  569. if key not in all_db_infos:
  570. all_db_infos[key] = val
  571. else:
  572. all_db_infos[key].extend(val)
  573. for k, v in all_db_infos.items():
  574. print('Database %s: %d' % (k, len(v)))
  575. with open(db_info_save_path, 'wb') as f:
  576. pickle.dump(all_db_infos, f)
  577. def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
  578. raw_data_tag='raw_data', processed_data_tag='waymo_processed_data',
  579. workers=min(16, multiprocessing.cpu_count()), update_info_only=False):
  580. dataset = WaymoDataset(
  581. dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
  582. training=False, logger=common_utils.create_logger()
  583. )
  584. train_split, val_split = 'train', 'val'
  585. train_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, train_split))
  586. val_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, val_split))
  587. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  588. print('---------------Start to generate data infos---------------')
  589. dataset.set_split(train_split)
  590. waymo_infos_train = dataset.get_infos(
  591. raw_data_path=data_path / raw_data_tag,
  592. save_path=save_path / processed_data_tag, num_workers=workers, has_label=True,
  593. sampled_interval=1, update_info_only=update_info_only
  594. )
  595. with open(train_filename, 'wb') as f:
  596. pickle.dump(waymo_infos_train, f)
  597. print('----------------Waymo info train file is saved to %s----------------' % train_filename)
  598. dataset.set_split(val_split)
  599. waymo_infos_val = dataset.get_infos(
  600. raw_data_path=data_path / raw_data_tag,
  601. save_path=save_path / processed_data_tag, num_workers=workers, has_label=True,
  602. sampled_interval=1, update_info_only=update_info_only
  603. )
  604. with open(val_filename, 'wb') as f:
  605. pickle.dump(waymo_infos_val, f)
  606. print('----------------Waymo info val file is saved to %s----------------' % val_filename)
  607. if update_info_only:
  608. return
  609. print('---------------Start create groundtruth database for data augmentation---------------')
  610. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  611. dataset.set_split(train_split)
  612. dataset.create_groundtruth_database(
  613. info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
  614. used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
  615. )
  616. print('---------------Data preparation Done---------------')
  617. def create_waymo_gt_database(
  618. dataset_cfg, class_names, data_path, save_path, processed_data_tag='waymo_processed_data',
  619. workers=min(16, multiprocessing.cpu_count()), use_parallel=False, crop_gt_with_tail=False):
  620. dataset = WaymoDataset(
  621. dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
  622. training=False, logger=common_utils.create_logger()
  623. )
  624. train_split = 'train'
  625. train_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, train_split))
  626. print('---------------Start create groundtruth database for data augmentation---------------')
  627. dataset.set_split(train_split)
  628. if use_parallel:
  629. dataset.create_groundtruth_database_parallel(
  630. info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
  631. used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag,
  632. num_workers=workers, crop_gt_with_tail=crop_gt_with_tail
  633. )
  634. else:
  635. dataset.create_groundtruth_database(
  636. info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
  637. used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
  638. )
  639. print('---------------Data preparation Done---------------')
  640. if __name__ == '__main__':
  641. import argparse
  642. import yaml
  643. from easydict import EasyDict
  644. parser = argparse.ArgumentParser(description='arg parser')
  645. parser.add_argument('--cfg_file', type=str, default=None, help='specify the config of dataset')
  646. parser.add_argument('--func', type=str, default='create_waymo_infos', help='')
  647. parser.add_argument('--processed_data_tag', type=str, default='waymo_processed_data_v0_5_0', help='')
  648. parser.add_argument('--update_info_only', action='store_true', default=False, help='')
  649. parser.add_argument('--use_parallel', action='store_true', default=False, help='')
  650. parser.add_argument('--wo_crop_gt_with_tail', action='store_true', default=False, help='')
  651. parser.add_argument('--data_path', default=None, help='')
  652. args = parser.parse_args()
  653. if args.data_path is not None:
  654. ROOT_DIR = (Path(args.data_path)).resolve()
  655. else:
  656. ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve() / 'data'/'waymo'
  657. # ROOT_DIR = (Path(self.dataset_cfg.DATA_PATH)).resolve()
  658. if args.func == 'create_waymo_infos':
  659. try:
  660. yaml_config = yaml.safe_load(open(args.cfg_file), Loader=yaml.FullLoader)
  661. except:
  662. yaml_config = yaml.safe_load(open(args.cfg_file))
  663. dataset_cfg = EasyDict(yaml_config)
  664. dataset_cfg.PROCESSED_DATA_TAG = args.processed_data_tag
  665. create_waymo_infos(
  666. dataset_cfg=dataset_cfg,
  667. class_names=['Vehicle', 'Pedestrian', 'Cyclist'],
  668. data_path=ROOT_DIR,
  669. save_path=ROOT_DIR,
  670. raw_data_tag='raw_data',
  671. processed_data_tag=args.processed_data_tag,
  672. update_info_only=args.update_info_only
  673. )
  674. elif args.func == 'create_waymo_gt_database':
  675. try:
  676. yaml_config = yaml.safe_load(open(args.cfg_file), Loader=yaml.FullLoader)
  677. except:
  678. yaml_config = yaml.safe_load(open(args.cfg_file))
  679. dataset_cfg = EasyDict(yaml_config)
  680. dataset_cfg.PROCESSED_DATA_TAG = args.processed_data_tag
  681. create_waymo_gt_database(
  682. dataset_cfg=dataset_cfg,
  683. class_names=['Vehicle', 'Pedestrian', 'Cyclist'],
  684. data_path=ROOT_DIR,
  685. save_path=ROOT_DIR,
  686. processed_data_tag=args.processed_data_tag,
  687. use_parallel=args.use_parallel,
  688. crop_gt_with_tail=not args.wo_crop_gt_with_tail
  689. )
  690. else:
  691. raise NotImplementedError

(2)更改训练源码

位置:OpenPCDet-master/tools/cfgs/dataset_configs/waymo_dataset.yaml

将DATA_PATH改为自己数据集的路径:

4、使用命令进行转换

从tfrecord中提取点云数据并生成数据信息:

python -m pcdet.datasets.waymo.waymo_dataset --func create_waymo_infos --cfg_file tools/cfgs/dataset_configs/waymo_dataset.yaml --data_path /media/xd/hpc/data/waymo_pcdet_mini/

指令格式:

指令格式为:python -m pcdet.datasets.waymo.waymo_dataset --func create_waymo_infos --cfg_file tools/cfgs/dataset_configs/waymo_dataset.yaml --data_path /media/xd/hpc/data/waymo_pcdet_mini/(waymo数据集路径)

5、训练

这里用PV_RCNN++举例,命令为

python train.py --cfg_file /home/xd/xyy/OpenPCDet-master/tools/cfgs/waymo_models/pv_rcnn_plusplus.yaml

四、关于数据集划分大小测验

我选择测试的是PV-RCNN++,官方的精度为:

测试1/8的结果如下(Traning和validation各取1/8):

 测试一半的数据集时(Traning和validation各取一半),测试结果如下:

这时已经很接近官方的结果了,所以最后还是拿全集去测试。

五、遇到问题

1、ModuleNotFoundError: No module named 'numpy.typing'

答:numpy版本太低了,需要安装numpy>=1.21.0。


2、NotImplementedError: Cannot convert a symbolic Tensor (strided_slice:0) to a numpy array. This error may indicate that you're trying to pass a Tensor to a NumPy call, which is not supported

答:

这是一个很恶心的问题,主要是由于tensorflow与numpy版本不匹配,如果你使用的是官方的指令就容易出现这种错误:

pip install waymo_open_dataset-tf-2-1-0


问题描述:这里的tf-2-1-0指的就是在tensorflow2.1环境下运行,而源码中有些地方需要numpy>=2.21.0(例如:av2需要numpy>=2.21.0)。

提升或降低numpy版本:如果升numpy版本,就会遇到现在这个问题,如果降低numpy版本,又会报问题1的错误。

提升或降低tensorflow版本:不管提升还是降低tensorflow版本,都会由于av2过不了evaluation。

因此会陷入死循环,所以建议提升waymo_open_dataset-tf工具版本。


解决方法:tensorflow版本提高,用如下指令:

pip install waymo_open_dataset-tf-2.11.0==1.5.0

这个指令会自动安装waymo_open_dataset-tf-2.11.0最新版本工具以及对应版本的tensorflow。


注:对于waymo数据集工具箱与tensorflow版本对应关系如下所示:

tensor版本waymo_open_dataset版本numpy版本
tensorflow 2.1tensorflow 2.6waymo_open_dataset-tf-2-1-0waymo_open_dataset-tf-2-6-01.19.2及以下
tensorflow 2.11waymo_open_dataset-tf-2.11.01.21.5

 中间的tensorflow2.7-2.10目前都是没有开发的,具体可以参考最新手册及案例:

https://github.com/waymo-research/waymo-open-dataset/blob/master/tutorial/tutorial_v2.ipynbicon-default.png?t=N7T8https://github.com/waymo-research/waymo-open-dataset/blob/master/tutorial/tutorial_v2.ipynb


3、ImportError: cannot import name 'ParamSpec' from 'typing_extensions'

答:将typing_extenstions降级:

pip install typing-extensions==4.3.0

如果低版本安装不了,可以卸载了重新装一个不同版本的就可以。

参考:

ImportError: cannot import name 'ParamSpec' from 'typing_extensions'-编程语言-CSDN问答https://ask.csdn.net/questions/7829856


4、

 ImportError: cannot import name 'TypeGuard' from 'typing_extensions'

解决方法:同问题3.


5、AttributeError: module 'numpy.typing' has no attribute 'NDArray'

答:同问题1,numpy版本太低了,需要安装numpy>=1.21.0。


6、AttributeError: module ‘spconv‘ has no attribute ‘SparseModule‘

答:spconv2.x版本太高,需要安装spconv1.2.1进行降级。

CUDA11.x以上建议安装spconv2.x高版本,除非你是CUDA10.2,否则不建议安装spconv1.2.1,过程异常麻烦,且很大概率会在spconv构建环境build时报错:Subprocess.CalledProcessError。

具体可以参考我以前的博客:

九天毕昇”云平台:python3.7+CUDA10.1+torch1.6.0+spconcv1.2.1安装OpenPCDet全流程_空持千百偈,不如吃茶去的博客-CSDN博客主要在云平台上搭建OpenPCDet环境https://blog.csdn.net/weixin_44013732/article/details/126030624


7、ValueError: need at least one array to concatenate 

:原因有很多,路径错误文件名称错误,或者是train数据集和val数据集没有放到一起去转化。


8、tensorflow.python.framework.errors_impl.NotFoundError: /home/xd/anaconda3/lib/python3.8/site-packages/waymo_open_dataset/metrics/ops/metrics_ops.so: undefined symbol: _ZNK10tensorflow8OpKernel11TraceStringERKNS_15OpKernelContextEb 

:waymo_open_dataset-tf的版本与tensorflow不匹配。


 9、ImportError: cannot import name 'detection_metrics' from 'waymo_open_dataset.metrics.python' (unknown location)

:重装一遍waymo_open_dataset-tf。 


10、ZeroDivisionError: float division by zero

问题详细描述:

Traceback (most recent call last):
  File "train.py", line 229, in <module>
    main()
  File "train.py", line 219, in main
    repeat_eval_ckpt(
  File "/home/xyy/OpenPCDet-master/tools/test.py", line 123, in repeat_eval_ckpt
    tb_dict = eval_utils.eval_one_epoch(
  File "/home/xyy/OpenPCDet-master/tools/eval_utils/eval_utils.py", line 95, in eval_one_epoch
    sec_per_example = (time.time() - start_time) / len(dataloader.dataset)
ZeroDivisionError: float division by zero

解决方法

将val数据集放到/waymo/raw_data下一起转换。

具体解释如下:

首先我们查看一下报错的信息:

 ZeroDivisionError: float division by zero代表着分母为0除不尽。

接着,我们来到报错的位置,尝试输出len(dataloader.dataset):

所以我们可以得出结论,就是dataloader没有读入数据,因此我们开始向上看信息。

查看前面有没有这样的提示:

2023-08-25 19:44:58,443   INFO  Total skipped info 202
2023-08-25 19:44:58,443   INFO  Total samples for Waymo dataset: 0

信息中说明,跳过了202个,采样到了0个waymo数据

我们来到total skipped info程序中,位置在:tools/cfgs/dataset_configs/waymo_dataset.py 

我们发现,跳过的个数是依靠num_skipped_info 来计数的,当info_path(也就是数据集路径)中没有val.txt(在data/waymo/ImageSets)所提到的数据,那么就计数加1,所以在这里我们输出info_path。

发现跳过的就是val.txt中的文件,经查发现,/raw_data文件夹下确实忘记把val数据集放进去了。

所以结论很明了,我们只需要将val的解压包进行解压,放到val.


11、TFrecord文件错误

问题描述:

在转换waymo的tfrecord遇到这么一个阴间问题,导致我查了很久:

multiprocessing.pool.RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/data/conda/envs/pcdet/lib/python3.8/multiprocessing/pool.py", line 125, in worker
    result = (True, func(*args, **kwds))
  File "/home/xyy/OpenPCDet-master/pcdet/datasets/waymo/waymo_utils.py", line 225, in process_single_sequence
    for cnt, data in enumerate(dataset):

  File "/home/user/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 787, in __next__
    return self._next_internal()
  File "/home/user/.local/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_ops.py", line 770, in _next_internal
    ret = gen_dataset_ops.iterator_get_next(
  File "/home/user/.local/lib/python3.8/site-packages/tensorflow/python/ops/gen_dataset_ops.py", line 3017, in iterator_get_next
    _ops.raise_from_not_ok_status(e, name)
  File "/home/user/.local/lib/python3.8/site-packages/tensorflow/python/framework/ops.py", line 7215, in raise_from_not_ok_status
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.DataLossError: {{function_node __wrapped__IteratorGetNext_output_types_1_device_/job:localhost/replica:0/task:0/device:CPU:0}} corrupted record at 969164982 [Op:IteratorGetNext]
"""

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data/conda/envs/pcdet/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/data/conda/envs/pcdet/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/home/xyy/OpenPCDet-master/pcdet/datasets/waymo/waymo_dataset.py", line 831, in <module>
    create_waymo_infos(

  File "/home/xyy/OpenPCDet-master/pcdet/datasets/waymo/waymo_dataset.py", line 744, in create_waymo_infos
    waymo_infos_train = dataset.get_infos(
  File "/home/xyy/OpenPCDet-master/pcdet/datasets/waymo/waymo_dataset.py", line 199, in get_infos
    sequence_infos = list(tqdm(p.imap(process_single_sequence, sample_sequence_file_list),
  
File "/data/conda/envs/pcdet/lib/python3.8/site-packages/tqdm/std.py", line 1178, in __iter__
    for obj in iterable:
  File "/data/conda/envs/pcdet/lib/python3.8/multiprocessing/pool.py", line 868, in next
    raise value
tensorflow.python.framework.errors_impl.DataLossError: {{function_node __wrapped__IteratorGetNext_output_types_1_device_/job:localhost/replica:0/task:0/device:CPU:0}} corrupted record at 969164982 [Op:IteratorGetNext]

 问题原因:就是官方数据集在转换tfrecord时出了问题,这个现象在v1.4.1更为常见,因此,我们最主要的目的就是查出问题文件删除


省事方法:

将名为segment-9175749307679169289_5933_260_5953_260_with_camera_labels.tfrecord文件删除(或先放到一个新建的文件夹里)。


详细处理步骤:

根据问题描述,我们先来到 /OpenPCDet-master/pcdet/datasets/waymo/waymo_utils.py文件的225行:

 我们在代码上面输出一下当前转换的tfrecord文件:

 接着看下一个问题,定位到了读条函数那里,我们来到/OpenPCDet-master/pcdet/datasets/waymo/waymo_dataset.py的199行

 我们将读条函数的num_workers改为1,让他一个一个转换,看具体是哪个出了问题:

 最后再运行一下生成数据指令:

python -m pcdet.datasets.waymo.waymo_dataset --func create_waymo_infos --cfg_file tools/cfgs/dataset_configs/waymo_dataset.yaml

查看到的结果如下:

 因此,我们将这个文件拿去即可接着转换:

 没问题喽~。

修改源码如下:

/OpenPCDet-master/pcdet/datasets/waymo/waymo_dataset.py:

  1. # OpenPCDet PyTorch Dataloader and Evaluation Tools for Waymo Open Dataset
  2. # Reference https://github.com/open-mmlab/OpenPCDet
  3. # Written by Shaoshuai Shi, Chaoxu Guo
  4. # All Rights Reserved.
  5. import os
  6. import pickle
  7. import copy
  8. import numpy as np
  9. import torch
  10. import multiprocessing
  11. import SharedArray
  12. import torch.distributed as dist
  13. from tqdm import tqdm
  14. from pathlib import Path
  15. from functools import partial
  16. from ...ops.roiaware_pool3d import roiaware_pool3d_utils
  17. from ...utils import box_utils, common_utils
  18. from ..dataset import DatasetTemplate
  19. class WaymoDataset(DatasetTemplate):
  20. def __init__(self, dataset_cfg, class_names, training=True, root_path=None, logger=None):
  21. super().__init__(
  22. dataset_cfg=dataset_cfg, class_names=class_names, training=training, root_path=root_path, logger=logger
  23. )
  24. self.data_path = self.root_path / self.dataset_cfg.PROCESSED_DATA_TAG
  25. self.split = self.dataset_cfg.DATA_SPLIT[self.mode]
  26. split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
  27. self.sample_sequence_list = [x.strip() for x in open(split_dir).readlines()]
  28. self.infos = []
  29. self.seq_name_to_infos = self.include_waymo_data(self.mode)
  30. self.use_shared_memory = self.dataset_cfg.get('USE_SHARED_MEMORY', False) and self.training
  31. if self.use_shared_memory:
  32. self.shared_memory_file_limit = self.dataset_cfg.get('SHARED_MEMORY_FILE_LIMIT', 0x7FFFFFFF)
  33. self.load_data_to_shared_memory()
  34. if self.dataset_cfg.get('USE_PREDBOX', False):
  35. self.pred_boxes_dict = self.load_pred_boxes_to_dict(
  36. pred_boxes_path=self.dataset_cfg.ROI_BOXES_PATH[self.mode]
  37. )
  38. else:
  39. self.pred_boxes_dict = {}
  40. def set_split(self, split):
  41. super().__init__(
  42. dataset_cfg=self.dataset_cfg, class_names=self.class_names, training=self.training,
  43. root_path=self.root_path, logger=self.logger
  44. )
  45. self.split = split
  46. split_dir = self.root_path / 'ImageSets' / (self.split + '.txt')
  47. self.sample_sequence_list = [x.strip() for x in open(split_dir).readlines()]
  48. self.infos = []
  49. self.seq_name_to_infos = self.include_waymo_data(self.mode)
  50. def include_waymo_data(self, mode):
  51. self.logger.info('Loading Waymo dataset')
  52. waymo_infos = []
  53. seq_name_to_infos = {}
  54. num_skipped_infos = 0
  55. for k in range(len(self.sample_sequence_list)):
  56. sequence_name = os.path.splitext(self.sample_sequence_list[k])[0]
  57. info_path = self.data_path / sequence_name / ('%s.pkl' % sequence_name)
  58. info_path = self.check_sequence_name_with_all_version(info_path)
  59. # print(info_path)
  60. if not info_path.exists():
  61. num_skipped_infos += 1
  62. continue
  63. with open(info_path, 'rb') as f:
  64. infos = pickle.load(f)
  65. waymo_infos.extend(infos)
  66. seq_name_to_infos[infos[0]['point_cloud']['lidar_sequence']] = infos
  67. self.infos.extend(waymo_infos[:])
  68. self.logger.info('Total skipped info %s' % num_skipped_infos)
  69. self.logger.info('Total samples for Waymo dataset: %d' % (len(waymo_infos)))
  70. if self.dataset_cfg.SAMPLED_INTERVAL[mode] > 1:
  71. sampled_waymo_infos = []
  72. for k in range(0, len(self.infos), self.dataset_cfg.SAMPLED_INTERVAL[mode]):
  73. sampled_waymo_infos.append(self.infos[k])
  74. self.infos = sampled_waymo_infos
  75. self.logger.info('Total sampled samples for Waymo dataset: %d' % len(self.infos))
  76. use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG',
  77. None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
  78. if not use_sequence_data:
  79. seq_name_to_infos = None
  80. return seq_name_to_infos
  81. def load_pred_boxes_to_dict(self, pred_boxes_path):
  82. self.logger.info(f'Loading and reorganizing pred_boxes to dict from path: {pred_boxes_path}')
  83. with open(pred_boxes_path, 'rb') as f:
  84. pred_dicts = pickle.load(f)
  85. pred_boxes_dict = {}
  86. for index, box_dict in enumerate(pred_dicts):
  87. seq_name = box_dict['frame_id'][:-4].replace('training_', '').replace('validation_', '')
  88. sample_idx = int(box_dict['frame_id'][-3:])
  89. if seq_name not in pred_boxes_dict:
  90. pred_boxes_dict[seq_name] = {}
  91. pred_labels = np.array(
  92. [self.class_names.index(box_dict['name'][k]) + 1 for k in range(box_dict['name'].shape[0])])
  93. pred_boxes = np.concatenate(
  94. (box_dict['boxes_lidar'], box_dict['score'][:, np.newaxis], pred_labels[:, np.newaxis]), axis=-1)
  95. pred_boxes_dict[seq_name][sample_idx] = pred_boxes
  96. self.logger.info(f'Predicted boxes has been loaded, total sequences: {len(pred_boxes_dict)}')
  97. return pred_boxes_dict
  98. def load_data_to_shared_memory(self):
  99. self.logger.info(f'Loading training data to shared memory (file limit={self.shared_memory_file_limit})')
  100. cur_rank, num_gpus = common_utils.get_dist_info()
  101. all_infos = self.infos[:self.shared_memory_file_limit] \
  102. if self.shared_memory_file_limit < len(self.infos) else self.infos
  103. cur_infos = all_infos[cur_rank::num_gpus]
  104. for info in cur_infos:
  105. pc_info = info['point_cloud']
  106. sequence_name = pc_info['lidar_sequence']
  107. sample_idx = pc_info['sample_idx']
  108. sa_key = f'{sequence_name}___{sample_idx}'
  109. if os.path.exists(f"/dev/shm/{sa_key}"):
  110. continue
  111. points = self.get_lidar(sequence_name, sample_idx)
  112. common_utils.sa_create(f"shm://{sa_key}", points)
  113. dist.barrier()
  114. self.logger.info('Training data has been saved to shared memory')
  115. def clean_shared_memory(self):
  116. self.logger.info(f'Clean training data from shared memory (file limit={self.shared_memory_file_limit})')
  117. cur_rank, num_gpus = common_utils.get_dist_info()
  118. all_infos = self.infos[:self.shared_memory_file_limit] \
  119. if self.shared_memory_file_limit < len(self.infos) else self.infos
  120. cur_infos = all_infos[cur_rank::num_gpus]
  121. for info in cur_infos:
  122. pc_info = info['point_cloud']
  123. sequence_name = pc_info['lidar_sequence']
  124. sample_idx = pc_info['sample_idx']
  125. sa_key = f'{sequence_name}___{sample_idx}'
  126. if not os.path.exists(f"/dev/shm/{sa_key}"):
  127. continue
  128. SharedArray.delete(f"shm://{sa_key}")
  129. if num_gpus > 1:
  130. dist.barrier()
  131. self.logger.info('Training data has been deleted from shared memory')
  132. @staticmethod
  133. def check_sequence_name_with_all_version(sequence_file):
  134. if not sequence_file.exists():
  135. found_sequence_file = sequence_file
  136. for pre_text in ['training', 'validation', 'testing']:
  137. if not sequence_file.exists():
  138. temp_sequence_file = Path(str(sequence_file).replace('segment', pre_text + '_segment'))
  139. if temp_sequence_file.exists():
  140. found_sequence_file = temp_sequence_file
  141. break
  142. if not found_sequence_file.exists():
  143. found_sequence_file = Path(str(sequence_file).replace('_with_camera_labels', ''))
  144. if found_sequence_file.exists():
  145. sequence_file = found_sequence_file
  146. return sequence_file
  147. def get_infos(self, raw_data_path, save_path, num_workers=multiprocessing.cpu_count(), has_label=True,
  148. sampled_interval=1, update_info_only=False):
  149. from . import waymo_utils
  150. print('---------------The waymo sample interval is %d, total sequecnes is %d-----------------'
  151. % (sampled_interval, len(self.sample_sequence_list)))
  152. process_single_sequence = partial(
  153. waymo_utils.process_single_sequence,
  154. save_path=save_path, sampled_interval=sampled_interval, has_label=has_label,
  155. update_info_only=update_info_only
  156. )
  157. sample_sequence_file_list = [
  158. self.check_sequence_name_with_all_version(raw_data_path / sequence_file)
  159. for sequence_file in self.sample_sequence_list
  160. ]
  161. # process_single_sequence(sample_sequence_file_list[0])
  162. # 读条
  163. num_workers=1
  164. with multiprocessing.Pool(num_workers) as p:
  165. sequence_infos = list(tqdm(p.imap(process_single_sequence, sample_sequence_file_list),
  166. total=len(sample_sequence_file_list)))
  167. all_sequences_infos = [item for infos in sequence_infos for item in infos]
  168. return all_sequences_infos
  169. def get_lidar(self, sequence_name, sample_idx):
  170. lidar_file = self.data_path / sequence_name / ('%04d.npy' % sample_idx)
  171. point_features = np.load(lidar_file) # (N, 7): [x, y, z, intensity, elongation, NLZ_flag]
  172. points_all, NLZ_flag = point_features[:, 0:5], point_features[:, 5]
  173. if not self.dataset_cfg.get('DISABLE_NLZ_FLAG_ON_POINTS', False):
  174. points_all = points_all[NLZ_flag == -1]
  175. points_all[:, 3] = np.tanh(points_all[:, 3])
  176. return points_all
  177. @staticmethod
  178. def transform_prebox_to_current(pred_boxes3d, pose_pre, pose_cur):
  179. """
  180. Args:
  181. pred_boxes3d (N, 9 or 11): [x, y, z, dx, dy, dz, raw, <vx, vy,> score, label]
  182. pose_pre (4, 4):
  183. pose_cur (4, 4):
  184. Returns:
  185. """
  186. assert pred_boxes3d.shape[-1] in [9, 11]
  187. pred_boxes3d = pred_boxes3d.copy()
  188. expand_bboxes = np.concatenate([pred_boxes3d[:, :3], np.ones((pred_boxes3d.shape[0], 1))], axis=-1)
  189. bboxes_global = np.dot(expand_bboxes, pose_pre.T)[:, :3]
  190. expand_bboxes_global = np.concatenate([bboxes_global[:, :3], np.ones((bboxes_global.shape[0], 1))], axis=-1)
  191. bboxes_pre2cur = np.dot(expand_bboxes_global, np.linalg.inv(pose_cur.T))[:, :3]
  192. pred_boxes3d[:, 0:3] = bboxes_pre2cur
  193. if pred_boxes3d.shape[-1] == 11:
  194. expand_vels = np.concatenate([pred_boxes3d[:, 7:9], np.zeros((pred_boxes3d.shape[0], 1))], axis=-1)
  195. vels_global = np.dot(expand_vels, pose_pre[:3, :3].T)
  196. vels_pre2cur = np.dot(vels_global, np.linalg.inv(pose_cur[:3, :3].T))[:, :2]
  197. pred_boxes3d[:, 7:9] = vels_pre2cur
  198. pred_boxes3d[:, 6] = pred_boxes3d[..., 6] + np.arctan2(pose_pre[..., 1, 0], pose_pre[..., 0, 0])
  199. pred_boxes3d[:, 6] = pred_boxes3d[..., 6] - np.arctan2(pose_cur[..., 1, 0], pose_cur[..., 0, 0])
  200. return pred_boxes3d
  201. @staticmethod
  202. def reorder_rois_for_refining(pred_bboxes):
  203. num_max_rois = max([len(bbox) for bbox in pred_bboxes])
  204. num_max_rois = max(1, num_max_rois) # at least one faked rois to avoid error
  205. ordered_bboxes = np.zeros([len(pred_bboxes), num_max_rois, pred_bboxes[0].shape[-1]], dtype=np.float32)
  206. for bs_idx in range(ordered_bboxes.shape[0]):
  207. ordered_bboxes[bs_idx, :len(pred_bboxes[bs_idx])] = pred_bboxes[bs_idx]
  208. return ordered_bboxes
  209. def get_sequence_data(self, info, points, sequence_name, sample_idx, sequence_cfg, load_pred_boxes=False):
  210. """
  211. Args:
  212. info:
  213. points:
  214. sequence_name:
  215. sample_idx:
  216. sequence_cfg:
  217. Returns:
  218. """
  219. def remove_ego_points(points, center_radius=1.0):
  220. mask = ~((np.abs(points[:, 0]) < center_radius) & (np.abs(points[:, 1]) < center_radius))
  221. return points[mask]
  222. def load_pred_boxes_from_dict(sequence_name, sample_idx):
  223. """
  224. boxes: (N, 11) [x, y, z, dx, dy, dn, raw, vx, vy, score, label]
  225. """
  226. sequence_name = sequence_name.replace('training_', '').replace('validation_', '')
  227. load_boxes = self.pred_boxes_dict[sequence_name][sample_idx]
  228. assert load_boxes.shape[-1] == 11
  229. load_boxes[:, 7:9] = -0.1 * load_boxes[:, 7:9] # transfer speed to negtive motion from t to t-1
  230. return load_boxes
  231. pose_cur = info['pose'].reshape((4, 4))
  232. num_pts_cur = points.shape[0]
  233. sample_idx_pre_list = np.clip(
  234. sample_idx + np.arange(sequence_cfg.SAMPLE_OFFSET[0], sequence_cfg.SAMPLE_OFFSET[1]), 0, 0x7FFFFFFF)
  235. sample_idx_pre_list = sample_idx_pre_list[::-1]
  236. if sequence_cfg.get('ONEHOT_TIMESTAMP', False):
  237. onehot_cur = np.zeros((points.shape[0], len(sample_idx_pre_list) + 1)).astype(points.dtype)
  238. onehot_cur[:, 0] = 1
  239. points = np.hstack([points, onehot_cur])
  240. else:
  241. points = np.hstack([points, np.zeros((points.shape[0], 1)).astype(points.dtype)])
  242. points_pre_all = []
  243. num_points_pre = []
  244. pose_all = [pose_cur]
  245. pred_boxes_all = []
  246. if load_pred_boxes:
  247. pred_boxes = load_pred_boxes_from_dict(sequence_name, sample_idx)
  248. pred_boxes_all.append(pred_boxes)
  249. sequence_info = self.seq_name_to_infos[sequence_name]
  250. for idx, sample_idx_pre in enumerate(sample_idx_pre_list):
  251. points_pre = self.get_lidar(sequence_name, sample_idx_pre)
  252. pose_pre = sequence_info[sample_idx_pre]['pose'].reshape((4, 4))
  253. expand_points_pre = np.concatenate([points_pre[:, :3], np.ones((points_pre.shape[0], 1))], axis=-1)
  254. points_pre_global = np.dot(expand_points_pre, pose_pre.T)[:, :3]
  255. expand_points_pre_global = np.concatenate([points_pre_global, np.ones((points_pre_global.shape[0], 1))],
  256. axis=-1)
  257. points_pre2cur = np.dot(expand_points_pre_global, np.linalg.inv(pose_cur.T))[:, :3]
  258. points_pre = np.concatenate([points_pre2cur, points_pre[:, 3:]], axis=-1)
  259. if sequence_cfg.get('ONEHOT_TIMESTAMP', False):
  260. onehot_vector = np.zeros((points_pre.shape[0], len(sample_idx_pre_list) + 1))
  261. onehot_vector[:, idx + 1] = 1
  262. points_pre = np.hstack([points_pre, onehot_vector])
  263. else:
  264. # add timestamp
  265. points_pre = np.hstack([points_pre,
  266. 0.1 * (sample_idx - sample_idx_pre) * np.ones((points_pre.shape[0], 1)).astype(
  267. points_pre.dtype)]) # one frame 0.1s
  268. points_pre = remove_ego_points(points_pre, 1.0)
  269. points_pre_all.append(points_pre)
  270. num_points_pre.append(points_pre.shape[0])
  271. pose_all.append(pose_pre)
  272. if load_pred_boxes:
  273. pose_pre = sequence_info[sample_idx_pre]['pose'].reshape((4, 4))
  274. pred_boxes = load_pred_boxes_from_dict(sequence_name, sample_idx_pre)
  275. pred_boxes = self.transform_prebox_to_current(pred_boxes, pose_pre, pose_cur)
  276. pred_boxes_all.append(pred_boxes)
  277. points = np.concatenate([points] + points_pre_all, axis=0).astype(np.float32)
  278. num_points_all = np.array([num_pts_cur] + num_points_pre).astype(np.int32)
  279. poses = np.concatenate(pose_all, axis=0).astype(np.float32)
  280. if load_pred_boxes:
  281. temp_pred_boxes = self.reorder_rois_for_refining(pred_boxes_all)
  282. pred_boxes = temp_pred_boxes[:, :, 0:9]
  283. pred_scores = temp_pred_boxes[:, :, 9]
  284. pred_labels = temp_pred_boxes[:, :, 10]
  285. else:
  286. pred_boxes = pred_scores = pred_labels = None
  287. return points, num_points_all, sample_idx_pre_list, poses, pred_boxes, pred_scores, pred_labels
  288. def __len__(self):
  289. if self._merge_all_iters_to_one_epoch:
  290. return len(self.infos) * self.total_epochs
  291. return len(self.infos)
  292. def __getitem__(self, index):
  293. if self._merge_all_iters_to_one_epoch:
  294. index = index % len(self.infos)
  295. info = copy.deepcopy(self.infos[index])
  296. pc_info = info['point_cloud']
  297. sequence_name = pc_info['lidar_sequence']
  298. sample_idx = pc_info['sample_idx']
  299. input_dict = {
  300. 'sample_idx': sample_idx
  301. }
  302. if self.use_shared_memory and index < self.shared_memory_file_limit:
  303. sa_key = f'{sequence_name}___{sample_idx}'
  304. points = SharedArray.attach(f"shm://{sa_key}").copy()
  305. else:
  306. points = self.get_lidar(sequence_name, sample_idx)
  307. if self.dataset_cfg.get('SEQUENCE_CONFIG', None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED:
  308. points, num_points_all, sample_idx_pre_list, poses, pred_boxes, pred_scores, pred_labels = self.get_sequence_data(
  309. info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG,
  310. load_pred_boxes=self.dataset_cfg.get('USE_PREDBOX', False)
  311. )
  312. input_dict['poses'] = poses
  313. if self.dataset_cfg.get('USE_PREDBOX', False):
  314. input_dict.update({
  315. 'roi_boxes': pred_boxes,
  316. 'roi_scores': pred_scores,
  317. 'roi_labels': pred_labels,
  318. })
  319. input_dict.update({
  320. 'points': points,
  321. 'frame_id': info['frame_id'],
  322. })
  323. if 'annos' in info:
  324. annos = info['annos']
  325. annos = common_utils.drop_info_with_name(annos, name='unknown')
  326. if self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False):
  327. gt_boxes_lidar = box_utils.boxes3d_kitti_fakelidar_to_lidar(annos['gt_boxes_lidar'])
  328. else:
  329. gt_boxes_lidar = annos['gt_boxes_lidar']
  330. if self.dataset_cfg.get('TRAIN_WITH_SPEED', False):
  331. assert gt_boxes_lidar.shape[-1] == 9
  332. else:
  333. gt_boxes_lidar = gt_boxes_lidar[:, 0:7]
  334. if self.training and self.dataset_cfg.get('FILTER_EMPTY_BOXES_FOR_TRAIN', False):
  335. mask = (annos['num_points_in_gt'] > 0) # filter empty boxes
  336. annos['name'] = annos['name'][mask]
  337. gt_boxes_lidar = gt_boxes_lidar[mask]
  338. annos['num_points_in_gt'] = annos['num_points_in_gt'][mask]
  339. input_dict.update({
  340. 'gt_names': annos['name'],
  341. 'gt_boxes': gt_boxes_lidar,
  342. 'num_points_in_gt': annos.get('num_points_in_gt', None)
  343. })
  344. data_dict = self.prepare_data(data_dict=input_dict)
  345. data_dict['metadata'] = info.get('metadata', info['frame_id'])
  346. data_dict.pop('num_points_in_gt', None)
  347. return data_dict
  348. def evaluation(self, det_annos, class_names, **kwargs):
  349. if 'annos' not in self.infos[0].keys():
  350. return 'No ground-truth boxes for evaluation', {}
  351. def kitti_eval(eval_det_annos, eval_gt_annos):
  352. from ..kitti.kitti_object_eval_python import eval as kitti_eval
  353. from ..kitti import kitti_utils
  354. map_name_to_kitti = {
  355. 'Vehicle': 'Car',
  356. 'Pedestrian': 'Pedestrian',
  357. 'Cyclist': 'Cyclist',
  358. 'Sign': 'Sign',
  359. 'Car': 'Car'
  360. }
  361. kitti_utils.transform_annotations_to_kitti_format(eval_det_annos, map_name_to_kitti=map_name_to_kitti)
  362. kitti_utils.transform_annotations_to_kitti_format(
  363. eval_gt_annos, map_name_to_kitti=map_name_to_kitti,
  364. info_with_fakelidar=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
  365. )
  366. kitti_class_names = [map_name_to_kitti[x] for x in class_names]
  367. ap_result_str, ap_dict = kitti_eval.get_official_eval_result(
  368. gt_annos=eval_gt_annos, dt_annos=eval_det_annos, current_classes=kitti_class_names
  369. )
  370. return ap_result_str, ap_dict
  371. def waymo_eval(eval_det_annos, eval_gt_annos):
  372. from .waymo_eval import OpenPCDetWaymoDetectionMetricsEstimator
  373. eval = OpenPCDetWaymoDetectionMetricsEstimator()
  374. ap_dict = eval.waymo_evaluation(
  375. eval_det_annos, eval_gt_annos, class_name=class_names,
  376. distance_thresh=1000, fake_gt_infos=self.dataset_cfg.get('INFO_WITH_FAKELIDAR', False)
  377. )
  378. ap_result_str = '\n'
  379. for key in ap_dict:
  380. ap_dict[key] = ap_dict[key][0]
  381. ap_result_str += '%s: %.4f \n' % (key, ap_dict[key])
  382. return ap_result_str, ap_dict
  383. eval_det_annos = copy.deepcopy(det_annos)
  384. eval_gt_annos = [copy.deepcopy(info['annos']) for info in self.infos]
  385. if kwargs['eval_metric'] == 'kitti':
  386. ap_result_str, ap_dict = kitti_eval(eval_det_annos, eval_gt_annos)
  387. elif kwargs['eval_metric'] == 'waymo':
  388. ap_result_str, ap_dict = waymo_eval(eval_det_annos, eval_gt_annos)
  389. else:
  390. raise NotImplementedError
  391. return ap_result_str, ap_dict
  392. def create_groundtruth_database(self, info_path, save_path, used_classes=None, split='train', sampled_interval=10,
  393. processed_data_tag=None):
  394. use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG',
  395. None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
  396. if use_sequence_data:
  397. st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], \
  398. self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1]
  399. self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0] = min(-4,
  400. st_frame) # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
  401. st_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0]
  402. database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s' % (
  403. processed_data_tag, split, sampled_interval, st_frame, ed_frame))
  404. db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s.pkl' % (
  405. processed_data_tag, split, sampled_interval, st_frame, ed_frame))
  406. db_data_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_global.npy' % (
  407. processed_data_tag, split, sampled_interval, st_frame, ed_frame))
  408. else:
  409. database_save_path = save_path / (
  410. '%s_gt_database_%s_sampled_%d' % (processed_data_tag, split, sampled_interval))
  411. db_info_save_path = save_path / (
  412. '%s_waymo_dbinfos_%s_sampled_%d.pkl' % (processed_data_tag, split, sampled_interval))
  413. db_data_save_path = save_path / (
  414. '%s_gt_database_%s_sampled_%d_global.npy' % (processed_data_tag, split, sampled_interval))
  415. database_save_path.mkdir(parents=True, exist_ok=True)
  416. all_db_infos = {}
  417. with open(info_path, 'rb') as f:
  418. infos = pickle.load(f)
  419. point_offset_cnt = 0
  420. stacked_gt_points = []
  421. for k in tqdm(range(0, len(infos), sampled_interval)):
  422. # print('gt_database sample: %d/%d' % (k + 1, len(infos)))
  423. info = infos[k]
  424. pc_info = info['point_cloud']
  425. sequence_name = pc_info['lidar_sequence']
  426. sample_idx = pc_info['sample_idx']
  427. points = self.get_lidar(sequence_name, sample_idx)
  428. if use_sequence_data:
  429. points, num_points_all, sample_idx_pre_list, _, _, _, _ = self.get_sequence_data(
  430. info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG
  431. )
  432. annos = info['annos']
  433. names = annos['name']
  434. difficulty = annos['difficulty']
  435. gt_boxes = annos['gt_boxes_lidar']
  436. if k % 4 != 0 and len(names) > 0:
  437. mask = (names == 'Vehicle')
  438. names = names[~mask]
  439. difficulty = difficulty[~mask]
  440. gt_boxes = gt_boxes[~mask]
  441. if k % 2 != 0 and len(names) > 0:
  442. mask = (names == 'Pedestrian')
  443. names = names[~mask]
  444. difficulty = difficulty[~mask]
  445. gt_boxes = gt_boxes[~mask]
  446. num_obj = gt_boxes.shape[0]
  447. if num_obj == 0:
  448. continue
  449. box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
  450. torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
  451. torch.from_numpy(gt_boxes[:, 0:7]).unsqueeze(dim=0).float().cuda()
  452. ).long().squeeze(dim=0).cpu().numpy()
  453. for i in range(num_obj):
  454. filename = '%s_%04d_%s_%d.bin' % (sequence_name, sample_idx, names[i], i)
  455. filepath = database_save_path / filename
  456. gt_points = points[box_idxs_of_pts == i]
  457. gt_points[:, :3] -= gt_boxes[i, :3]
  458. if (used_classes is None) or names[i] in used_classes:
  459. gt_points = gt_points.astype(np.float32)
  460. assert gt_points.dtype == np.float32
  461. with open(filepath, 'w') as f:
  462. gt_points.tofile(f)
  463. db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
  464. db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
  465. 'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
  466. 'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i]}
  467. # it will be used if you choose to use shared memory for gt sampling
  468. stacked_gt_points.append(gt_points)
  469. db_info['global_data_offset'] = [point_offset_cnt, point_offset_cnt + gt_points.shape[0]]
  470. point_offset_cnt += gt_points.shape[0]
  471. if names[i] in all_db_infos:
  472. all_db_infos[names[i]].append(db_info)
  473. else:
  474. all_db_infos[names[i]] = [db_info]
  475. for k, v in all_db_infos.items():
  476. print('Database %s: %d' % (k, len(v)))
  477. with open(db_info_save_path, 'wb') as f:
  478. pickle.dump(all_db_infos, f)
  479. # it will be used if you choose to use shared memory for gt sampling
  480. stacked_gt_points = np.concatenate(stacked_gt_points, axis=0)
  481. np.save(db_data_save_path, stacked_gt_points)
  482. def create_gt_database_of_single_scene(self, info_with_idx, database_save_path=None, use_sequence_data=False,
  483. used_classes=None,
  484. total_samples=0, use_cuda=False, crop_gt_with_tail=False):
  485. info, info_idx = info_with_idx
  486. print('gt_database sample: %d/%d' % (info_idx, total_samples))
  487. all_db_infos = {}
  488. pc_info = info['point_cloud']
  489. sequence_name = pc_info['lidar_sequence']
  490. sample_idx = pc_info['sample_idx']
  491. points = self.get_lidar(sequence_name, sample_idx)
  492. if use_sequence_data:
  493. points, num_points_all, sample_idx_pre_list, _, _, _, _ = self.get_sequence_data(
  494. info, points, sequence_name, sample_idx, self.dataset_cfg.SEQUENCE_CONFIG
  495. )
  496. annos = info['annos']
  497. names = annos['name']
  498. difficulty = annos['difficulty']
  499. gt_boxes = annos['gt_boxes_lidar']
  500. if info_idx % 4 != 0 and len(names) > 0:
  501. mask = (names == 'Vehicle')
  502. names = names[~mask]
  503. difficulty = difficulty[~mask]
  504. gt_boxes = gt_boxes[~mask]
  505. if info_idx % 2 != 0 and len(names) > 0:
  506. mask = (names == 'Pedestrian')
  507. names = names[~mask]
  508. difficulty = difficulty[~mask]
  509. gt_boxes = gt_boxes[~mask]
  510. num_obj = gt_boxes.shape[0]
  511. if num_obj == 0:
  512. return {}
  513. if use_sequence_data and crop_gt_with_tail:
  514. assert gt_boxes.shape[1] == 9
  515. speed = gt_boxes[:, 7:9]
  516. sequence_cfg = self.dataset_cfg.SEQUENCE_CONFIG
  517. assert sequence_cfg.SAMPLE_OFFSET[1] == 0
  518. assert sequence_cfg.SAMPLE_OFFSET[0] < 0
  519. num_frames = sequence_cfg.SAMPLE_OFFSET[1] - sequence_cfg.SAMPLE_OFFSET[0] + 1
  520. assert num_frames > 1
  521. latest_center = gt_boxes[:, 0:2]
  522. oldest_center = latest_center - speed * (num_frames - 1) * 0.1
  523. new_center = (latest_center + oldest_center) * 0.5
  524. new_length = gt_boxes[:, 3] + np.linalg.norm(latest_center - oldest_center, axis=-1)
  525. gt_boxes_crop = gt_boxes.copy()
  526. gt_boxes_crop[:, 0:2] = new_center
  527. gt_boxes_crop[:, 3] = new_length
  528. else:
  529. gt_boxes_crop = gt_boxes
  530. if use_cuda:
  531. box_idxs_of_pts = roiaware_pool3d_utils.points_in_boxes_gpu(
  532. torch.from_numpy(points[:, 0:3]).unsqueeze(dim=0).float().cuda(),
  533. torch.from_numpy(gt_boxes_crop[:, 0:7]).unsqueeze(dim=0).float().cuda()
  534. ).long().squeeze(dim=0).cpu().numpy()
  535. else:
  536. box_point_mask = roiaware_pool3d_utils.points_in_boxes_cpu(
  537. torch.from_numpy(points[:, 0:3]).float(),
  538. torch.from_numpy(gt_boxes_crop[:, 0:7]).float()
  539. ).long().numpy() # (num_boxes, num_points)
  540. for i in range(num_obj):
  541. filename = '%s_%04d_%s_%d.bin' % (sequence_name, sample_idx, names[i], i)
  542. filepath = database_save_path / filename
  543. if use_cuda:
  544. gt_points = points[box_idxs_of_pts == i]
  545. else:
  546. gt_points = points[box_point_mask[i] > 0]
  547. gt_points[:, :3] -= gt_boxes[i, :3]
  548. if (used_classes is None) or names[i] in used_classes:
  549. gt_points = gt_points.astype(np.float32)
  550. assert gt_points.dtype == np.float32
  551. with open(filepath, 'w') as f:
  552. gt_points.tofile(f)
  553. db_path = str(filepath.relative_to(self.root_path)) # gt_database/xxxxx.bin
  554. db_info = {'name': names[i], 'path': db_path, 'sequence_name': sequence_name,
  555. 'sample_idx': sample_idx, 'gt_idx': i, 'box3d_lidar': gt_boxes[i],
  556. 'num_points_in_gt': gt_points.shape[0], 'difficulty': difficulty[i],
  557. 'box3d_crop': gt_boxes_crop[i]}
  558. if names[i] in all_db_infos:
  559. all_db_infos[names[i]].append(db_info)
  560. else:
  561. all_db_infos[names[i]] = [db_info]
  562. return all_db_infos
  563. def create_groundtruth_database_parallel(self, info_path, save_path, used_classes=None, split='train',
  564. sampled_interval=10,
  565. processed_data_tag=None, num_workers=16, crop_gt_with_tail=False):
  566. use_sequence_data = self.dataset_cfg.get('SEQUENCE_CONFIG',
  567. None) is not None and self.dataset_cfg.SEQUENCE_CONFIG.ENABLED
  568. if use_sequence_data:
  569. st_frame, ed_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0], \
  570. self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[1]
  571. self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0] = min(-4,
  572. st_frame) # at least we use 5 frames for generating gt database to support various sequence configs (<= 5 frames)
  573. st_frame = self.dataset_cfg.SEQUENCE_CONFIG.SAMPLE_OFFSET[0]
  574. database_save_path = save_path / ('%s_gt_database_%s_sampled_%d_multiframe_%s_to_%s_%sparallel' % (
  575. processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
  576. db_info_save_path = save_path / ('%s_waymo_dbinfos_%s_sampled_%d_multiframe_%s_to_%s_%sparallel.pkl' % (
  577. processed_data_tag, split, sampled_interval, st_frame, ed_frame, 'tail_' if crop_gt_with_tail else ''))
  578. else:
  579. database_save_path = save_path / (
  580. '%s_gt_database_%s_sampled_%d_parallel' % (processed_data_tag, split, sampled_interval))
  581. db_info_save_path = save_path / (
  582. '%s_waymo_dbinfos_%s_sampled_%d_parallel.pkl' % (processed_data_tag, split, sampled_interval))
  583. database_save_path.mkdir(parents=True, exist_ok=True)
  584. with open(info_path, 'rb') as f:
  585. infos = pickle.load(f)
  586. print(f'Number workers: {num_workers}')
  587. create_gt_database_of_single_scene = partial(
  588. self.create_gt_database_of_single_scene,
  589. use_sequence_data=use_sequence_data, database_save_path=database_save_path,
  590. used_classes=used_classes, total_samples=len(infos), use_cuda=False,
  591. crop_gt_with_tail=crop_gt_with_tail
  592. )
  593. # create_gt_database_of_single_scene((infos[300], 0))
  594. with multiprocessing.Pool(num_workers) as p:
  595. all_db_infos_list = list(p.map(create_gt_database_of_single_scene, zip(infos, np.arange(len(infos)))))
  596. all_db_infos = {}
  597. for cur_db_infos in all_db_infos_list:
  598. for key, val in cur_db_infos.items():
  599. if key not in all_db_infos:
  600. all_db_infos[key] = val
  601. else:
  602. all_db_infos[key].extend(val)
  603. for k, v in all_db_infos.items():
  604. print('Database %s: %d' % (k, len(v)))
  605. with open(db_info_save_path, 'wb') as f:
  606. pickle.dump(all_db_infos, f)
  607. def create_waymo_infos(dataset_cfg, class_names, data_path, save_path,
  608. raw_data_tag='raw_data', processed_data_tag='waymo_processed_data',
  609. workers=min(16, multiprocessing.cpu_count()), update_info_only=False):
  610. dataset = WaymoDataset(
  611. dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
  612. training=False, logger=common_utils.create_logger()
  613. )
  614. train_split, val_split = 'train', 'val'
  615. train_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, train_split))
  616. val_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, val_split))
  617. os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
  618. print('---------------Start to generate data infos---------------')
  619. dataset.set_split(train_split)
  620. waymo_infos_train = dataset.get_infos(
  621. raw_data_path=data_path / raw_data_tag,
  622. save_path=save_path / processed_data_tag, num_workers=workers, has_label=True,
  623. sampled_interval=1, update_info_only=update_info_only
  624. )
  625. with open(train_filename, 'wb') as f:
  626. pickle.dump(waymo_infos_train, f)
  627. print('----------------Waymo info train file is saved to %s----------------' % train_filename)
  628. dataset.set_split(val_split)
  629. waymo_infos_val = dataset.get_infos(
  630. raw_data_path=data_path / raw_data_tag,
  631. save_path=save_path / processed_data_tag, num_workers=workers, has_label=True,
  632. sampled_interval=1, update_info_only=update_info_only
  633. )
  634. with open(val_filename, 'wb') as f:
  635. pickle.dump(waymo_infos_val, f)
  636. print('----------------Waymo info val file is saved to %s----------------' % val_filename)
  637. if update_info_only:
  638. return
  639. print('---------------Start create groundtruth database for data augmentation---------------')
  640. os.environ["CUDA_VISIBLE_DEVICES"] = "0"
  641. dataset.set_split(train_split)
  642. dataset.create_groundtruth_database(
  643. info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
  644. used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
  645. )
  646. print('---------------Data preparation Done---------------')
  647. def create_waymo_gt_database(
  648. dataset_cfg, class_names, data_path, save_path, processed_data_tag='waymo_processed_data',
  649. workers=min(16, multiprocessing.cpu_count()), use_parallel=False, crop_gt_with_tail=False):
  650. dataset = WaymoDataset(
  651. dataset_cfg=dataset_cfg, class_names=class_names, root_path=data_path,
  652. training=False, logger=common_utils.create_logger()
  653. )
  654. train_split = 'train'
  655. train_filename = save_path / ('%s_infos_%s.pkl' % (processed_data_tag, train_split))
  656. print('---------------Start create groundtruth database for data augmentation---------------')
  657. dataset.set_split(train_split)
  658. # print("train_split:")
  659. # print(train_split)
  660. if use_parallel:
  661. dataset.create_groundtruth_database_parallel(
  662. info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
  663. used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag,
  664. num_workers=workers, crop_gt_with_tail=crop_gt_with_tail
  665. )
  666. else:
  667. dataset.create_groundtruth_database(
  668. info_path=train_filename, save_path=save_path, split='train', sampled_interval=1,
  669. used_classes=['Vehicle', 'Pedestrian', 'Cyclist'], processed_data_tag=processed_data_tag
  670. )
  671. print('---------------Data preparation Done---------------')
  672. if __name__ == '__main__':
  673. import argparse
  674. import yaml
  675. from easydict import EasyDict
  676. parser = argparse.ArgumentParser(description='arg parser')
  677. parser.add_argument('--cfg_file', type=str, default=None, help='specify the config of dataset')
  678. parser.add_argument('--func', type=str, default='create_waymo_infos', help='')
  679. parser.add_argument('--processed_data_tag', type=str, default='waymo_processed_data_v0_5_0', help='')
  680. parser.add_argument('--update_info_only', action='store_true', default=False, help='')
  681. parser.add_argument('--use_parallel', action='store_true', default=False, help='')
  682. parser.add_argument('--wo_crop_gt_with_tail', action='store_true', default=False, help='')
  683. parser.add_argument('--data_path', default=None, help='')
  684. args = parser.parse_args()
  685. if args.data_path is not None:
  686. ROOT_DIR = (Path(args.data_path)).resolve()
  687. else:
  688. ROOT_DIR = (Path(__file__).resolve().parent / '../../../').resolve() / 'data' / 'waymo'
  689. # ROOT_DIR = (Path(self.dataset_cfg.DATA_PATH)).resolve()
  690. if args.func == 'create_waymo_infos':
  691. try:
  692. yaml_config = yaml.safe_load(open(args.cfg_file), Loader=yaml.FullLoader)
  693. except:
  694. yaml_config = yaml.safe_load(open(args.cfg_file))
  695. dataset_cfg = EasyDict(yaml_config)
  696. dataset_cfg.PROCESSED_DATA_TAG = args.processed_data_tag
  697. create_waymo_infos(
  698. dataset_cfg=dataset_cfg,
  699. class_names=['Vehicle', 'Pedestrian', 'Cyclist'],
  700. data_path=ROOT_DIR,
  701. save_path=ROOT_DIR,
  702. raw_data_tag='raw_data',
  703. processed_data_tag=args.processed_data_tag,
  704. update_info_only=args.update_info_only
  705. )
  706. elif args.func == 'create_waymo_gt_database':
  707. try:
  708. yaml_config = yaml.safe_load(open(args.cfg_file), Loader=yaml.FullLoader)
  709. except:
  710. yaml_config = yaml.safe_load(open(args.cfg_file))
  711. dataset_cfg = EasyDict(yaml_config)
  712. dataset_cfg.PROCESSED_DATA_TAG = args.processed_data_tag
  713. create_waymo_gt_database(
  714. dataset_cfg=dataset_cfg,
  715. class_names=['Vehicle', 'Pedestrian', 'Cyclist'],
  716. data_path=ROOT_DIR,
  717. save_path=ROOT_DIR,
  718. processed_data_tag=args.processed_data_tag,
  719. use_parallel=args.use_parallel,
  720. crop_gt_with_tail=not args.wo_crop_gt_with_tail
  721. )
  722. else:
  723. raise NotImplementedError

 /OpenPCDet-master/pcdet/datasets/waymo/waymo_utils.py:

  1. # OpenPCDet PyTorch Dataloader and Evaluation Tools for Waymo Open Dataset
  2. # Reference https://github.com/open-mmlab/OpenPCDet
  3. # Written by Shaoshuai Shi, Chaoxu Guo
  4. # All Rights Reserved 2019-2020.
  5. import os
  6. import shutil
  7. import pickle
  8. import numpy as np
  9. from ...utils import common_utils
  10. import tensorflow as tf
  11. from waymo_open_dataset.utils import frame_utils, transform_utils, range_image_utils
  12. from waymo_open_dataset import dataset_pb2
  13. try:
  14. tf.enable_eager_execution()
  15. except:
  16. pass
  17. WAYMO_CLASSES = ['unknown', 'Vehicle', 'Pedestrian', 'Sign', 'Cyclist']
  18. def generate_labels(frame, pose):
  19. obj_name, difficulty, dimensions, locations, heading_angles = [], [], [], [], []
  20. tracking_difficulty, speeds, accelerations, obj_ids = [], [], [], []
  21. num_points_in_gt = []
  22. laser_labels = frame.laser_labels
  23. for i in range(len(laser_labels)):
  24. box = laser_labels[i].box
  25. class_ind = laser_labels[i].type
  26. loc = [box.center_x, box.center_y, box.center_z]
  27. heading_angles.append(box.heading)
  28. obj_name.append(WAYMO_CLASSES[class_ind])
  29. difficulty.append(laser_labels[i].detection_difficulty_level)
  30. tracking_difficulty.append(laser_labels[i].tracking_difficulty_level)
  31. dimensions.append([box.length, box.width, box.height]) # lwh in unified coordinate of OpenPCDet
  32. locations.append(loc)
  33. obj_ids.append(laser_labels[i].id)
  34. num_points_in_gt.append(laser_labels[i].num_lidar_points_in_box)
  35. speeds.append([laser_labels[i].metadata.speed_x, laser_labels[i].metadata.speed_y])
  36. accelerations.append([laser_labels[i].metadata.accel_x, laser_labels[i].metadata.accel_y])
  37. annotations = {}
  38. annotations['name'] = np.array(obj_name)
  39. annotations['difficulty'] = np.array(difficulty)
  40. annotations['dimensions'] = np.array(dimensions)
  41. annotations['location'] = np.array(locations)
  42. annotations['heading_angles'] = np.array(heading_angles)
  43. annotations['obj_ids'] = np.array(obj_ids)
  44. annotations['tracking_difficulty'] = np.array(tracking_difficulty)
  45. annotations['num_points_in_gt'] = np.array(num_points_in_gt)
  46. annotations['speed_global'] = np.array(speeds)
  47. annotations['accel_global'] = np.array(accelerations)
  48. annotations = common_utils.drop_info_with_name(annotations, name='unknown')
  49. if annotations['name'].__len__() > 0:
  50. global_speed = np.pad(annotations['speed_global'], ((0, 0), (0, 1)), mode='constant', constant_values=0) # (N, 3)
  51. speed = np.dot(global_speed, np.linalg.inv(pose[:3, :3].T))
  52. speed = speed[:, :2]
  53. gt_boxes_lidar = np.concatenate([
  54. annotations['location'], annotations['dimensions'], annotations['heading_angles'][..., np.newaxis], speed],
  55. axis=1
  56. )
  57. else:
  58. gt_boxes_lidar = np.zeros((0, 9))
  59. annotations['gt_boxes_lidar'] = gt_boxes_lidar
  60. return annotations
  61. def convert_range_image_to_point_cloud(frame, range_images, camera_projections, range_image_top_pose, ri_index=(0, 1)):
  62. """
  63. Modified from the codes of Waymo Open Dataset.
  64. Convert range images to point cloud.
  65. Args:
  66. frame: open dataset frame
  67. range_images: A dict of {laser_name, [range_image_first_return, range_image_second_return]}.
  68. camera_projections: A dict of {laser_name,
  69. [camera_projection_from_first_return, camera_projection_from_second_return]}.
  70. range_image_top_pose: range image pixel pose for top lidar.
  71. ri_index: 0 for the first return, 1 for the second return.
  72. Returns:
  73. points: {[N, 3]} list of 3d lidar points of length 5 (number of lidars).
  74. cp_points: {[N, 6]} list of camera projections of length 5 (number of lidars).
  75. """
  76. calibrations = sorted(frame.context.laser_calibrations, key=lambda c: c.name)
  77. points = []
  78. cp_points = []
  79. points_NLZ = []
  80. points_intensity = []
  81. points_elongation = []
  82. frame_pose = tf.convert_to_tensor(np.reshape(np.array(frame.pose.transform), [4, 4]))
  83. # [H, W, 6]
  84. range_image_top_pose_tensor = tf.reshape(
  85. tf.convert_to_tensor(range_image_top_pose.data), range_image_top_pose.shape.dims
  86. )
  87. # [H, W, 3, 3]
  88. range_image_top_pose_tensor_rotation = transform_utils.get_rotation_matrix(
  89. range_image_top_pose_tensor[..., 0], range_image_top_pose_tensor[..., 1],
  90. range_image_top_pose_tensor[..., 2])
  91. range_image_top_pose_tensor_translation = range_image_top_pose_tensor[..., 3:]
  92. range_image_top_pose_tensor = transform_utils.get_transform(
  93. range_image_top_pose_tensor_rotation,
  94. range_image_top_pose_tensor_translation)
  95. for c in calibrations:
  96. points_single, cp_points_single, points_NLZ_single, points_intensity_single, points_elongation_single \
  97. = [], [], [], [], []
  98. for cur_ri_index in ri_index:
  99. range_image = range_images[c.name][cur_ri_index]
  100. if len(c.beam_inclinations) == 0: # pylint: disable=g-explicit-length-test
  101. beam_inclinations = range_image_utils.compute_inclination(
  102. tf.constant([c.beam_inclination_min, c.beam_inclination_max]),
  103. height=range_image.shape.dims[0])
  104. else:
  105. beam_inclinations = tf.constant(c.beam_inclinations)
  106. beam_inclinations = tf.reverse(beam_inclinations, axis=[-1])
  107. extrinsic = np.reshape(np.array(c.extrinsic.transform), [4, 4])
  108. range_image_tensor = tf.reshape(
  109. tf.convert_to_tensor(range_image.data), range_image.shape.dims)
  110. pixel_pose_local = None
  111. frame_pose_local = None
  112. if c.name == dataset_pb2.LaserName.TOP:
  113. pixel_pose_local = range_image_top_pose_tensor
  114. pixel_pose_local = tf.expand_dims(pixel_pose_local, axis=0)
  115. frame_pose_local = tf.expand_dims(frame_pose, axis=0)
  116. range_image_mask = range_image_tensor[..., 0] > 0
  117. range_image_NLZ = range_image_tensor[..., 3]
  118. range_image_intensity = range_image_tensor[..., 1]
  119. range_image_elongation = range_image_tensor[..., 2]
  120. range_image_cartesian = range_image_utils.extract_point_cloud_from_range_image(
  121. tf.expand_dims(range_image_tensor[..., 0], axis=0),
  122. tf.expand_dims(extrinsic, axis=0),
  123. tf.expand_dims(tf.convert_to_tensor(beam_inclinations), axis=0),
  124. pixel_pose=pixel_pose_local,
  125. frame_pose=frame_pose_local)
  126. range_image_cartesian = tf.squeeze(range_image_cartesian, axis=0)
  127. points_tensor = tf.gather_nd(range_image_cartesian,
  128. tf.where(range_image_mask))
  129. points_NLZ_tensor = tf.gather_nd(range_image_NLZ, tf.compat.v1.where(range_image_mask))
  130. points_intensity_tensor = tf.gather_nd(range_image_intensity, tf.compat.v1.where(range_image_mask))
  131. points_elongation_tensor = tf.gather_nd(range_image_elongation, tf.compat.v1.where(range_image_mask))
  132. cp = camera_projections[c.name][0]
  133. cp_tensor = tf.reshape(tf.convert_to_tensor(cp.data), cp.shape.dims)
  134. cp_points_tensor = tf.gather_nd(cp_tensor, tf.where(range_image_mask))
  135. points_single.append(points_tensor.numpy())
  136. cp_points_single.append(cp_points_tensor.numpy())
  137. points_NLZ_single.append(points_NLZ_tensor.numpy())
  138. points_intensity_single.append(points_intensity_tensor.numpy())
  139. points_elongation_single.append(points_elongation_tensor.numpy())
  140. points.append(np.concatenate(points_single, axis=0))
  141. cp_points.append(np.concatenate(cp_points_single, axis=0))
  142. points_NLZ.append(np.concatenate(points_NLZ_single, axis=0))
  143. points_intensity.append(np.concatenate(points_intensity_single, axis=0))
  144. points_elongation.append(np.concatenate(points_elongation_single, axis=0))
  145. return points, cp_points, points_NLZ, points_intensity, points_elongation
  146. def save_lidar_points(frame, cur_save_path, use_two_returns=True):
  147. ret_outputs = frame_utils.parse_range_image_and_camera_projection(frame)
  148. if len(ret_outputs) == 4:
  149. range_images, camera_projections, seg_labels, range_image_top_pose = ret_outputs
  150. else:
  151. assert len(ret_outputs) == 3
  152. range_images, camera_projections, range_image_top_pose = ret_outputs
  153. points, cp_points, points_in_NLZ_flag, points_intensity, points_elongation = convert_range_image_to_point_cloud(
  154. frame, range_images, camera_projections, range_image_top_pose, ri_index=(0, 1) if use_two_returns else (0,)
  155. )
  156. # 3d points in vehicle frame.
  157. points_all = np.concatenate(points, axis=0)
  158. points_in_NLZ_flag = np.concatenate(points_in_NLZ_flag, axis=0).reshape(-1, 1)
  159. points_intensity = np.concatenate(points_intensity, axis=0).reshape(-1, 1)
  160. points_elongation = np.concatenate(points_elongation, axis=0).reshape(-1, 1)
  161. num_points_of_each_lidar = [point.shape[0] for point in points]
  162. save_points = np.concatenate([
  163. points_all, points_intensity, points_elongation, points_in_NLZ_flag
  164. ], axis=-1).astype(np.float32)
  165. np.save(cur_save_path, save_points)
  166. # print('saving to ', cur_save_path)
  167. return num_points_of_each_lidar
  168. def process_single_sequence(sequence_file, save_path, sampled_interval, has_label=True, use_two_returns=True, update_info_only=False):
  169. sequence_name = os.path.splitext(os.path.basename(sequence_file))[0]
  170. # print('Load record (sampled_interval=%d): %s' % (sampled_interval, sequence_name))
  171. if not sequence_file.exists():
  172. print('NotFoundError: %s' % sequence_file)
  173. return []
  174. dataset = tf.data.TFRecordDataset(str(sequence_file), compression_type='')
  175. cur_save_dir = save_path / sequence_name
  176. cur_save_dir.mkdir(parents=True, exist_ok=True)
  177. pkl_file = cur_save_dir / ('%s.pkl' % sequence_name)
  178. sequence_infos = []
  179. if pkl_file.exists():
  180. sequence_infos = pickle.load(open(pkl_file, 'rb'))
  181. sequence_infos_old = None
  182. if not update_info_only:
  183. print('Skip sequence since it has been processed before: %s' % pkl_file)
  184. return sequence_infos
  185. else:
  186. sequence_infos_old = sequence_infos
  187. sequence_infos = []
  188. # shutil.move(pkl_file, pkl_file + '/1')
  189. print("##########################sequence_file#############################")
  190. print(sequence_file)
  191. for cnt, data in enumerate(dataset):
  192. if cnt % sampled_interval != 0:
  193. continue
  194. # print(sequence_name, cnt)
  195. frame = dataset_pb2.Frame()
  196. frame.ParseFromString(bytearray(data.numpy()))
  197. info = {}
  198. pc_info = {'num_features': 5, 'lidar_sequence': sequence_name, 'sample_idx': cnt}
  199. info['point_cloud'] = pc_info
  200. info['frame_id'] = sequence_name + ('_%03d' % cnt)
  201. info['metadata'] = {
  202. 'context_name': frame.context.name,
  203. 'timestamp_micros': frame.timestamp_micros
  204. }
  205. image_info = {}
  206. for j in range(5):
  207. width = frame.context.camera_calibrations[j].width
  208. height = frame.context.camera_calibrations[j].height
  209. image_info.update({'image_shape_%d' % j: (height, width)})
  210. info['image'] = image_info
  211. pose = np.array(frame.pose.transform, dtype=np.float32).reshape(4, 4)
  212. info['pose'] = pose
  213. if has_label:
  214. annotations = generate_labels(frame, pose=pose)
  215. info['annos'] = annotations
  216. if update_info_only and sequence_infos_old is not None:
  217. assert info['frame_id'] == sequence_infos_old[cnt]['frame_id']
  218. num_points_of_each_lidar = sequence_infos_old[cnt]['num_points_of_each_lidar']
  219. else:
  220. num_points_of_each_lidar = save_lidar_points(
  221. frame, cur_save_dir / ('%04d.npy' % cnt), use_two_returns=use_two_returns
  222. )
  223. info['num_points_of_each_lidar'] = num_points_of_each_lidar
  224. sequence_infos.append(info)
  225. with open(pkl_file, 'wb') as f:
  226. pickle.dump(sequence_infos, f)
  227. print('Infos are saved to (sampled_interval=%d): %s' % (sampled_interval, pkl_file))
  228. return sequence_infos

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/325202
推荐阅读
相关标签
  

闽ICP备14008679号