当前位置:   article > 正文

工程 (六) ——PointNet点云分类_pointnet2

pointnet2

目录

一、 基本原理

二、工程目录

三、代码解析

3.1 分类

3.2 语义分割

四、测试运行

4.1 分类

4.2 语义分割


一、 基本原理

相比与pointnet,pointnet++是对点云逐层运用RNN最邻近收缩进行均匀降采样,加上上一层的特征传入PointNet

为不受坐标的影响,需要有Normalize步骤减去中心位置,以不受绝对距离的影响。比如一个人在1m和在20m都是一个人。

网络框架如下,pointnet++可以进行点云分类和语义分割的操作。

二、工程目录

  • test和train:这一系列文件是pointnet++不同功能的测试和验证文件
  • visualizer:是对点云分割结果的可视化
  • data:放数据集的文件夹
  • data_utils:对数据集的解析文件
  • log:日志文件及各种功能网络模型的相关代码及预训练权重

 可以看到三个功能是分类、区域分割和语义分割,打开一个文件

 checkpoint:预训练权重文件

logs:训练日志文件

pointnet2_utils.py:网络相关的代码组件

pointnet2_cls_msg.py:网络模型文件与models下的一样

  • models:各种网络模型

三、代码解析

3.1 分类

首先打开model/pointnet2_cls_ssg.py的一个分类网络模型的文件,代码解析如下,可以看到分类网络主要通过三个特征提取层PointNetSetAbstraction之后再通过全链接来达到分类的目的。

  1. import torch.nn as nn
  2. import torch.nn.functional as F
  3. from pointnet2_utils import PointNetSetAbstraction
  4. class get_model(nn.Module): #网络结构
  5. def __init__(self,num_class,normal_channel=True):
  6. super(get_model, self).__init__()
  7. in_channel = 6 if normal_channel else 3
  8. self.normal_channel = normal_channel
  9. # 三次特征提取特征层
  10. self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=in_channel, mlp=[64, 64, 128], group_all=False)
  11. self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False)
  12. self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True)
  13. self.fc1 = nn.Linear(1024, 512) #全链接层
  14. '''
  15. BatchNorm1d小批量数据归一化方法
  16. 1.加快网络的训练和收敛的速度
  17. 2.控制梯度爆炸和梯度消失
  18. 3.防止过拟合
  19. '''
  20. self.bn1 = nn.BatchNorm1d(512)
  21. '''
  22. nn.dropout()是为了防止或减轻过拟合而使用的函数,它一般用在全连接层
  23. Dropout就是在不同的训练过程中随机扔掉一部分神经元。
  24. 也就是让某个神经元的激活值以一定的概率p
  25. 让其停止工作,这次训练过程中不更新权值,也不参加神经网络的计算。
  26. '''
  27. self.drop1 = nn.Dropout(0.4)
  28. self.fc2 = nn.Linear(512, 256)
  29. self.bn2 = nn.BatchNorm1d(256)
  30. self.drop2 = nn.Dropout(0.4)
  31. self.fc3 = nn.Linear(256, num_class)
  32. def forward(self, xyz): #前向传播
  33. *********************************
  34. class get_loss(nn.Module): #损失函数
  35. *********************************

PointNetSetAbstraction 的代码在pointnet2_untils.py中的PointNetSetAbstraction类中。

文件中有3个类

