当前位置:   article > 正文

论文代码-PointNet++_mask=dist

mask=dist

PointNet++代码解读

引用说明

文章中的代码全部来自于Github仓库: Pointnet_Pointnet2_pytorch

本文更关注语义分割以及零件分割部分代码

工具方法&工具类

来自于models文件夹下的pointnet2_utils.py文件

  • 点云归一化
    归一化有利于损失函数梯度下降过程的求解
# 很简单就是求出质心然后用所有点对于质心距离的最大值作为缩放标准
def pc_normalize(pc):
    l = pc.shape[0]  # 没用,不知道为啥有这句
    centroid = np.mean(pc, axis=0)  # axis=0是对每列求均值 也就是xyz的均值
    pc = pc - centroid  # 先做减法 x0-x y0-y z0-z
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))  # 上一步的矩阵平方之后 axis=1是对每行求和 开根号不就是距离了么,找出距离的最大值为m
    pc = pc / m  # 缩放
    return pc
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 最远点采样
# xyz原始点集 npoint采样出来的点索引
# 将xyz分为selected和left两个集合,每次迭代从left中选出一个点这个点要求是距离selected集合最远的点,直到selected的数量为n
def farthest_point_sample(xyz, npoint):
    device = xyz.device  # 为了cuda计算
    B, N, C = xyz.shape  # B是batch N是xyz原来的点个数 C是3或6
    # 这个就是最终要返回的npoint,n个point的索引,初始化为0
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    # distance一开始我很难理解,现看下面的再回来看。这里记录了每个batch中N个点距离当前已经采样到的点的一个最小距离,在这里取最大值作为下一个最远点
    distance = torch.ones(B, N).to(device) * 1e10
    # 这个为每一个batch随机初始化了一个当前的最远点的索引,内容为[0,N) B个
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    # 初始化一个batch的索引向量吧
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):  # 开始迭代
        centroids[:, i] = farthest  # 每次先读取当前最远点索引到centroids中 第i个最远点
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) # 每次先根据索引值取出点的xyz坐标
        dist = torch.sum((xyz - centroid) ** 2, -1)  # 这里求的是每个点到当前最远点的欧氏距离平方
        mask = dist < distance  # 技巧,将dist中小于distance的都更新到distance中
        distance[mask] = dist[mask]  # 那么distance记录的是每个batch中所有点距离已经出现点的最小距离
        farthest = torch.max(distance, -1)[1]  # 取出最小距离中的最大值作为下一个最远点
    return centroids
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 球查询
# 用于查询在空间球体内部的点,输入为半径,最大采样点个数,原始全部点,查询点(一般是球的中心点)
def query_ball_point(radius, nsample, xyz, new_xyz):
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape  # S个中心点
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)  # 用到了欧氏距离函数补充里有介绍,这里计算了查询点和全部点的两两距离
    group_idx[sqrdists > radius ** 2] = N  # 如果距离大于半径,那么不管了直接将idx设置为N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]  # 升序排序 取前面nsample个就行
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])  # 这里将每个group中第一个点的值提取出来然后构造成一个B,S,nsample的维度,方便后面mask操作
    # 下面的目的就是由于可能采样数不足,球形空间内没有nsample个点,因此剩余的为N的都用第一个点的值代替
    # 密度小到周围没有一个点的时候应该全都是自己本身了
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

PointNet++ encoder

接下来的部分就是图中的两个操作

  • 采样分组
# 有点像CNN中的卷积核的意思,将点云分成不同的小部分然后可以通过pointnet提取特征
# 输入分别为:最远点采样的点个数,半径,每个球体内部要采样的点个数,原始点集(只有位置信息3维),点集(可以包含特征信息D维),是否返回最远点采样的索引和分组坐标
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # 最远点采样出来的索引
    new_xyz = index_points(xyz, fps_idx)  # 采样出来的坐标
    idx = query_ball_point(radius, nsample, xyz, new_xyz)  # 球查询出的索引
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 从xyz中查出分组的坐标
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)  # 求相对坐标了这里,改为相对于中心点的坐标值了,局部坐标系

    # 这里points的存在是因为可能迭代整个采样分组过程,如图就是两次,所以可能点云中不仅仅是xyz数据,还会有别的特征维度,因此如果有特征维度那么进行拼接
    if points is not None:
        grouped_points = index_points(points, idx)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 拼接
    else:
        new_points = grouped_xyz_norm
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_idx
    else:
        return new_xyz, new_points

