当前位置:   article > 正文

【项目分享】使用 PointNet 进行点云分割_点云分割onehot

点云分割onehot

介绍

“点云(point cloud)”是一种用于存储几何形状数据的重要数据结构类型。由于其不规则的格式,在用于深度学习应用程序之前,它通常会转换为规则的 3D 体素网格或图像集合,这会使数据变得不必要地大。PointNet 系列模型通过直接使用点云解决了这个问题,同时尊重点数据的排列不变性。PointNet 系列模型为从对象分类部分分割到 场景语义解析等应用提供了一个简单、统一的架构。

在此示例中,我们演示了用于形状分割的 PointNet 架构的实现。

参考


导入

  1. import os
  2. import json
  3. import random
  4. import numpy as np
  5. import pandas as pd
  6. from tqdm import tqdm
  7. from glob import glob
  8. import tensorflow as tf
  9. from tensorflow import keras
  10. from tensorflow.keras import layers
  11. import matplotlib.pyplot as plt

下载数据集

ShapeNet数据集是建立一个注释丰富的大规模 3D 形状数据集的持续努力。ShapeNetCore是完整 ShapeNet 数据集的子集,具有干净的单个 3D 模型和手动验证的类别和对齐注释。它涵盖了 55 个常见的对象类别,拥有大约 51,300 个独特的 3D 模型。

对于此示例,我们使用 PASCAL 3D+的 12 个对象类别之一,作为 ShapenetCore 数据集的一部分。

  1. dataset_url = "https://git.io/JiY4i"
  2. dataset_path = keras.utils.get_file(
  3. fname="shapenet.zip",
  4. origin=dataset_url,
  5. cache_subdir="datasets",
  6. hash_algorithm="auto",
  7. extract=True,
  8. archive_format="auto",
  9. cache_dir="datasets",
  10. )

加载数据集