PointNetSetAbstraction,PointNetSetAbstractionMsg,PointNetFeaturePropagation分别代表不用网络模型所用到的层结构。文件中的其他函数为类中的相关函数

  1. # PointNet网络,将每个区域的所有点变成一个特征 在输入网络之前,会把每个区域的坐标变成围绕中心点的相对坐标
  2. class PointNetSetAbstraction(nn.Module):
  3. def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
  4. super(PointNetSetAbstraction, self).__init__()
  5. self.npoint = npoint
  6. self.radius = radius
  7. self.nsample = nsample
  8. self.mlp_convs = nn.ModuleList()
  9. self.mlp_bns = nn.ModuleList()
  10. last_channel = in_channel
  11. for out_channel in mlp:
  12. self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
  13. self.mlp_bns.append(nn.BatchNorm2d(out_channel))
  14. last_channel = out_channel
  15. self.group_all = group_all
  16. def forward(self, xyz, points):
  17. """
  18. Input:
  19. xyz: input points position data, [B, C, N]
  20. points: input points data, [B, D, N]
  21. Return:
  22. new_xyz: sampled points position data, [B, C, S]
  23. new_points_concat: sample points feature data, [B, D', S]
  24. """
  25. xyz = xyz.permute(0, 2, 1)
  26. if points is not None:
  27. points = points.permute(0, 2, 1)
  28. if self.group_all:
  29. new_xyz, new_points = sample_and_group_all(xyz, points) #下采样特征点
  30. else:
  31. new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
  32. # new_xyz: sampled points position data, [B, npoint, C]
  33. # new_points: sampled points data, [B, npoint, nsample, C+D]
  34. new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
  35. for i, conv in enumerate(self.mlp_convs):
  36. bn = self.mlp_bns[i]
  37. new_points = F.relu(bn(conv(new_points)))
  38. new_points = torch.max(new_points, 2)[0]
  39. new_xyz = new_xyz.permute(0, 2, 1)
  40. return new_xyz, new_points

使用sample_and_group以达到选取中心点分局部区域的目的

  1. def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
  2. """
  3. Input:
  4. npoint:
  5. radius:
  6. nsample:
  7. xyz: input points position data, [B, N, 3]
  8. points: input points data, [B, N, D]
  9. Return:
  10. new_xyz: sampled points position data, [B, npoint, nsample, 3]
  11. new_points: sampled points data, [B, npoint, nsample, 3+D]
  12. """
  13. B, N, C = xyz.shape
  14. S = npoint
  15. fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] FPS法采样N个中心点
  16. new_xyz = index_points(xyz, fps_idx)
  17. idx = query_ball_point(radius, nsample, xyz, new_xyz) #根据N个中心点生成对应的局部区域
  18. grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
  19. grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) #坐标归一化
  20. if points is not None:
  21. grouped_points = index_points(points, idx)
  22. new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
  23. else:
  24. new_points = grouped_xyz_norm
  25. if returnfps:
  26. return new_xyz, new_points, grouped_xyz, fps_idx
  27. else:
  28. return new_xyz, new_points

