当前位置:   article > 正文

Pointnet++ 网络结构以及代码实现_pointnet++代码

pointnet++代码

前言:

pointnet++是在pointnet的基础上发展而来的,而pointnet对于局部结构的识别能力有所缺陷,从pointnet的网络我们也可以看出,pointnet(如图一)是对整体的特征进行了maxpooling操作,忽略了局部特征,而pointnet++采用了一个叫深度的层次特征学习模式以提高局部结构的识别能力。
具体细节还请参考论文:

PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space

图一:

pointnet++的网络结构

 一、分类任务

1.1 分层抽取特征(set abstraction)

1.1.1sampling:在点云中采样作为中心点,采用最远点采样法(farthest point sampling

点云的数量为N,批次为B,需要采取npoint个中心点,xyz的第三个维度代表点云集的xyz空间坐标数据

最远点采样的步骤为:

1.初始化中心点centroids为[B,npoints]维度的全0张量,初始化距离distance全为10的10次方维度为[B,N]的张量,farthest初始化为随机从N个点选取的一个点。

2.首先先将随机初始点为最远值作为第一个中心点,然后就算点云中每个点与第一个中心点的距离,存在dist的中,这里采用的是欧式距离,公式举例为(x1-x2)**2+(y1-y2)**2+(z1-z2)**2,就是两点各个坐标的差值平方和。然后将此距离与distance做比较,将距离张量dist中小于distance中对应位置的值的距离更新到distance张量中。取distance中的最大值作为最远值点,centroids中更新为储存着第一个和第二个中心点,然后重复上面操作,依次以更新后的farthest点作为中心点,计算距离,取样,直到取到npoint个数的点为止。

  1. def farthest_point_sample(xyz, npoint):
  2. """
  3. Input:
  4. xyz: pointcloud data, [B, N, 3]
  5. npoint: number of samples
  6. Return:
  7. centroids: sampled pointcloud index, [B, npoint]
  8. """
  9. device = xyz.device
  10. B, N, C = xyz.shape
  11. centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
  12. distance = torch.ones(B, N).to(device) * 1e10
  13. farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
  14. #torch.randint函数会生成指定范围内的随机整数,并返回一个张量
  15. batch_indices = torch.arange(B, dtype=torch.long).to(device)
  16. for i in range(npoint):
  17. centroids[:, i] = farthest
  18. #随机值farthest作为为点云集合的中心点
  19. centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
  20. #选最远点的点作为质心,形状为B,1,3
  21. dist = torch.sum((xyz - centroid) ** 2, -1)
  22. #采用欧式距离计算每个点与质点的距离,最后dist的形状是B,N
  23. mask = dist < distance
  24. #标记哪些点小于distance,mask是个B,N的布尔码数组,标记了小于distance的信息
  25. distance[mask] = dist[mask]
  26. #dist[mask]是个一维数组,含有对应true顺序的dist的数值从而与distance的数值更新
  27. #将距离张量dist中小于distance中对应位置的值的距离更新到distance张量中,
  28. #从而更新每个点到采样点的距离。
  29. farthest = torch.max(distance, -1)[1]
  30. #farthest是distance中与质心最远点的点的索引
  31. return centroids

1.1.2grouping ,分组层,找距离中心点附近最近的K个点,组成local points region。这样的话就可以更加关注点云的局部信息,具体操作如下:

1.在query_ball_point函数中将上一步每一个采样的中心点需要以它们为中心采样周围的点组成成一个group,中心点与它group里面的其他点假设都在一个球体内,中心点为质心,计算其他点与质心的的距离,将不在球内的点(距离大于r平方的点)标记为N,然后选取离他最近的nsample个点为同一个组的采样点。如果质心附近点云稀疏的话(不够nsample个采样点),则将第一个点复制,将前nsample中不满足条件的点替换为第一个点,同样取样nsample个点。最后返回采样group的索引

2.根据已经提取出来的group_idx,在points(所有的点云数据集)中提取出,new_xyz,new_points,这些为points的子集,为每一个中心点采取一个group的集合,new_xyz的最后一维只包含xyz等空间信息,而new_points的最后一维包含其他特征,比如法向量nx,ny,nz

  1. def query_ball_point(radius, nsample, xyz, new_xyz):
  2. """
  3. Input:
  4. radius: local region radius
  5. nsample: max sample number in local region
  6. xyz: all points, [B, N, 3]
  7. new_xyz: query points, [B, S, 3]
  8. Return:
  9. group_idx: grouped points index, [B, S, nsample]
  10. """
  11. device = xyz.device
  12. B, N, C = xyz.shape
  13. _, S, _ = new_xyz.shape
  14. group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
  15. #将[0,....,N-1]先用view变化为(1,1,N)相当于1行N列,然后将第一维度上复制B次,第二维度(行)复制S次,
  16. #第三维度复制1次,最后是B,S,N的形状
  17. sqrdists = square_distance(new_xyz, xyz)
  18. #sqrdists: [B, S, N] 记录中心点与所有点之间的欧氏距离
  19. group_idx[sqrdists > radius ** 2] = N
  20. #为了处理未找到有效邻域点的情况,并对应于球形邻域搜索中的点筛选操作,将距离大于半径的邻域点排除在外。
  21. group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
  22. #对最后一个维度采用升序排序排序,选出距离最近的nsample个点,形状B,S,nsample
  23. group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
  24. #考虑到一个group不足nsample个点,用第一个点复制代替
  25. #得到的group_first张量是一个形状为[B, S, nsample]的张量,其中每个元素表示每个查询点的第一个邻域点的索引
  26. mask = group_idx == N
  27. group_idx[mask] = group_first[mask]
  28. #对于在nsample内若存在大于半径球内的N点值,则将大于group的点替换成第一个点,最后返回group的索引
  29. return group_idx
  30. def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
  31. """
  32. Input:
  33. npoint:
  34. radius:
  35. nsample:
  36. xyz: input points position data, [B, N, 3]
  37. points: input points data, [B, N, D]
  38. Return:
  39. new_xyz: sampled points position data, [B, npoint, nsample, 3]
  40. new_points: sampled points data, [B, npoint, nsample, 3+D]
  41. """
  42. B, N, C = xyz.shape
  43. S = npoint
  44. fps_idx = farthest_point_sample(xyz, npoint) #获取了最远采样的几个点的索引[B, npoint]
  45. new_xyz = index_points(xyz, fps_idx) #获取最远点采样点[B,npoint,C]
  46. idx = query_ball_point(radius, nsample, xyz, new_xyz) #获取每个中心点采样nsample个点的下标[B,npoint,nsample]的索引
  47. grouped_xyz = index_points(xyz, idx) # 获取所有采样的点的分组[B,npoint,nsample,C]
  48. grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) #每个group点减去质心的坐标
  49. if points is not None:
  50. grouped_points = index_points(points, idx)
  51. new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1)
  52. # 最后一个特征维度进行拼接[B, npoint, nsample, C+D]
  53. else:
  54. new_points = grouped_xyz_norm
  55. if returnfps:
  56. return new_xyz, new_points, grouped_xyz, fps_idx
  57. else:
  58. return new_xyz, new_points
  59. def sample_and_group_all(xyz, points):
  60. """
  61. Input:
  62. xyz: input points position data, [B, N, 3]
  63. points: input points data, [B, N, D]
  64. Return:
  65. new_xyz: sampled points position data, [B, 1, 3]
  66. new_points: sampled points data, [B, 1, N, 3+D]
  67. """
  68. #直接将所有点作为一个group,即增加一个长度为1的维度而已
  69. device = xyz.device
  70. B, N, C = xyz.shape
  71. new_xyz = torch.zeros(B, 1, C).to(device)
  72. # new_xyz代表中心点,用原点表示
  73. grouped_xyz = xyz.view(B, 1, N, C)
  74. # grouped_xyz减去中心点:每个区域的点减去区域的中心值,由于中心点为原点,所以结果仍然是grouped_xyz
  75. if points is not None:
  76. new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
  77. # view(B, 1, N, -1),-1代表自动计算,即结果等于view(B, 1, N, D)
  78. else:
  79. new_points = grouped_xyz
  80. return new_xyz, new_points