# 这直接对全部数据进行一个分组分为一个组即npoint=1,其他没有区别
def sample_and_group_all(xyz, points):
    device = xyz.device
    B, N, C = xyz.shape
    new_xyz = torch.zeros(B, 1, C).to(device)
    grouped_xyz = xyz.view(B, 1, N, C)
    if points is not None:
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 集合抽象
    这里也有两种方法,一种普通的就是堆叠采样分组和对每个分组的点做PointNet提取特征,这里看下对于非均匀点云的MSG方法
# MSG是在球查询过程中用多个半径查询,然后赋予不同权重得出一个结果,对于不均匀的点云有鲁棒性
class PointNetSetAbstractionMsg(nn.Module):
    def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list):
        super(PointNetSetAbstractionMsg, self).__init__()
        # 这三项定义的是采样分组层的参数
        self.npoint = npoint
        self.radius_list = radius_list  # 半径是个list,存不同的半径
        self.nsample_list = nsample_list  # 不同半径有不同的采样个数
        # MLP层的初始化
        self.conv_blocks = nn.ModuleList()
        self.bn_blocks = nn.ModuleList()
        for i in range(len(mlp_list)):
            convs = nn.ModuleList()
            bns = nn.ModuleList()
            last_channel = in_channel + 3  # 这里对于输入通道增加3?是为了法向量统一? TODO
            for out_channel in mlp_list[i]:
                # 使用二维的1*1卷积来做,而不是PointNet的一维卷积
                convs.append(nn.Conv2d(last_channel, out_channel, 1))
                bns.append(nn.BatchNorm2d(out_channel))  # 二维batch_normal
                last_channel = out_channel
            self.conv_blocks.append(convs)
            self.bn_blocks.append(bns)

    def forward(self, xyz, points):
        xyz = xyz.permute(0, 2, 1)  # 首先将维度交换一下顺序
        if points is not None:
            points = points.permute(0, 2, 1)

        B, N, C = xyz.shape
        S = self.npoint
        new_xyz = index_points(xyz, farthest_point_sample(xyz, S))
        new_points_list = []  # 这里存储不同r下面提取的特征
        for i, radius in enumerate(self.radius_list):  # enumerate 下标和内容
            K = self.nsample_list[i]  # 在球形空间中取K个
            group_idx = query_ball_point(radius, K, xyz, new_xyz) # 定义好的球查询
            grouped_xyz = index_points(xyz, group_idx)
            grouped_xyz -= new_xyz.view(B, S, 1, C) # 相对坐标
            if points is not None:  # 这里还是如果有特征就连接,没有直接复制
                grouped_points = index_points(points, group_idx)
                grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1)
            else:
                grouped_points = grouped_xyz

            grouped_points = grouped_points.permute(0, 3, 2, 1)  # [B, D, K, S] 这里是为了二维卷积更方便
            for j in range(len(self.conv_blocks[i])):
                # 1*1的二维卷积每个group为一个通道那么就有S个通道
                # 这里相当于对D做1*1的卷积 为啥要对输入通道+3输出呢还是不明白?TODO
                conv = self.conv_blocks[i][j]  
                bn = self.bn_blocks[i][j]
                grouped_points =  F.relu(bn(conv(grouped_points))) # MLP
            new_points = torch.max(grouped_points, 2)[0]  # [B, D', S] 最大池化获得group的全局特征
            new_points_list.append(new_points)

        new_xyz = new_xyz.permute(0, 2, 1)  # 维度还回去
        new_points_concat = torch.cat(new_points_list, dim=1)  # 将几个r的特征连接起来 好像没有权重
        return new_xyz, new_points_concat
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 反向插值&跳连接

反向插值:结构图中Segmentation部分,红色为N个点第一个蓝色为N1个点。反向插值的意思就是,在蓝色层的坐标系中,对于每一个点找到红色层中距离他最近的几个点,如何得到这个点的特征呢?将距离最近的这几个点的特征加权求和即可。
跳连接:经过反向插值得到的特则很其实是global的,跳连接是为了得到局部特征。在特征抽象的过程中其实已经有蓝色层这些点的特征了,那么直接将抽象过程的特征和反向插值后的特征做一个concat就有了一个既有global又有located特征的张量。