相关的采样点等函数代码如下

  1. def square_distance(src, dst):
  2. """
  3. Calculate Euclid distance between each two points.
  4. src^T * dst = xn * xm + yn * ym + zn * zm;
  5. sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
  6. sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
  7. dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
  8. = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
  9. Input:
  10. src: source points, [B, N, C]
  11. dst: target points, [B, M, C]
  12. Output:
  13. dist: per-point square distance, [B, N, M]
  14. """
  15. B, N, _ = src.shape
  16. _, M, _ = dst.shape
  17. dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
  18. dist += torch.sum(src ** 2, -1).view(B, N, 1)
  19. dist += torch.sum(dst ** 2, -1).view(B, 1, M)
  20. return dist
  21. def index_points(points, idx):
  22. """
  23. Input:
  24. points: input points data, [B, N, C]
  25. idx: sample index data, [B, S]
  26. Return:
  27. new_points:, indexed points data, [B, S, C]
  28. """
  29. device = points.device
  30. B = points.shape[0]
  31. view_shape = list(idx.shape)
  32. view_shape[1:] = [1] * (len(view_shape) - 1)
  33. repeat_shape = list(idx.shape)
  34. repeat_shape[0] = 1
  35. batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
  36. new_points = points[batch_indices, idx, :]
  37. return new_points
  38. # FPS法采样N个中心点
  39. def farthest_point_sample(xyz, npoint):
  40. """
  41. Input:
  42. xyz: pointcloud data, [B, N, 3]
  43. npoint: number of samples
  44. Return:
  45. centroids: sampled pointcloud index, [B, npoint]
  46. """
  47. device = xyz.device
  48. B, N, C = xyz.shape
  49. centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) #采样点矩阵(B,npoint) npoint剩下的点
  50. distance = torch.ones(B, N).to(device) * 1e10 #采样点到所有点之间的距离(B,N)
  51. farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) #最远点
  52. batch_indices = torch.arange(B, dtype=torch.long).to(device) #batch_size数组
  53. #寻找最远点
  54. for i in range(npoint): #把剩下的点进行循环
  55. centroids[:, i] = farthest #更新最远点
  56. centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) #取出最远点的坐标
  57. dist = torch.sum((xyz - centroid) ** 2, -1) #计算到最远点的欧式距离
  58. mask = dist < distance
  59. distance[mask] = dist[mask] #更新最远点
  60. farthest = torch.max(distance, -1)[1] #返回最远点索引
  61. return centroids
  62. # 根据N个中心点生成对应的局部区域 这里使用到两个超参数 ,一个是每个区域中点的数量K,另一个是query的半径r。
  63. def query_ball_point(radius, nsample, xyz, new_xyz):
  64. """
  65. Input:
  66. radius: local region radius
  67. nsample: max sample number in local region
  68. xyz: all points, [B, N, 3]
  69. new_xyz: query points, [B, S, 3]
  70. Return:
  71. group_idx: grouped points index, [B, S, nsample]
  72. """
  73. device = xyz.device
  74. B, N, C = xyz.shape
  75. _, S, _ = new_xyz.shape
  76. group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
  77. sqrdists = square_distance(new_xyz, xyz)
  78. group_idx[sqrdists > radius ** 2] = N
  79. group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
  80. group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
  81. mask = group_idx == N
  82. group_idx[mask] = group_first[mask]
  83. return group_idx #返回若干个区域

3.2 语义分割

语义分割与分类的不同是除了将点云下采样得到特征向量后,还需要将点集上采样回原始点集数量,网络结构如下

  1. class get_model(nn.Module):
  2. def __init__(self, num_classes):
  3. super(get_model, self).__init__()
  4. # 4次下采样提取特征层
  5. self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 9 + 3, [32, 32, 64], False)
  6. self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False)
  7. self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False)
  8. self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False)
  9. # 4次上采样恢复点云
  10. self.fp4 = PointNetFeaturePropagation(768, [256, 256])
  11. self.fp3 = PointNetFeaturePropagation(384, [256, 256])
  12. self.fp2 = PointNetFeaturePropagation(320, [256, 128])
  13. self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
  14. self.conv1 = nn.Conv1d(128, 128, 1)
  15. self.bn1 = nn.BatchNorm1d(128)
  16. self.drop1 = nn.Dropout(0.5)
  17. self.conv2 = nn.Conv1d(128, num_classes, 1)

PointNetSetAbstraction 与分类的代码功能一样,对于语义分割多了上采样层

PointNetFeaturePropagation使用了分层的差值方法,代码如下

  1. class PointNetFeaturePropagation(nn.Module):
  2. def __init__(self, in_channel, mlp):
  3. super(PointNetFeaturePropagation, self).__init__()
  4. self.mlp_convs = nn.ModuleList()
  5. self.mlp_bns = nn.ModuleList()
  6. last_channel = in_channel
  7. for out_channel in mlp:
  8. self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
  9. self.mlp_bns.append(nn.BatchNorm1d(out_channel))
  10. last_channel = out_channel
  11. def forward(self, xyz1, xyz2, points1, points2):
  12. """
  13. Input:
  14. xyz1: input points position data, [B, C, N]
  15. xyz2: sampled input points position data, [B, C, S]
  16. points1: input points data, [B, D, N]
  17. points2: input points data, [B, D, S]
  18. Return:
  19. new_points: upsampled points data, [B, D', N]
  20. """
  21. xyz1 = xyz1.permute(0, 2, 1)
  22. xyz2 = xyz2.permute(0, 2, 1)
  23. points2 = points2.permute(0, 2, 1)
  24. B, N, C = xyz1.shape
  25. _, S, _ = xyz2.shape
  26. if S == 1:
  27. interpolated_points = points2.repeat(1, N, 1)
  28. else:
  29. dists = square_distance(xyz1, xyz2)
  30. dists, idx = dists.sort(dim=-1)
  31. dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
  32. dist_recip = 1.0 / (dists + 1e-8)
  33. norm = torch.sum(dist_recip, dim=2, keepdim=True)
  34. weight = dist_recip / norm
  35. interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
  36. if points1 is not None:
  37. points1 = points1.permute(0, 2, 1)
  38. new_points = torch.cat([points1, interpolated_points], dim=-1)
  39. else:
  40. new_points = interpolated_points
  41. new_points = new_points.permute(0, 2, 1)
  42. for i, conv in enumerate(self.mlp_convs):
  43. bn = self.mlp_bns[i]
  44. new_points = F.relu(bn(conv(new_points)))
  45. return new_points