以上函数还使用了一个index_points的函数如下:

主要功能是可以用batch_indices,以及idx(两个维度必须匹配),根据点云索引从点云集中抽取出特定的点云数据。关于这个索引方法可以看看numpy的整数索引方法。

  1. def index_points(points, idx):
  2. """
  3. Input:
  4. points: input points data, [B, N, C]
  5. idx: sample index data, [B, S]
  6. Return:
  7. new_points:, indexed points data, [B, S, C]
  8. """
  9. device = points.device
  10. B = points.shape[0]
  11. view_shape = list(idx.shape)
  12. view_shape[1:] = [1] * (len(view_shape) - 1)
  13. #view_shape[1:]=[s]然后把[1]赋给[s],变为[B,1]
  14. repeat_shape = list(idx.shape)
  15. repeat_shape[0] = 1
  16. #repeat_shape形状为[1,S]
  17. batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
  18. #arrange生成[0, ..., B - 1], view后变为列向量[B, 1], repeat后[B, S]
  19. new_points = points[batch_indices, idx, :]
  20. # 从points中取出每个batch_indices对应索引的数据点
  21. return new_points

1.1.3特征提取层

将上面进行过采样以及分组处理后的点进行pointnet网络,这样一来,pointnet就可以关注到局部的细节,需要进行两次set abstraction的提取,下面是set abstraion的代码:

  1. class PointNetSetAbstraction(nn.Module):
  2. def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
  3. super(PointNetSetAbstraction, self).__init__()
  4. self.npoint = npoint
  5. self.radius = radius
  6. self.nsample = nsample
  7. self.mlp_convs = nn.ModuleList()
  8. self.mlp_bns = nn.ModuleList()
  9. last_channel = in_channel
  10. for out_channel in mlp:
  11. self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
  12. self.mlp_bns.append(nn.BatchNorm2d(out_channel))
  13. last_channel = out_channel
  14. self.group_all = group_all
  15. def forward(self, xyz, points):
  16. """
  17. Input:
  18. xyz: input points position data, [B, C, N]
  19. points: input points data, [B, D, N]
  20. Return:
  21. new_xyz: sampled points position data, [B, C, S]
  22. new_points_concat: sample points feature data, [B, D', S]
  23. """
  24. xyz = xyz.permute(0, 2, 1)
  25. if points is not None:
  26. points = points.permute(0, 2, 1)
  27. if self.group_all:
  28. new_xyz, new_points = sample_and_group_all(xyz, points)
  29. else:
  30. new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
  31. # new_xyz: sampled points position data, [B, npoint, C]
  32. # new_points: sampled points data, [B, npoint, nsample, C+D]
  33. new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
  34. for i, conv in enumerate(self.mlp_convs):
  35. bn = self.mlp_bns[i]
  36. new_points = F.relu(bn(conv(new_points)))
  37. #经过多层感知机以及maxpooling,相当于局部pointnet
  38. new_points = torch.max(new_points, 2)[0]
  39. new_xyz = new_xyz.permute(0, 2, 1)
  40. return new_xyz, new_points