# 差值之后对每个点做MLP获得特征
class PointNetFeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp: # 这边主要就是一个MLP的定义
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))  # 这时候又是一维卷积了
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, points1, points2):
        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)

        points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)  # 如果只有一个点,那么直接复制N个
        else:
            dists = square_distance(xyz1, xyz2)  # 计算欧氏距离的平方
            dists, idx = dists.sort(dim=-1)  # 距离排序
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3] 构造结构

            dist_recip = 1.0 / (dists + 1e-8)  # 防止分母为0 取距离的倒数构造权重,距离越小权重越高
            norm = torch.sum(dist_recip, dim=2, keepdim=True)
            weight = dist_recip / norm  # 归一化
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) # 加权求和得到全局特征

        if points1 is not None:  # encoder中的局部特征如果有,就和全局特征拼接
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1)
        else:
            new_points = interpolated_points

        new_points = new_points.permute(0, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points = F.relu(bn(conv(new_points)))  # MLP
        return new_points
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 补充
    • square_distance(src, dst)根据其英文注释可以很清楚的看出这是求点集与点集之间欧式距离的平方的函数(没有开根号)
    • index_points(points, idx)在points中取出索引为idx的点 points是带batch维度的数据

零件分割网络

代码来自于models/pointnet2_part_seg_msg.py

class get_model(nn.Module):
    def __init__(self, num_classes, normal_channel=False): # 这里传入的num_classes就是总的部件数量
        super(get_model, self).__init__()
        # 首先如果有法向量 增加三个通道
        if normal_channel:
            additional_channel = 3
        else:
            additional_channel = 0
        self.normal_channel = normal_channel
        # 第一个set abstraction 512个组 每个组多个半径多个采样数 这次第一个输入通道就是3或6很好理解
        self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 3+additional_channel, [[32, 32, 64], [64, 64, 128], [64, 96, 128]])
        # 第二个 在512个组中挑128个 128+128+64 是输入通道,因为上一个sa有三个半径拼接后是64 + 128 + 128
        self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]])
        # 最后一个直接group_all=True分成一个组 输入维度256+256+3 3是xyz
        self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True)
        # MLP层
        self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256]) # 1536 = 1024 + 512 -> 256
        self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128]) # 576 = 320 + 256 -> 128
        self.fp1 = PointNetFeaturePropagation(in_channel=150+additional_channel, mlp=[128, 128]) # 150=128+16+6 这个+6 是 3+3 看forward中有解释 128是 fp2的输出通道
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(0.5)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz, cls_label):
        # Set Abstraction layers
        B,C,N = xyz.shape
        if self.normal_channel:
            l0_points = xyz
            l0_xyz = xyz[:,:3,:]
        else:
            l0_points = xyz
            l0_xyz = xyz
        l1_xyz, l1_points = self.sa1(l0_xyz, l0_points)  # 这里输出是320
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)  # 这里输出是512
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)  # 这里输出是1024
        # Feature Propagation layers
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)  # 512 + 1024输入 输出是 256
        l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)  # 320 + 256输入 输出是 128
        cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N)  # 输入的点云分类个数 16类 one-hot key
        l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points)  # 16+3+3 + 128 输入 输出是 128
        # FC layers
        feat = F.relu(self.bn1(self.conv1(l0_points))) # 对128做个MLP提取最终分类结果
        x = self.drop1(feat) # dropout层
        x = self.conv2(x)  # 128 -> 50个类别
        x = F.log_softmax(x, dim=1)  # 还是log_softmax
        x = x.permute(0, 2, 1)
        return x, l3_points

class get_loss(nn.Module):
    def __init__(self):
        super(get_loss, self).__init__()

    def forward(self, pred, target, trans_feat):
        total_loss = F.nll_loss(pred, target)  # 没啥说的还是nll_loss
        return total_loss
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

训练&测试

代码来自于train_partseg.py,测试代码和训练的验证部分差不多,不赘述

# 路径定义
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOT_DIR = BASE_DIR
sys.path.append(os.path.join(ROOT_DIR, 'models'))

# key是class标签 value是part的列表
seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43],
               'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37],
               'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49],
               'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]}