四、测试运行

4.1 分类

test_classification.py代码如下

  1. """
  2. Author: Benny
  3. Date: Nov 2019
  4. """
  5. from data_utils.ModelNetDataLoader import ModelNetDataLoader
  6. import argparse
  7. import numpy as np
  8. import os
  9. import torch
  10. import logging
  11. from tqdm import tqdm
  12. import sys
  13. import importlib
  14. BASE_DIR = os.path.dirname(os.path.abspath(__file__))
  15. ROOT_DIR = BASE_DIR
  16. sys.path.append(os.path.join(ROOT_DIR, 'models'))
  17. def parse_args(): #初始化设置
  18. '''PARAMETERS'''
  19. parser = argparse.ArgumentParser('Testing')
  20. parser.add_argument('--use_cpu', action='store_true', default=False, help='use cpu mode')
  21. parser.add_argument('--gpu', type=str, default='0', help='specify gpu device')
  22. parser.add_argument('--batch_size', type=int, default=24, help='batch size in training')
  23. parser.add_argument('--num_category', default=40, type=int, choices=[10, 40], help='training on ModelNet10/40')
  24. parser.add_argument('--num_point', type=int, default=1024, help='Point Number')
  25. parser.add_argument('--log_dir', type=str, required=True, help='Experiment root')
  26. parser.add_argument('--use_normals', action='store_true', default=False, help='use normals')
  27. parser.add_argument('--use_uniform_sample', action='store_true', default=False, help='use uniform sampiling')
  28. parser.add_argument('--num_votes', type=int, default=3, help='Aggregate classification scores with voting')
  29. return parser.parse_args()
  30. def test(model, loader, num_class=40, vote_num=1): #5.测试
  31. mean_correct = []
  32. classifier = model.eval()
  33. class_acc = np.zeros((num_class, 3))
  34. for j, (points, target) in tqdm(enumerate(loader), total=len(loader)):
  35. if not args.use_cpu:
  36. points, target = points.cuda(), target.cuda()
  37. points = points.transpose(2, 1)
  38. vote_pool = torch.zeros(target.size()[0], num_class).cuda()
  39. for _ in range(vote_num):
  40. pred, _ = classifier(points)
  41. vote_pool += pred
  42. pred = vote_pool / vote_num
  43. pred_choice = pred.data.max(1)[1]
  44. for cat in np.unique(target.cpu()):
  45. classacc = pred_choice[target == cat].eq(target[target == cat].long().data).cpu().sum()
  46. class_acc[cat, 0] += classacc.item() / float(points[target == cat].size()[0])
  47. class_acc[cat, 1] += 1
  48. correct = pred_choice.eq(target.long().data).cpu().sum()
  49. mean_correct.append(correct.item() / float(points.size()[0]))
  50. class_acc[:, 2] = class_acc[:, 0] / class_acc[:, 1]
  51. class_acc = np.mean(class_acc[:, 2])
  52. instance_acc = np.mean(mean_correct)
  53. return instance_acc, class_acc
  54. def main(args):
  55. def log_string(str):
  56. logger.info(str)
  57. print(str)
  58. '''HYPER PARAMETER''' #1.选择设备
  59. os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
  60. '''CREATE DIR''' #2.创建日志保存路径
  61. experiment_dir = 'log/classification/' + args.log_dir
  62. '''LOG'''
  63. args = parse_args()
  64. logger = logging.getLogger("Model")
  65. logger.setLevel(logging.INFO)
  66. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  67. file_handler = logging.FileHandler('%s/eval.txt' % experiment_dir) #将命令保存到这个文件
  68. file_handler.setLevel(logging.INFO)
  69. file_handler.setFormatter(formatter)
  70. logger.addHandler(file_handler)
  71. log_string('PARAMETER ...')
  72. log_string(args)
  73. '''DATA LOADING''' #3.数据加载
  74. log_string('Load dataset ...')
  75. data_path = 'data/modelnet40_normal_resampled/'
  76. test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=False)
  77. testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=10)
  78. '''MODEL LOADING''' #4.加载模型
  79. num_class = args.num_category
  80. model_name = os.listdir(experiment_dir + '/logs')[0].split('.')[0] #获取网络结构文件
  81. print(model_name)
  82. model = importlib.import_module(model_name) #加载网络
  83. classifier = model.get_model(num_class, normal_channel=args.use_normals)
  84. if not args.use_cpu:
  85. classifier = classifier.cuda()
  86. checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') #加载模型权重
  87. classifier.load_state_dict(checkpoint['model_state_dict'])
  88. with torch.no_grad():
  89. instance_acc, class_acc = test(classifier.eval(), testDataLoader, vote_num=args.num_votes, num_class=num_class) #进行测试
  90. log_string('Test Instance Accuracy: %f, Class Accuracy: %f' % (instance_acc, class_acc))
  91. if __name__ == '__main__':
  92. args = parse_args()
  93. main(args)