分类任务中,两层set abstraion层后再接一个pointnet,得到一个关于全局的特征张量,然后通过多层感知机变化通道数,最后经过softmax输出各类别的概率。

二、分割任务

分割需要对每一个点进行分类,在前面的步骤中经过采样分组和pointnet已经将点云进行了下采样,所以分割任务中需要将特征上采样进行还原到以前的维度。作者提出了一种基于距离插值的分层特征传播(Feature Propagation)策略,从网络图看,先是将第一次经过pointnet的特征(我们当他当layer2层)与第二次经过pointnet的特征(我们把它称layer3层),做距离差值,然后还原到第一次pointnet后的特征维度,然后与没做过pointnet的层继续做距离差值。

以下是距离差值的公式:

 代码如下,实现逻辑是:

  1. def square_distance(src, dst):
  2. """
  3. Calculate Euclid distance between each two points.
  4. src^T * dst = xn * xm + yn * ym + zn * zm;
  5. sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
  6. sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
  7. dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
  8. = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
  9. Input:
  10. src: source points, [B, N, C]
  11. dst: target points, [B, M, C]
  12. Output:
  13. dist: per-point square distance, [B, N, M]
  14. """
  15. B, N, _ = src.shape
  16. _, M, _ = dst.shape
  17. dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
  18. #乘法运算实际上是计算了两个向量之间的内积。
  19. dist += torch.sum(src ** 2, -1).view(B, N, 1)
  20. dist += torch.sum(dst ** 2, -1).view(B, 1, M)
  21. return dist