我们解析数据集元数据,以便轻松地将模型类别映射到它们各自的目录,并将分割类映射到颜色以实现可视化。

  1. with open("/tmp/.keras/datasets/PartAnnotation/metadata.json") as json_file:
  2. metadata = json.load(json_file)
  3. print(metadata)
{'Airplane': {'directory': '02691156', 'lables': ['wing', 'body', 'tail', 'engine'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Bag': {'directory': '02773838', 'lables': ['handle', 'body'], 'colors': ['blue', 'green']}, 'Cap': {'directory': '02954340', 'lables': ['panels', 'peak'], 'colors': ['blue', 'green']}, 'Car': {'directory': '02958343', 'lables': ['wheel', 'hood', 'roof'], 'colors': ['blue', 'green', 'red']}, 'Chair': {'directory': '03001627', 'lables': ['leg', 'arm', 'back', 'seat'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Earphone': {'directory': '03261776', 'lables': ['earphone', 'headband'], 'colors': ['blue', 'green']}, 'Guitar': {'directory': '03467517', 'lables': ['head', 'body', 'neck'], 'colors': ['blue', 'green', 'red']}, 'Knife': {'directory': '03624134', 'lables': ['handle', 'blade'], 'colors': ['blue', 'green']}, 'Lamp': {'directory': '03636649', 'lables': ['canopy', 'lampshade', 'base'], 'colors': ['blue', 'green', 'red']}, 'Laptop': {'directory': '03642806', 'lables': ['keyboard'], 'colors': ['blue']}, 'Motorbike': {'directory': '03790512', 'lables': ['wheel', 'handle', 'gas_tank', 'light', 'seat'], 'colors': ['blue', 'green', 'red', 'pink', 'yellow']}, 'Mug': {'directory': '03797390', 'lables': ['handle'], 'colors': ['blue']}, 'Pistol': {'directory': '03948459', 'lables': ['trigger_and_guard', 'handle', 'barrel'], 'colors': ['blue', 'green', 'red']}, 'Rocket': {'directory': '04099429', 'lables': ['nose', 'body', 'fin'], 'colors': ['blue', 'green', 'red']}, 'Skateboard': {'directory': '04225987', 'lables': ['wheel', 'deck'], 'colors': ['blue', 'green']}, 'Table': {'directory': '04379243', 'lables': ['leg', 'top'], 'colors': ['blue', 'green']}} 

在这个例子中,我们训练 PointNet 来分割Airplane模型的各个部分。

  1. points_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points".format(
  2. metadata["Airplane"]["directory"]
  3. )
  4. labels_dir = "/tmp/.keras/datasets/PartAnnotation/{}/points_label".format(
  5. metadata["Airplane"]["directory"]
  6. )
  7. LABELS = metadata["Airplane"]["lables"]
  8. COLORS = metadata["Airplane"]["colors"]
  9. VAL_SPLIT = 0.2
  10. NUM_SAMPLE_POINTS = 1024
  11. BATCH_SIZE = 32
  12. EPOCHS = 60
  13. INITIAL_LR = 1e-3

构建数据集

我们从飞机点云及其标签生成以下内存数据结构:

  • point_cloudsnp.array是以x、y 和 z 坐标的形式表示点云数据的对象列表。轴 0 表示点云中的点数,轴 1 表示坐标。all_labels是将每个坐标的标签表示为字符串的列表(主要用于可视化目的)。
  • test_point_clouds与 格式相同point_clouds,但没有对应的点云标签。
  • all_labelsnp.array表示每个坐标的点云标签的对象列表,对应于point_clouds列表。
  • point_cloud_labels是一个np.array对象列表,它以 one-hot 编码形式表示每个坐标的点云标签,对应于point_clouds 列表。
  1. point_clouds, test_point_clouds = [], []
  2. point_cloud_labels, all_labels = [], []
  3. points_files = glob(os.path.join(points_dir, "*.pts"))
  4. for point_file in tqdm(points_files):
  5. point_cloud = np.loadtxt(point_file)
  6. if point_cloud.shape[0] < NUM_SAMPLE_POINTS:
  7. continue
  8. # Get the file-id of the current point cloud for parsing its
  9. # labels.
  10. file_id = point_file.split("/")[-1].split(".")[0]
  11. label_data, num_labels = {}, 0
  12. for label in LABELS:
  13. label_file = os.path.join(labels_dir, label, file_id + ".seg")
  14. if os.path.exists(label_file):
  15. label_data[label] = np.loadtxt(label_file).astype("float32")
  16. num_labels = len(label_data[label])
  17. # Point clouds having labels will be our training samples.
  18. try:
  19. label_map = ["none"] * num_labels
  20. for label in LABELS:
  21. for i, data in enumerate(label_data[label]):
  22. label_map[i] = label if data == 1 else label_map[i]
  23. label_data = [
  24. LABELS.index(label) if label != "none" else len(LABELS)
  25. for label in label_map
  26. ]
  27. # Apply one-hot encoding to the dense label representation.
  28. label_data = keras.utils.to_categorical(label_data, num_classes=len(LABELS) + 1)
  29. point_clouds.append(point_cloud)
  30. point_cloud_labels.append(label_data)
  31. all_labels.append(label_map)
  32. except KeyError:
  33. test_point_clouds.append(point_cloud)
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4045/4045 [03:35<00:00, 18.76it/s] 

接下来,我们看一下刚刚生成的内存数组中的一些样本:

  1. for _ in range(5):
  2. i = random.randint(0, len(point_clouds) - 1)
  3. print(f"point_clouds[{i}].shape:", point_clouds[0].shape)
  4. print(f"point_cloud_labels[{i}].shape:", point_cloud_labels[0].shape)
  5. for j in range(5):
  6. print(
  7. f"all_labels[{i}][{j}]:",
  8. all_labels[i][j],
  9. f"\tpoint_cloud_labels[{i}][{j}]:",
  10. point_cloud_labels[i][j],
  11. "\n",
  12. )
point_clouds[475].shape: (2602, 3) point_cloud_labels[475].shape: (2602, 5) all_labels[475][0]: body point_cloud_labels[475][0]: [0. 1. 0. 0. 0.] 
all_labels[475][1]: engine point_cloud_labels[475][1]: [0. 0. 0. 1. 0.] 
all_labels[475][2]: body point_cloud_labels[475][2]: [0. 1. 0. 0. 0.] 
all_labels[475][3]: body point_cloud_labels[475][3]: [0. 1. 0. 0. 0.] 
all_labels[475][4]: wing point_cloud_labels[475][4]: [1. 0. 0. 0. 0.] 
point_clouds[2712].shape: (2602, 3) point_cloud_labels[2712].shape: (2602, 5) all_labels[2712][0]: tail point_cloud_labels[2712][0]: [0. 0. 1. 0. 0.] 
all_labels[2712][1]: wing point_cloud_labels[2712][1]: [1. 0. 0. 0. 0.] 
all_labels[2712][2]: engine point_cloud_labels[2712][2]: [0. 0. 0. 1. 0.] 
all_labels[2712][3]: wing point_cloud_labels[2712][3]: [1. 0. 0. 0. 0.] 
all_labels[2712][4]: wing point_cloud_labels[2712][4]: [1. 0. 0. 0. 0.] 
point_clouds[1413].shape: (2602, 3) point_cloud_labels[1413].shape: (2602, 5) all_labels[1413][0]: body point_cloud_labels[1413][0]: [0. 1. 0. 0. 0.] 
all_labels[1413][1]: tail point_cloud_labels[1413][1]: [0. 0. 1. 0. 0.] 
all_labels[1413][2]: tail point_cloud_labels[1413][2]: [0. 0. 1. 0. 0.] 
all_labels[1413][3]: tail point_cloud_labels[1413][3]: [0. 0. 1. 0. 0.] 
all_labels[1413][4]: tail point_cloud_labels[1413][4]: [0. 0. 1. 0. 0.] 
point_clouds[1207].shape: (2602, 3) point_cloud_labels[1207].shape: (2602, 5) all_labels[1207][0]: tail point_cloud_labels[1207][0]: [0. 0. 1. 0. 0.] 
all_labels[1207][1]: wing point_cloud_labels[1207][1]: [1. 0. 0. 0. 0.] 
all_labels[1207][2]: wing point_cloud_labels[1207][2]: [1. 0. 0. 0. 0.] 
all_labels[1207][3]: body point_cloud_labels[1207][3]: [0. 1. 0. 0. 0.] 
all_labels[1207][4]: body point_cloud_labels[1207][4]: [0. 1. 0. 0. 0.] 
point_clouds[2492].shape: (2602, 3) point_cloud_labels[2492].shape: (2602, 5) all_labels[2492][0]: engine point_cloud_labels[2492][0]: [0. 0. 0. 1. 0.] 
all_labels[2492][1]: body point_cloud_labels[2492][1]: [0. 1. 0. 0. 0.] 
all_labels[2492][2]: body point_cloud_labels[2492][2]: [0. 1. 0. 0. 0.] 
all_labels[2492][3]: body point_cloud_labels[2492][3]: [0. 1. 0. 0. 0.] 
all_labels[2492][4]: engine point_cloud_labels[2492][4]: [0. 0. 0. 1. 0.] 

现在,让我们可视化一些点云及其标签。

  1. def visualize_data(point_cloud, labels):
  2. df = pd.DataFrame(
  3. data={
  4. "x": point_cloud[:, 0],
  5. "y": point_cloud[:, 1],
  6. "z": point_cloud[:, 2],
  7. "label": labels,
  8. }
  9. )
  10. fig = plt.figure(figsize=(15, 10))
  11. ax = plt.axes(projection="3d")
  12. for index, label in enumerate(LABELS):
  13. c_df = df[df["label"] == label]
  14. try:
  15. ax.scatter(
  16. c_df["x"], c_df["y"], c_df["z"], label=label, alpha=0.5, c=COLORS[index]
  17. )
  18. except IndexError:
  19. pass
  20. ax.legend()
  21. plt.show()
  22. visualize_data(point_clouds[0], all_labels[0])
  23. visualize_data(point_clouds[300], all_labels[300])

预处理

请注意,我们加载的所有点云都包含可变数量的点,这使得我们很难将它们批处理在一起。为了克服这个问题,我们从每个点云中随机抽取固定数量的点。我们还对点云进行归一化,以使数据具有尺度不变性。

  1. for index in tqdm(range(len(point_clouds))):
  2. current_point_cloud = point_clouds[index]
  3. current_label_cloud = point_cloud_labels[index]
  4. current_labels = all_labels[index]
  5. num_points = len(current_point_cloud)
  6. # Randomly sampling respective indices.
  7. sampled_indices = random.sample(list(range(num_points)), NUM_SAMPLE_POINTS)
  8. # Sampling points corresponding to sampled indices.
  9. sampled_point_cloud = np.array([current_point_cloud[i] for i in sampled_indices])
  10. # Sampling corresponding one-hot encoded labels.
  11. sampled_label_cloud = np.array([current_label_cloud[i] for i in sampled_indices])
  12. # Sampling corresponding labels for visualization.
  13. sampled_labels = np.array([current_labels[i] for i in sampled_indices])
  14. # Normalizing sampled point cloud.
  15. norm_point_cloud = sampled_point_cloud - np.mean(sampled_point_cloud, axis=0)
  16. norm_point_cloud /= np.max(np.linalg.norm(norm_point_cloud, axis=1))
  17. point_clouds[index] = norm_point_cloud
  18. point_cloud_labels[index] = sampled_label_cloud
  19. all_labels[index] = sampled_labels
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3694/3694 [00:07<00:00, 478.67it/s] 

让我们可视化采样和归一化的点云及其相应的标签。

  1. visualize_data(point_clouds[0], all_labels[0])
  2. visualize_data(point_clouds[300], all_labels[300])

创建 TensorFlow 数据集

tf.data.Dataset我们为训练和验证数据创建对象。我们还通过对其应用随机抖动来增强训练点云。

  1. def load_data(point_cloud_batch, label_cloud_batch):
  2. point_cloud_batch.set_shape([NUM_SAMPLE_POINTS, 3])
  3. label_cloud_batch.set_shape([NUM_SAMPLE_POINTS, len(LABELS) + 1])
  4. return point_cloud_batch, label_cloud_batch
  5. def augment(point_cloud_batch, label_cloud_batch):
  6. noise = tf.random.uniform(
  7. tf.shape(label_cloud_batch), -0.005, 0.005, dtype=tf.float64
  8. )
  9. point_cloud_batch += noise[:, :, :3]
  10. return point_cloud_batch, label_cloud_batch
  11. def generate_dataset(point_clouds, label_clouds, is_training=True):
  12. dataset = tf.data.Dataset.from_tensor_slices((point_clouds, label_clouds))
  13. dataset = dataset.shuffle(BATCH_SIZE * 100) if is_training else dataset
  14. dataset = dataset.map(load_data, num_parallel_calls=tf.data.AUTOTUNE)
  15. dataset = dataset.batch(batch_size=BATCH_SIZE)
  16. dataset = (
  17. dataset.map(augment, num_parallel_calls=tf.data.AUTOTUNE)
  18. if is_training
  19. else dataset
  20. )
  21. return dataset
  22. split_index = int(len(point_clouds) * (1 - VAL_SPLIT))
  23. train_point_clouds = point_clouds[:split_index]
  24. train_label_cloud = point_cloud_labels[:split_index]
  25. total_training_examples = len(train_point_clouds)
  26. val_point_clouds = point_clouds[split_index:]
  27. val_label_cloud = point_cloud_labels[split_index:]
  28. print("Num train point clouds:", len(train_point_clouds))
  29. print("Num train point cloud labels:", len(train_label_cloud))
  30. print("Num val point clouds:", len(val_point_clouds))
  31. print("Num val point cloud labels:", len(val_label_cloud))
  32. train_dataset = generate_dataset(train_point_clouds, train_label_cloud)
  33. val_dataset = generate_dataset(val_point_clouds, val_label_cloud, is_training=False)
  34. print("Train Dataset:", train_dataset)
  35. print("Validation Dataset:", val_dataset)
Num train point clouds: 2955 Num train point cloud labels: 2955 Num val point clouds: 739 Num val point cloud labels: 739 Train Dataset: <ParallelMapDataset shapes: ((None, 1024, 3), (None, 1024, 5)), types: (tf.float64, tf.float32)> Validation Dataset: <BatchDataset shapes: ((None, 1024, 3), (None, 1024, 5)), types: (tf.float64, tf.float32)> 

点网模型

下图描述了 PointNet 模型族的内部结构:

鉴于 PointNet 旨在使用一组无序坐标作为其输入数据,其架构需要匹配点云数据的以下特征:

排列不变性

鉴于点云数据的非结构化性质,由点组成的扫描n具有n! 排列。后续的数据处理必须对不同的表示保持不变。为了使 PointNet 对输入排列保持不变,一旦n输入点映射到更高维空间,我们就使用对称函数(例如 max-pooling)。结果是一个全局特征向量,旨在捕获n输入点的聚合签名。全局特征向量与局部点特征一起用于分割。

变换不变性

如果对象经历了某些转换,例如平移或缩放,则分割输出应该保持不变。对于给定的输入点云,我们应用适当的刚性或仿射变换来实现姿态归一化。因为每个n输入点都表示为一个向量并独立地映射到嵌入空间,所以应用几何变换简单地等于矩阵将每个点与一个变换矩阵相乘。这是由空间变压器网络的概念推动的 。

构成 T-Net 的操作是由 PointNet 的更高级别架构推动的。MLP(或全连接层)用于将输入点独立且相同地映射到更高维空间;最大池用于编码全局特征向量,然后使用全连接层降低其维度。然后将最终全连接层的输入相关特征与全局可训练的权重和偏差相结合,形成一个 3×3 变换矩阵。

点交互

相邻点之间的交互通常携带有用的信息(即,不应孤立地处理单个点)。分类只需要利用全局特征,而分割必须能够利用局部点特征和全局点特征。

:本节中的数字取自 原始论文

现在我们知道了构成 PointNet 模型的部分,我们可以实现该模型。我们首先实现基本块,即卷积块和多层感知器块。

  1. def conv_block(x: tf.Tensor, filters: int, name: str) -> tf.Tensor:
  2. x = layers.Conv1D(filters, kernel_size=1, padding="valid", name=f"{name}_conv")(x)
  3. x = layers.BatchNormalization(momentum=0.0, name=f"{name}_batch_norm")(x)
  4. return layers.Activation("relu", name=f"{name}_relu")(x)
  5. def mlp_block(x: tf.Tensor, filters: int, name: str) -> tf.Tensor:
  6. x = layers.Dense(filters, name=f"{name}_dense")(x)
  7. x = layers.BatchNormalization(momentum=0.0, name=f"{name}_batch_norm")(x)
  8. return layers.Activation("relu", name=f"{name}_relu")(x)

我们实现了一个正则化器(取自 这个例子)来加强特征空间的正交性。这是为了确保转换后的特征的幅度不会变化太大。

  1. class OrthogonalRegularizer(keras.regularizers.Regularizer):
  2. """Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""
  3. def __init__(self, num_features, l2reg=0.001):
  4. self.num_features = num_features
  5. self.l2reg = l2reg
  6. self.identity = tf.eye(num_features)
  7. def __call__(self, x):
  8. x = tf.reshape(x, (-1, self.num_features, self.num_features))
  9. xxt = tf.tensordot(x, x, axes=(2, 2))
  10. xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
  11. return tf.reduce_sum(self.l2reg * tf.square(xxt - self.identity))
  12. def get_config(self):
  13. config = super(TransformerEncoder, self).get_config()
  14. config.update({"num_features": self.num_features, "l2reg_strength": self.l2reg})
  15. return config

下一部分是我们之前解释过的转换网络。

  1. def transformation_net(inputs: tf.Tensor, num_features: int, name: str) -> tf.Tensor:
  2. """
  3. Reference: https://keras.io/examples/vision/pointnet/#build-a-model.
  4. The `filters` values come from the original paper:
  5. https://arxiv.org/abs/1612.00593.
  6. """
  7. x = conv_block(inputs, filters=64, name=f"{name}_1")
  8. x = conv_block(x, filters=128, name=f"{name}_2")
  9. x = conv_block(x, filters=1024, name=f"{name}_3")
  10. x = layers.GlobalMaxPooling1D()(x)
  11. x = mlp_block(x, filters=512, name=f"{name}_1_1")
  12. x = mlp_block(x, filters=256, name=f"{name}_2_1")
  13. return layers.Dense(
  14. num_features * num_features,
  15. kernel_initializer="zeros",
  16. bias_initializer=keras.initializers.Constant(np.eye(num_features).flatten()),
  17. activity_regularizer=OrthogonalRegularizer(num_features),
  18. name=f"{name}_final",
  19. )(x)
  20. def transformation_block(inputs: tf.Tensor, num_features: int, name: str) -> tf.Tensor:
  21. transformed_features = transformation_net(inputs, num_features, name=name)
  22. transformed_features = layers.Reshape((num_features, num_features))(
  23. transformed_features
  24. )
  25. return layers.Dot(axes=(2, 1), name=f"{name}_mm")([inputs, transformed_features])

最后,我们将上述块拼凑在一起并实现分割模型。

  1. def get_shape_segmentation_model(num_points: int, num_classes: int) -> keras.Model:
  2. input_points = keras.Input(shape=(None, 3))
  3. # PointNet Classification Network.
  4. transformed_inputs = transformation_block(
  5. input_points, num_features=3, name="input_transformation_block"
  6. )
  7. features_64 = conv_block(transformed_inputs, filters=64, name="features_64")
  8. features_128_1 = conv_block(features_64, filters=128, name="features_128_1")
  9. features_128_2 = conv_block(features_128_1, filters=128, name="features_128_2")
  10. transformed_features = transformation_block(
  11. features_128_2, num_features=128, name="transformed_features"
  12. )
  13. features_512 = conv_block(transformed_features, filters=512, name="features_512")
  14. features_2048 = conv_block(features_512, filters=2048, name="pre_maxpool_block")
  15. global_features = layers.MaxPool1D(pool_size=num_points, name="global_features")(
  16. features_2048
  17. )
  18. global_features = tf.tile(global_features, [1, num_points, 1])
  19. # Segmentation head.
  20. segmentation_input = layers.Concatenate(name="segmentation_input")(
  21. [
  22. features_64,
  23. features_128_1,
  24. features_128_2,
  25. transformed_features,
  26. features_512,
  27. global_features,
  28. ]
  29. )
  30. segmentation_features = conv_block(
  31. segmentation_input, filters=128, name="segmentation_features"
  32. )
  33. outputs = layers.Conv1D(
  34. num_classes, kernel_size=1, activation="softmax", name="segmentation_head"
  35. )(segmentation_features)
  36. return keras.Model(input_points, outputs)

实例化模型

  1. x, y = next(iter(train_dataset))
  2. num_points = x.shape[1]
  3. num_classes = y.shape[-1]
  4. segmentation_model = get_shape_segmentation_model(num_points, num_classes)
  5. segmentation_model.summary()
2021-10-25 01:26:33.563133: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2) Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, None, 3)] 0 __________________________________________________________________________________________________ input_transformation_block_1_co (None, None, 64) 256 input_1[0][0] __________________________________________________________________________________________________ input_transformation_block_1_ba (None, None, 64) 256 input_transformation_block_1_conv __________________________________________________________________________________________________ input_transformation_block_1_re (None, None, 64) 0 input_transformation_block_1_batc __________________________________________________________________________________________________ input_transformation_block_2_co (None, None, 128) 8320 input_transformation_block_1_relu __________________________________________________________________________________________________ input_transformation_block_2_ba (None, None, 128) 512 input_transformation_block_2_conv __________________________________________________________________________________________________ input_transformation_block_2_re (None, None, 128) 0 input_transformation_block_2_batc __________________________________________________________________________________________________ input_transformation_block_3_co (None, None, 1024) 132096 input_transformation_block_2_relu __________________________________________________________________________________________________ input_transformation_block_3_ba (None, None, 1024) 4096 input_transformation_block_3_conv __________________________________________________________________________________________________ input_transformation_block_3_re (None, None, 1024) 0 input_transformation_block_3_batc __________________________________________________________________________________________________ global_max_pooling1d (GlobalMax (None, 1024) 0 input_transformation_block_3_relu __________________________________________________________________________________________________ input_transformation_block_1_1_ (None, 512) 524800 global_max_pooling1d[0][0] __________________________________________________________________________________________________ input_transformation_block_1_1_ (None, 512) 2048 input_transformation_block_1_1_de __________________________________________________________________________________________________ input_transformation_block_1_1_ (None, 512) 0 input_transformation_block_1_1_ba __________________________________________________________________________________________________ input_transformation_block_2_1_ (None, 256) 131328 input_transformation_block_1_1_re __________________________________________________________________________________________________ input_transformation_block_2_1_ (None, 256) 1024 input_transformation_block_2_1_de __________________________________________________________________________________________________ input_transformation_block_2_1_ (None, 256) 0 input_transformation_block_2_1_ba __________________________________________________________________________________________________ input_transformation_block_fina (None, 9) 2313 input_transformation_block_2_1_re __________________________________________________________________________________________________ reshape (Reshape) (None, 3, 3) 0 input_transformation_block_final[ __________________________________________________________________________________________________ input_transformation_block_mm ( (None, None, 3) 0 input_1[0][0] reshape[0][0] __________________________________________________________________________________________________ features_64_conv (Conv1D) (None, None, 64) 256 input_transformation_block_mm[0][ __________________________________________________________________________________________________ features_64_batch_norm (BatchNo (None, None, 64) 256 features_64_conv[0][0] __________________________________________________________________________________________________ features_64_relu (Activation) (None, None, 64) 0 features_64_batch_norm[0][0] __________________________________________________________________________________________________ features_128_1_conv (Conv1D) (None, None, 128) 8320 features_64_relu[0][0] __________________________________________________________________________________________________ features_128_1_batch_norm (Batc (None, None, 128) 512 features_128_1_conv[0][0] __________________________________________________________________________________________________ features_128_1_relu (Activation (None, None, 128) 0 features_128_1_batch_norm[0][0] __________________________________________________________________________________________________ features_128_2_conv (Conv1D) (None, None, 128) 16512 features_128_1_relu[0][0] __________________________________________________________________________________________________ features_128_2_batch_norm (Batc (None, None, 128) 512 features_128_2_conv[0][0] __________________________________________________________________________________________________ features_128_2_relu (Activation (None, None, 128) 0 features_128_2_batch_norm[0][0] __________________________________________________________________________________________________ transformed_features_1_conv (Co (None, None, 64) 8256 features_128_2_relu[0][0] __________________________________________________________________________________________________ transformed_features_1_batch_no (None, None, 64) 256 transformed_features_1_conv[0][0] __________________________________________________________________________________________________ transformed_features_1_relu (Ac (None, None, 64) 0 transformed_features_1_batch_norm __________________________________________________________________________________________________ transformed_features_2_conv (Co (None, None, 128) 8320 transformed_features_1_relu[0][0] __________________________________________________________________________________________________ transformed_features_2_batch_no (None, None, 128) 512 transformed_features_2_conv[0][0] __________________________________________________________________________________________________ transformed_features_2_relu (Ac (None, None, 128) 0 transformed_features_2_batch_norm __________________________________________________________________________________________________ transformed_features_3_conv (Co (None, None, 1024) 132096 transformed_features_2_relu[0][0] __________________________________________________________________________________________________ transformed_features_3_batch_no (None, None, 1024) 4096 transformed_features_3_conv[0][0] __________________________________________________________________________________________________ transformed_features_3_relu (Ac (None, None, 1024) 0 transformed_features_3_batch_norm __________________________________________________________________________________________________ global_max_pooling1d_1 (GlobalM (None, 1024) 0 transformed_features_3_relu[0][0] __________________________________________________________________________________________________ transformed_features_1_1_dense (None, 512) 524800 global_max_pooling1d_1[0][0] __________________________________________________________________________________________________ transformed_features_1_1_batch_ (None, 512) 2048 transformed_features_1_1_dense[0] __________________________________________________________________________________________________ transformed_features_1_1_relu ( (None, 512) 0 transformed_features_1_1_batch_no __________________________________________________________________________________________________ transformed_features_2_1_dense (None, 256) 131328 transformed_features_1_1_relu[0][ __________________________________________________________________________________________________ transformed_features_2_1_batch_ (None, 256) 1024 transformed_features_2_1_dense[0] __________________________________________________________________________________________________ transformed_features_2_1_relu ( (None, 256) 0 transformed_features_2_1_batch_no __________________________________________________________________________________________________ transformed_features_final (Den (None, 16384) 4210688 transformed_features_2_1_relu[0][ __________________________________________________________________________________________________ reshape_1 (Reshape) (None, 128, 128) 0 transformed_features_final[0][0] __________________________________________________________________________________________________ transformed_features_mm (Dot) (None, None, 128) 0 features_128_2_relu[0][0] reshape_1[0][0] __________________________________________________________________________________________________ features_512_conv (Conv1D) (None, None, 512) 66048 transformed_features_mm[0][0] __________________________________________________________________________________________________ features_512_batch_norm (BatchN (None, None, 512) 2048 features_512_conv[0][0] __________________________________________________________________________________________________ features_512_relu (Activation) (None, None, 512) 0 features_512_batch_norm[0][0] __________________________________________________________________________________________________ pre_maxpool_block_conv (Conv1D) (None, None, 2048) 1050624 features_512_relu[0][0] __________________________________________________________________________________________________ pre_maxpool_block_batch_norm (B (None, None, 2048) 8192 pre_maxpool_block_conv[0][0] __________________________________________________________________________________________________ pre_maxpool_block_relu (Activat (None, None, 2048) 0 pre_maxpool_block_batch_norm[0][0 __________________________________________________________________________________________________ global_features (MaxPooling1D) (None, None, 2048) 0 pre_maxpool_block_relu[0][0] __________________________________________________________________________________________________ tf.tile (TFOpLambda) (None, None, 2048) 0 global_features[0][0] __________________________________________________________________________________________________ segmentation_input (Concatenate (None, None, 3008) 0 features_64_relu[0][0] features_128_1_relu[0][0] features_128_2_relu[0][0] transformed_features_mm[0][0] features_512_relu[0][0] tf.tile[0][0] __________________________________________________________________________________________________ segmentation_features_conv (Con (None, None, 128) 385152 segmentation_input[0][0] __________________________________________________________________________________________________ segmentation_features_batch_nor (None, None, 128) 512 segmentation_features_conv[0][0] __________________________________________________________________________________________________ segmentation_features_relu (Act (None, None, 128) 0 segmentation_features_batch_norm[ __________________________________________________________________________________________________ segmentation_head (Conv1D) (None, None, 5) 645 segmentation_features_relu[0][0] ================================================================================================== Total params: 7,370,062 Trainable params: 7,356,110 Non-trainable params: 13,952 __________________________________________________________________________________________________ 

训练

对于训练,作者建议使用每 20 个 epoch 将初始学习率降低一半的学习率计划。在这个例子中,我们使用 15 个 epoch。

  1. training_step_size = total_training_examples // BATCH_SIZE
  2. total_training_steps = training_step_size * EPOCHS
  3. print(f"Total training steps: {total_training_steps}.")
  4. lr_schedule = keras.optimizers.schedules.PiecewiseConstantDecay(
  5. boundaries=[training_step_size * 15, training_step_size * 15],
  6. values=[INITIAL_LR, INITIAL_LR * 0.5, INITIAL_LR * 0.25],
  7. )
  8. steps = tf.range(total_training_steps, dtype=tf.int32)
  9. lrs = [lr_schedule(step) for step in steps]
  10. plt.plot(lrs)
  11. plt.xlabel("Steps")
  12. plt.ylabel("Learning Rate")
  13. plt.show()
Total training steps: 5520. 

最后,我们实现了一个实用程序来运行我们的实验并启动模型训练。

  1. def run_experiment(epochs):
  2. segmentation_model = get_shape_segmentation_model(num_points, num_classes)
  3. segmentation_model.compile(
  4. optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
  5. loss=keras.losses.CategoricalCrossentropy(),
  6. metrics=["accuracy"],
  7. )
  8. checkpoint_filepath = "/tmp/checkpoint"
  9. checkpoint_callback = keras.callbacks.ModelCheckpoint(
  10. checkpoint_filepath,
  11. monitor="val_loss",
  12. save_best_only=True,
  13. save_weights_only=True,
  14. )
  15. history = segmentation_model.fit(
  16. train_dataset,
  17. validation_data=val_dataset,
  18. epochs=epochs,
  19. callbacks=[checkpoint_callback],
  20. )
  21. segmentation_model.load_weights(checkpoint_filepath)
  22. return segmentation_model, history
  23. segmentation_model, history = run_experiment(epochs=EPOCHS)
Epoch 1/60 93/93 [==============================] - 28s 127ms/step - loss: 5.3556 - accuracy: 0.7448 - val_loss: 5.8386 - val_accuracy: 0.7471 Epoch 2/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7077 - accuracy: 0.8181 - val_loss: 5.2614 - val_accuracy: 0.7793 Epoch 3/60 93/93 [==============================] - 11s 118ms/step - loss: 4.6566 - accuracy: 0.8301 - val_loss: 4.7907 - val_accuracy: 0.8269 Epoch 4/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6059 - accuracy: 0.8406 - val_loss: 4.6031 - val_accuracy: 0.8482 Epoch 5/60 93/93 [==============================] - 11s 118ms/step - loss: 4.5828 - accuracy: 0.8444 - val_loss: 4.7692 - val_accuracy: 0.8220 Epoch 6/60 93/93 [==============================] - 11s 118ms/step - loss: 4.6150 - accuracy: 0.8408 - val_loss: 5.4460 - val_accuracy: 0.8192 Epoch 7/60 93/93 [==============================] - 11s 117ms/step - loss: 67.5943 - accuracy: 0.7378 - val_loss: 1617.1846 - val_accuracy: 0.5191 Epoch 8/60 93/93 [==============================] - 11s 117ms/step - loss: 15.2910 - accuracy: 0.6651 - val_loss: 8.1014 - val_accuracy: 0.7046 Epoch 9/60 93/93 [==============================] - 11s 117ms/step - loss: 6.8878 - accuracy: 0.7368 - val_loss: 14.2311 - val_accuracy: 0.6949 Epoch 10/60 93/93 [==============================] - 11s 117ms/step - loss: 5.8362 - accuracy: 0.7549 - val_loss: 14.6942 - val_accuracy: 0.6350 Epoch 11/60 93/93 [==============================] - 11s 117ms/step - loss: 5.4777 - accuracy: 0.7648 - val_loss: 44.1037 - val_accuracy: 0.6422 Epoch 12/60 93/93 [==============================] - 11s 117ms/step - loss: 5.2688 - accuracy: 0.7712 - val_loss: 4.9977 - val_accuracy: 0.7692 Epoch 13/60 93/93 [==============================] - 11s 117ms/step - loss: 5.1041 - accuracy: 0.7837 - val_loss: 6.0642 - val_accuracy: 0.7577 Epoch 14/60 93/93 [==============================] - 11s 117ms/step - loss: 5.0011 - accuracy: 0.7862 - val_loss: 4.9313 - val_accuracy: 0.7840 Epoch 15/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8910 - accuracy: 0.7953 - val_loss: 5.8368 - val_accuracy: 0.7725 Epoch 16/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8698 - accuracy: 0.8074 - val_loss: 73.0260 - val_accuracy: 0.7251 Epoch 17/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8299 - accuracy: 0.8109 - val_loss: 17.1503 - val_accuracy: 0.7415 Epoch 18/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8147 - accuracy: 0.8111 - val_loss: 62.2765 - val_accuracy: 0.7344 Epoch 19/60 93/93 [==============================] - 11s 117ms/step - loss: 4.8316 - accuracy: 0.8141 - val_loss: 5.2200 - val_accuracy: 0.7890 Epoch 20/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7853 - accuracy: 0.8142 - val_loss: 5.7062 - val_accuracy: 0.7719 Epoch 21/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7753 - accuracy: 0.8157 - val_loss: 6.2089 - val_accuracy: 0.7839 Epoch 22/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7681 - accuracy: 0.8161 - val_loss: 5.1077 - val_accuracy: 0.8021 Epoch 23/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7554 - accuracy: 0.8187 - val_loss: 4.7912 - val_accuracy: 0.7912 Epoch 24/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7355 - accuracy: 0.8197 - val_loss: 4.9164 - val_accuracy: 0.7978 Epoch 25/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7483 - accuracy: 0.8197 - val_loss: 13.4724 - val_accuracy: 0.7631 Epoch 26/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7200 - accuracy: 0.8218 - val_loss: 8.3074 - val_accuracy: 0.7596 Epoch 27/60 93/93 [==============================] - 11s 118ms/step - loss: 4.7192 - accuracy: 0.8231 - val_loss: 12.4468 - val_accuracy: 0.7591 Epoch 28/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7151 - accuracy: 0.8241 - val_loss: 23.8681 - val_accuracy: 0.7689 Epoch 29/60 93/93 [==============================] - 11s 117ms/step - loss: 4.7096 - accuracy: 0.8237 - val_loss: 4.9069 - val_accuracy: 0.8104 Epoch 30/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6991 - accuracy: 0.8257 - val_loss: 4.9858 - val_accuracy: 0.7950 Epoch 31/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6852 - accuracy: 0.8260 - val_loss: 5.0130 - val_accuracy: 0.7678 Epoch 32/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6630 - accuracy: 0.8286 - val_loss: 4.8523 - val_accuracy: 0.7676 Epoch 33/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6837 - accuracy: 0.8281 - val_loss: 5.4347 - val_accuracy: 0.8095 Epoch 34/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6571 - accuracy: 0.8296 - val_loss: 10.4595 - val_accuracy: 0.7410 Epoch 35/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6460 - accuracy: 0.8321 - val_loss: 4.9189 - val_accuracy: 0.8083 Epoch 36/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6430 - accuracy: 0.8327 - val_loss: 5.8674 - val_accuracy: 0.7911 Epoch 37/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6530 - accuracy: 0.8309 - val_loss: 4.7946 - val_accuracy: 0.8032 Epoch 38/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6391 - accuracy: 0.8318 - val_loss: 5.0111 - val_accuracy: 0.8024 Epoch 39/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6521 - accuracy: 0.8336 - val_loss: 8.1558 - val_accuracy: 0.7727 Epoch 40/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6443 - accuracy: 0.8329 - val_loss: 42.8513 - val_accuracy: 0.7688 Epoch 41/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6316 - accuracy: 0.8342 - val_loss: 5.0960 - val_accuracy: 0.8066 Epoch 42/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6322 - accuracy: 0.8335 - val_loss: 5.0634 - val_accuracy: 0.8158 Epoch 43/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6175 - accuracy: 0.8370 - val_loss: 6.0642 - val_accuracy: 0.8062 Epoch 44/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6175 - accuracy: 0.8371 - val_loss: 11.1805 - val_accuracy: 0.7790 Epoch 45/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6056 - accuracy: 0.8377 - val_loss: 4.7359 - val_accuracy: 0.8145 Epoch 46/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6108 - accuracy: 0.8383 - val_loss: 5.7125 - val_accuracy: 0.7713 Epoch 47/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6103 - accuracy: 0.8377 - val_loss: 6.3271 - val_accuracy: 0.8105 Epoch 48/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6020 - accuracy: 0.8383 - val_loss: 14.2876 - val_accuracy: 0.7529 Epoch 49/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6035 - accuracy: 0.8382 - val_loss: 4.8244 - val_accuracy: 0.8143 Epoch 50/60 93/93 [==============================] - 11s 117ms/step - loss: 4.6076 - accuracy: 0.8381 - val_loss: 8.2636 - val_accuracy: 0.7528 Epoch 51/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5927 - accuracy: 0.8399 - val_loss: 4.6473 - val_accuracy: 0.8266 Epoch 52/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5927 - accuracy: 0.8408 - val_loss: 4.6443 - val_accuracy: 0.8276 Epoch 53/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5852 - accuracy: 0.8413 - val_loss: 5.1300 - val_accuracy: 0.7768 Epoch 54/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5787 - accuracy: 0.8426 - val_loss: 8.9590 - val_accuracy: 0.7582 Epoch 55/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5837 - accuracy: 0.8410 - val_loss: 5.1501 - val_accuracy: 0.8117 Epoch 56/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5875 - accuracy: 0.8422 - val_loss: 31.3518 - val_accuracy: 0.7590 Epoch 57/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5821 - accuracy: 0.8427 - val_loss: 4.8853 - val_accuracy: 0.8144 Epoch 58/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5751 - accuracy: 0.8446 - val_loss: 4.6653 - val_accuracy: 0.8222 Epoch 59/60 93/93 [==============================] - 11s 117ms/step - loss: 4.5752 - accuracy: 0.8447 - val_loss: 6.0078 - val_accuracy: 0.8014 Epoch 60/60 93/93 [==============================] - 11s 118ms/step - loss: 4.5695 - accuracy: 0.8452 - val_loss: 4.8178 - val_accuracy: 0.8192 

可视化训练环境

  1. def plot_result(item):
  2. plt.plot(history.history[item], label=item)
  3. plt.plot(history.history["val_" + item], label="val_" + item)
  4. plt.xlabel("Epochs")
  5. plt.ylabel(item)
  6. plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
  7. plt.legend()
  8. plt.grid()
  9. plt.show()
  10. plot_result("loss")
  11. plot_result("accuracy")


推理

  1. validation_batch = next(iter(val_dataset))
  2. val_predictions = segmentation_model.predict(validation_batch[0])
  3. print(f"Validation prediction shape: {val_predictions.shape}")
  4. def visualize_single_point_cloud(point_clouds, label_clouds, idx):
  5. label_map = LABELS + ["none"]
  6. point_cloud = point_clouds[idx]
  7. label_cloud = label_clouds[idx]
  8. visualize_data(point_cloud, [label_map[np.argmax(label)] for label in label_cloud])
  9. idx = np.random.choice(len(validation_batch[0]))
  10. print(f"Index selected: {idx}")
  11. # Plotting with ground-truth.
  12. visualize_single_point_cloud(validation_batch[0], validation_batch[1], idx)
  13. # Plotting with predicted labels.
  14. visualize_single_point_cloud(validation_batch[0], val_predictions, idx)
Validation prediction shape: (32, 1024, 5) Index selected: 24 


最后的笔记

如果您有兴趣了解有关此主题的更多信息,您可能会发现 此存储库 很有用。

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/不正经/article/detail/228964
推荐阅读
相关标签
  

闽ICP备14008679号