# 将上面的定义转换成 {0:Airplane, 1:Airplane, ...49:Table} 的样子
seg_label_to_cat = {}  # key是part的标签 value是class标签 
for cat in seg_classes.keys():
    for label in seg_classes[cat]:
        seg_label_to_cat[label] = cat

# 节省显存的做法 用inplace
def inplace_relu(m):
    classname = m.__class__.__name__
    if classname.find('ReLU') != -1:
        m.inplace=True

# 对于物体类别 class 做一个one-hot编码 partnet默认16个类别
def to_categorical(y, num_classes):
    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
    if (y.is_cuda):
        return new_y.cuda()
    return new_y


def parse_args():
    parser = argparse.ArgumentParser('Model')
    parser.add_argument('--model', type=str, default='pointnet_part_seg', help='model name')  # 使用msg还是ssg
    parser.add_argument('--batch_size', type=int, default=16, help='batch Size during training')  # 批次大小
    parser.add_argument('--epoch', default=251, type=int, help='epoch to run')  # 循环几次
    parser.add_argument('--learning_rate', default=0.001, type=float, help='initial learning rate')  # 学习率
    parser.add_argument('--gpu', type=str, default='0', help='specify GPU devices')  # 指定GPU编号
    parser.add_argument('--optimizer', type=str, default='Adam', help='Adam or SGD')  # 优化方法
    parser.add_argument('--log_dir', type=str, default=None, help='log path')  # log的存储地址
    parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay')  # 权重衰退
    parser.add_argument('--npoint', type=int, default=2048, help='point Number')  # 采样值
    parser.add_argument('--normal', action='store_true', default=False, help='use normals')  # 是否有法向量
    parser.add_argument('--step_size', type=int, default=20, help='decay step for lr decay')  # 步长
    parser.add_argument('--lr_decay', type=float, default=0.5, help='decay rate for lr decay')  # 学习率衰减指数

    return parser.parse_args()