首先是计算layer2与layer3每个点之间的距离,然后进行升序排列,取靠的最近的layer3层三个点作为距离差值的点,取这三个距离的倒数相加,接着得出权值,用特征与权值相乘,得到差值后的新的点的特征值。产生的新特征与上一层的特征进行cat操作,再通过卷积等完成特征融合。

  1. class PointNetFeaturePropagation(nn.Module):
  2. def __init__(self, in_channel, mlp):
  3. super(PointNetFeaturePropagation, self).__init__()
  4. self.mlp_convs = nn.ModuleList()
  5. self.mlp_bns = nn.ModuleList()
  6. last_channel = in_channel
  7. for out_channel in mlp:
  8. self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
  9. self.mlp_bns.append(nn.BatchNorm1d(out_channel))
  10. last_channel = out_channel
  11. def forward(self, xyz1, xyz2, points1, points2):
  12. """
  13. Input:
  14. xyz1: input points position data, [B, C, N]
  15. xyz2: sampled input points position data, [B, C, S]
  16. points1: input points data, [B, D, N]
  17. points2: input points data, [B, D, S]
  18. Return:
  19. new_points: upsampled points data, [B, D', N]
  20. """
  21. xyz1 = xyz1.permute(0, 2, 1)
  22. xyz2 = xyz2.permute(0, 2, 1)
  23. points2 = points2.permute(0, 2, 1)
  24. B, N, C = xyz1.shape
  25. _, S, _ = xyz2.shape
  26. if S == 1:
  27. interpolated_points = points2.repeat(1, N, 1)
  28. #如果只有一个点,将复制N份上采样
  29. else:
  30. dists = square_distance(xyz1, xyz2)
  31. #计算layer2的xyz1的点与layer 3 的xyz2的点之间的距离,形状为[B,N,S]
  32. dists, idx = dists.sort(dim=-1)
  33. dists, idx = dists[:, :, :3], idx[:, :, :3]
  34. #然后将距离按照行的维度升序排列,也就是排列后可得每个N点离s个点最近的点,取三个最近点维度成为 [B, N, 3]
  35. dist_recip = 1.0 / (dists + 1e-8)
  36. #取距离的倒数,对应论文中的 Wi(x),然后将每行的三个距离的倒数相加
  37. norm = torch.sum(dist_recip, dim=2, keepdim=True)
  38. weight = dist_recip / norm
  39. #计算权重,离得近的点权重大。 两者相除就是每个距离占总和的比重 也就是weight
  40. interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
  41. #index_points(points2, idx),points2的维度为[B,S,D],idx的维度为[B,N,3],函数中batch_indices为[B,N,3],
  42. #最后得到的维度为[B,N,3,D],weight的维度view为[B,N,3,1]
  43. if points1 is not None:
  44. points1 = points1.permute(0, 2, 1)
  45. new_points = torch.cat([points1, interpolated_points], dim=-1)
  46. else:
  47. new_points = interpolated_points
  48. new_points = new_points.permute(0, 2, 1)
  49. for i, conv in enumerate(self.mlp_convs):
  50. bn = self.mlp_bns[i]
  51. new_points = F.relu(bn(conv(new_points)))
  52. return new_points

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/笔触狂放9/article/detail/101777?site
推荐阅读
相关标签
  

闽ICP备14008679号