分类官方使用的是ModelNet数据集,将下载好的 ModelNet40数据集放在./data/路径下

  1. # ModelNet40
  2. ## Select different models in ./models
  3. ## e.g., pointnet2_ssg without normal features
  4. python train_classification.py --model pointnet2_cls_ssg --log_dir pointnet2_cls_ssg
  5. python test_classification.py --log_dir pointnet2_cls_ssg
  6. ## e.g., pointnet2_ssg with normal features
  7. python train_classification.py --model pointnet2_cls_ssg --use_normals --log_dir pointnet2_cls_ssg_normal
  8. python test_classification.py --use_normals --log_dir pointnet2_cls_ssg_normal
  9. ## e.g., pointnet2_ssg with uniform sampling
  10. python train_classification.py --model pointnet2_cls_ssg --use_uniform_sample --log_dir pointnet2_cls_ssg_fps
  11. python test_classification.py --use_uniform_sample --log_dir pointnet2_cls_ssg_fps

--log_dir 后面跟的是网络模型的名称,也就是model文件夹下的文件。

4.2 语义分割

分割用的3D indoor数据集

  1. Download 3D indoor parsing dataset (**S3DIS**) [here](http://buildingparser.stanford.edu/dataset.html) and save in `data/s3dis/Stanford3dDataset_v1.2_Aligned_Version/`.
  2. ```
  3. cd data_utils
  4. python collect_indoor3d_data.py
  5. ```
  6. Processed data will save in `data/s3dis/stanford_indoor3d/`.
  7. ### Run
  8. ```
  9. ## Check model in ./models
  10. ## e.g., pointnet2_ssg
  11. python train_semseg.py --model pointnet2_sem_seg --test_area 5 --log_dir pointnet2_sem_seg
  12. python test_semseg.py --log_dir pointnet2_sem_seg --test_area 5 --visual

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/2023面试高手/article/detail/115910
推荐阅读
相关标签
  

闽ICP备14008679号

        
cppcmd=keepalive&