def main(args):
    def log_string(str):
        logger.info(str)
        print(str)

    '''HYPER PARAMETER 超参数'''
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu

    '''CREATE DIR 新建需要的文件目录,log和模型weight的存储'''
    timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))
    exp_dir = Path('./log/')
    exp_dir.mkdir(exist_ok=True)
    exp_dir = exp_dir.joinpath('part_seg')
    exp_dir.mkdir(exist_ok=True)
    if args.log_dir is None:
        exp_dir = exp_dir.joinpath(timestr)
    else:
        exp_dir = exp_dir.joinpath(args.log_dir)
    exp_dir.mkdir(exist_ok=True)
    checkpoints_dir = exp_dir.joinpath('checkpoints/')
    checkpoints_dir.mkdir(exist_ok=True)
    log_dir = exp_dir.joinpath('logs/')
    log_dir.mkdir(exist_ok=True)

    '''LOG'''
    args = parse_args()
    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    log_string('PARAMETER ...')
    log_string(args)

    root = 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/'  # 根目录

    # 获取训练&测试数据
    TRAIN_DATASET = PartNormalDataset(root=root, npoints=args.npoint, split='trainval', normal_channel=args.normal)
    trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
    TEST_DATASET = PartNormalDataset(root=root, npoints=args.npoint, split='test', normal_channel=args.normal)
    testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10)
    log_string("The number of training data is: %d" % len(TRAIN_DATASET))
    log_string("The number of test data is: %d" % len(TEST_DATASET))

    num_classes = 16  # 分类
    num_part = 50  # part个数

    '''MODEL LOADING'''
    MODEL = importlib.import_module(args.model)
    shutil.copy('models/%s.py' % args.model, str(exp_dir))
    shutil.copy('models/pointnet2_utils.py', str(exp_dir))

    classifier = MODEL.get_model(num_part, normal_channel=args.normal).cuda()
    criterion = MODEL.get_loss().cuda()
    classifier.apply(inplace_relu)  # 用inplace来relu 节省显存

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            torch.nn.init.xavier_normal_(m.weight.data)
            torch.nn.init.constant_(m.bias.data, 0.0)
        elif classname.find('Linear') != -1:
            torch.nn.init.xavier_normal_(m.weight.data)
            torch.nn.init.constant_(m.bias.data, 0.0)

    try:  # 是否有现有模型,如果有就直接使用pretrain
        checkpoint = torch.load(str(exp_dir) + '/checkpoints/best_model.pth')
        start_epoch = checkpoint['epoch']
        classifier.load_state_dict(checkpoint['model_state_dict'])
        log_string('Use pretrain model')
    except:  # 没有就从头初始化
        log_string('No existing model, starting training from scratch...')
        start_epoch = 0
        classifier = classifier.apply(weights_init)
    # 优化方法的设置
    if args.optimizer == 'Adam':
        optimizer = torch.optim.Adam(
            classifier.parameters(),
            lr=args.learning_rate,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=args.decay_rate
        )
    else:
        optimizer = torch.optim.SGD(classifier.parameters(), lr=args.learning_rate, momentum=0.9)
    # bn的动量 为了权重优化
    def bn_momentum_adjust(m, momentum):
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
            m.momentum = momentum

    LEARNING_RATE_CLIP = 1e-5
    MOMENTUM_ORIGINAL = 0.1
    MOMENTUM_DECCAY = 0.5
    MOMENTUM_DECCAY_STEP = args.step_size

    # 这里为了存储训练结果
    best_acc = 0
    global_epoch = 0
    best_class_avg_iou = 0
    best_inctance_avg_iou = 0

    for epoch in range(start_epoch, args.epoch):
        mean_correct = []

        log_string('Epoch %d (%d/%s):' % (global_epoch + 1, epoch + 1, args.epoch))
        '''Adjust learning rate and BN momentum'''
        lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)  # 动态调整学习率
        log_string('Learning rate:%f' % lr)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))  # 动态调整动量
        if momentum < 0.01:
            momentum = 0.01
        print('BN momentum updated to: %f' % momentum)
        classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
        classifier = classifier.train() # 调用模型的训练函数

        '''learning one epoch'''
        for i, (points, label, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
            optimizer.zero_grad()  # 梯度清零
            # 点云数据预处理
            points = points.data.numpy()
            points[:, :, 0:3] = provider.random_scale_point_cloud(points[:, :, 0:3])  # 随机缩放点云,默认0.8-1.25
            points[:, :, 0:3] = provider.shift_point_cloud(points[:, :, 0:3])  # 随机移动点云 -0.1到0.1
            points = torch.Tensor(points)
            points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda() # 分成点,预测标签,真是标签
            points = points.transpose(2, 1)  # 换一下顺序

            seg_pred, trans_feat = classifier(points, to_categorical(label, num_classes))  #forward函数 返回的是 seg_pred 是那个log_softmax值,和分类值
            seg_pred = seg_pred.contiguous().view(-1, num_part)
            target = target.view(-1, 1)[:, 0]
            pred_choice = seg_pred.data.max(1)[1]  # 找到最大的那个作为part的预测值

            correct = pred_choice.eq(target.data).cpu().sum()  # 正确率
            mean_correct.append(correct.item() / (args.batch_size * args.npoint))  # 平均正确率
            loss = criterion(seg_pred, target, trans_feat)  # 损失 nll_loss 值
            loss.backward()  # 反向传播
            optimizer.step()  # 更新一次参数

        train_instance_acc = np.mean(mean_correct)  # 平均准确率
        log_string('Train accuracy is: %.5f' % train_instance_acc)
        # 接下来是评估训练,类似于测试
        with torch.no_grad():  # 测试无需记录梯度
            test_metrics = {}
            total_correct = 0
            total_seen = 0
            total_seen_class = [0 for _ in range(num_part)]
            total_correct_class = [0 for _ in range(num_part)]
            shape_ious = {cat: [] for cat in seg_classes.keys()}  # 类别16个
            seg_label_to_cat = {}  # {0:Airplane, 1:Airplane, ...49:Table}

            for cat in seg_classes.keys():
                for label in seg_classes[cat]:
                    seg_label_to_cat[label] = cat

            classifier = classifier.eval() # 这里就不是train了就是评估了

            for batch_id, (points, label, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
                cur_batch_size, NUM_POINT, _ = points.size()  # current batch size 超参数 然后某个点云中的点有NUM_POINT个
                points, label, target = points.float().cuda(), label.long().cuda(), target.long().cuda()
                points = points.transpose(2, 1)
                seg_pred, _ = classifier(points, to_categorical(label, num_classes))  # 不需要分类结果了这里
                cur_pred_val = seg_pred.cpu().data.numpy()  # [BATCH_SIZE,NUM_POINT,PART_NUM]
                cur_pred_val_logits = cur_pred_val
                cur_pred_val = np.zeros((cur_batch_size, NUM_POINT)).astype(np.int32)
                target = target.cpu().data.numpy()  # [BATCH_SIZE, NUM_POINT]

                for i in range(cur_batch_size):
                    cat = seg_label_to_cat[target[i, 0]]
                    logits = cur_pred_val_logits[i, :, :]
                    cur_pred_val[i, :] = np.argmax(logits[:, seg_classes[cat]], 1) + seg_classes[cat][0]

                correct = np.sum(cur_pred_val == target)
                total_correct += correct
                total_seen += (cur_batch_size * NUM_POINT)

                for l in range(num_part):
                    total_seen_class[l] += np.sum(target == l)
                    total_correct_class[l] += (np.sum((cur_pred_val == l) & (target == l)))

                for i in range(cur_batch_size):
                    segp = cur_pred_val[i, :]
                    segl = target[i, :]
                    cat = seg_label_to_cat[segl[0]]
                    # part_iou
                    part_ious = [0.0 for _ in range(len(seg_classes[cat]))]
                    for l in seg_classes[cat]:
                        if (np.sum(segl == l) == 0) and (
                                np.sum(segp == l) == 0):  # part is not present, no prediction as well
                            part_ious[l - seg_classes[cat][0]] = 1.0
                        else:
                            part_ious[l - seg_classes[cat][0]] = np.sum((segl == l) & (segp == l)) / float(
                                np.sum((segl == l) | (segp == l)))
                    # shape_iou
                    shape_ious[cat].append(np.mean(part_ious))  # 每个类别的平均iou

            all_shape_ious = []
            for cat in shape_ious.keys():
                for iou in shape_ious[cat]:
                    all_shape_ious.append(iou)
                shape_ious[cat] = np.mean(shape_ious[cat])
            mean_shape_ious = np.mean(list(shape_ious.values()))
            test_metrics['accuracy'] = total_correct / float(total_seen)
            test_metrics['class_avg_accuracy'] = np.mean(
                np.array(total_correct_class) / np.array(total_seen_class, dtype=np.float))
            for cat in sorted(shape_ious.keys()):
                log_string('eval mIoU of %s %f' % (cat + ' ' * (14 - len(cat)), shape_ious[cat]))
            test_metrics['class_avg_iou'] = mean_shape_ious
            test_metrics['inctance_avg_iou'] = np.mean(all_shape_ious)

        log_string('Epoch %d test Accuracy: %f  Class avg mIOU: %f   Inctance avg mIOU: %f' % (
            epoch + 1, test_metrics['accuracy'], test_metrics['class_avg_iou'], test_metrics['inctance_avg_iou']))
        if (test_metrics['inctance_avg_iou'] >= best_inctance_avg_iou):  # 如果性能更好,那么替换现有模型
            logger.info('Save model...')
            savepath = str(checkpoints_dir) + '/best_model.pth'
            log_string('Saving at %s' % savepath)
            state = {
                'epoch': epoch,
                'train_acc': train_instance_acc,
                'test_acc': test_metrics['accuracy'],
                'class_avg_iou': test_metrics['class_avg_iou'],
                'inctance_avg_iou': test_metrics['inctance_avg_iou'],
                'model_state_dict': classifier.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }
            torch.save(state, savepath)
            log_string('Saving model....')

        if test_metrics['accuracy'] > best_acc:
            best_acc = test_metrics['accuracy']
        if test_metrics['class_avg_iou'] > best_class_avg_iou:
            best_class_avg_iou = test_metrics['class_avg_iou']
        if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou:
            best_inctance_avg_iou = test_metrics['inctance_avg_iou']
        log_string('Best accuracy is: %.5f' % best_acc)
        log_string('Best class avg mIOU is: %.5f' % best_class_avg_iou)
        log_string('Best inctance avg mIOU is: %.5f' % best_inctance_avg_iou)
        global_epoch += 1


if __name__ == '__main__':
    args = parse_args()  # 解析出参数,然后根据参数训练
    main(args)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293

跑自己的数据

# TODO
  • 1
声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/正经夜光杯/article/detail/766028
推荐阅读