赞
踩
主代码
import argparse import os import torch import torch.nn.parallel import torch.utils.data from utils import to_categorical from collections import defaultdict from torch.autograd import Variable from data_utils.ShapeNetDataLoader import PartNormalDataset import torch.nn.functional as F import datetime import logging from pathlib import Path from utils import test_partseg from tqdm import tqdm from model.pointnet2 import PointNet2PartSeg_msg_one_hot from model.pointnet import PointNetDenseCls,PointNetLoss seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} seg_label_to_cat = {} # {0:Airplane, 1:Airplane, ...49:Table} for cat in seg_classes.keys(): for label in seg_classes[cat]: seg_label_to_cat[label] = cat def parse_args(): parser = argparse.ArgumentParser('PointNet2') parser.add_argument('--batchsize', type=int, default=8, help='input batch size') parser.add_argument('--workers', type=int, default=0, help='number of data loading workers') parser.add_argument('--epoch', type=int, default=4, help='number of epochs for training') parser.add_argument('--pretrain', type=str, default=None,help='whether use pretrain model') parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') parser.add_argument('--model_name', type=str, default='pointnet', help='Name of model') parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate for training') parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay') parser.add_argument('--optimizer', type=str, default='Adam', help='type of optimizer') parser.add_argument('--multi_gpu', type=str, default=None, help='whether use multi gpu training') parser.add_argument('--jitter', default=False, help="randomly jitter point cloud") parser.add_argument('--step_size', type=int, default=20, help="randomly rotate point cloud") return parser.parse_args() def main(args): #创建文件夹 # os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu if args.multi_gpu is None else '0,1,2,3' '''CREATE DIR''' experiment_dir = Path('./experiment/') experiment_dir.mkdir(exist_ok=True) file_dir = Path(str(experiment_dir) +'/%sPartSeg-'%args.model_name + str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M'))) file_dir.mkdir(exist_ok=True) checkpoints_dir = file_dir.joinpath('checkpoints/') checkpoints_dir.mkdir(exist_ok=True) log_dir = file_dir.joinpath('logs/') log_dir.mkdir(exist_ok=True) '''LOG''' #使用logging args = parse_args() logger = logging.getLogger(args.model_name)#设置logger 记录器 logger.setLevel(logging.INFO)#设置等级 formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')#设置输出的布局 file_handler = logging.FileHandler(str(log_dir) + '/train_%s_partseg.txt'%args.model_name)#设置handler处理器 file_handler.setLevel(logging.INFO) file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.info('---------------------------------------------------TRANING---------------------------------------------------') logger.info('PARAMETER ...') logger.info(args) norm = True if args.model_name == 'pointnet' else False #数据集加载 TRAIN_DATASET = PartNormalDataset(npoints=2048, split='trainval',normalize=norm, jitter=args.jitter) dataloader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batchsize,shuffle=True, num_workers=int(args.workers)) TEST_DATASET = PartNormalDataset(npoints=2048, split='test',normalize=norm,jitter=False) testdataloader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=8,shuffle=True, num_workers=int(args.workers)) print("The number of training data is:",len(TRAIN_DATASET)) logger.info("The number of training data is:%d",len(TRAIN_DATASET)) print("The number of test data is:", len(TEST_DATASET)) logger.info("The number of test data is:%d", len(TEST_DATASET)) num_classes = 16 num_part = 50 blue = lambda x: '\033[94m' + x + '\033[0m' model = PointNet2PartSeg_msg_one_hot(num_part) if args.model_name == 'pointnet2'else PointNetDenseCls(cat_num=num_classes,part_num=num_part) if args.pretrain is not None: model.load_state_dict(torch.load(args.pretrain)) print('load model %s'%args.pretrain) logger.info('load model %s'%args.pretrain) else: print('Training from scratch') logger.info('Training from scratch') pretrain = args.pretrain init_epoch = int(pretrain[-14:-11]) if args.pretrain is not None else 0 if args.optimizer == 'SGD': optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) elif args.optimizer == 'Adam': optimizer = torch.optim.Adam( model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=args.decay_rate ) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=0.5)#调整学习率的方法,根据epoches的数量调整学习率,step_size (int)学习率衰减期,gamma (float) – 学习率衰减的乘积因子。默认值:-0.1。。 '''GPU selection and multi-GPU''' # if args.multi_gpu is not None: # device_ids = [int(x) for x in args.multi_gpu.split(',')] # torch.backends.cudnn.benchmark = True # model.cuda(device_ids[0]) # model = torch.nn.DataParallel(model, device_ids=device_ids) # else: # # model.cuda() # model() criterion = PointNetLoss() LEARNING_RATE_CLIP = 1e-5 history = defaultdict(lambda: list())#记录每次loss best_acc = 0 best_class_avg_iou = 0 best_inctance_avg_iou = 0 for epoch in range(init_epoch,args.epoch): scheduler.step()#修改ir lr = max(optimizer.param_groups[0]['lr'],LEARNING_RATE_CLIP) print('Learning rate:%f' % lr) for param_group in optimizer.param_groups: param_group['lr'] = lr for i, data in tqdm(enumerate(dataloader, 0),total=len(dataloader),smoothing=0.9): points, label, target, norm_plt = data ''' points(8,2048,3) label(8,1) target(8,2048) norm_plt(8,2048,3) ''' # print(data.shape()) # print(points) # print(label) # print(target) # print(norm_plt) points, label, target = Variable(points.float()),Variable(label.long()), Variable(target.long())#转torch points = points.transpose(2, 1)#norm_plt(8,3,2048) norm_plt = norm_plt.transpose(2, 1)#norm_plt(8,3,2048) # points, label, target,norm_plt = points.cuda(),label.squeeze().cuda(), target.cuda(), norm_plt.cuda() points, label, target, norm_plt = points, label.squeeze(), target, norm_plt optimizer.zero_grad() model = model.train() if args.model_name == 'pointnet': labels_pred, seg_pred, trans_feat = model(points, to_categorical(label, 16)) ''' labels_pred (8,16) seg_pred (8,2048,50) trans_feat ''' seg_pred = seg_pred.contiguous().view(-1, num_part)#(8*2048,50) target = target.view(-1, 1)[:, 0]#(8*2048) loss, seg_loss, label_loss = criterion(labels_pred, label, seg_pred, target, trans_feat) else: seg_pred = model(points, norm_plt, to_categorical(label, 16)) seg_pred = seg_pred.contiguous().view(-1, num_part) target = target.view(-1, 1)[:, 0] loss = F.nll_loss(seg_pred, target) history['loss'].append(loss.cpu().data.numpy()) loss.backward() optimizer.step() forpointnet2 = args.model_name == 'pointnet2' test_metrics, test_hist_acc, cat_mean_iou = test_partseg(model.eval(), testdataloader, seg_label_to_cat,50,forpointnet2)#每完成测试集测试 #test_metrics字典记录了验证的标准 #test_hist_acc历史的语义分割的准确度 #cat_mean_iou验证出它的交叉比 print('Epoch %d %s accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( epoch, blue('test'), test_metrics['accuracy'],test_metrics['class_avg_iou'],test_metrics['inctance_avg_iou'])) #记录日志 logger.info('Epoch %d %s Accuracy: %f Class avg mIOU: %f Inctance avg mIOU: %f' % ( epoch, blue('test'), test_metrics['accuracy'],test_metrics['class_avg_iou'],test_metrics['inctance_avg_iou'])) #保存最好的模型 if test_metrics['accuracy'] > best_acc: best_acc = test_metrics['accuracy'] torch.save(model.state_dict(), '%s/%s_%.3d_%.4f.pth' % (checkpoints_dir,args.model_name, epoch, best_acc)) logger.info(cat_mean_iou) logger.info('Save model..') print('Save model..') print(cat_mean_iou) if test_metrics['class_avg_iou'] > best_class_avg_iou: best_class_avg_iou = test_metrics['class_avg_iou'] if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: best_inctance_avg_iou = test_metrics['inctance_avg_iou'] print('Best accuracy is: %.5f'%best_acc) logger.info('Best accuracy is: %.5f'%best_acc) print('Best class avg mIOU is: %.5f'%best_class_avg_iou) logger.info('Best class avg mIOU is: %.5f'%best_class_avg_iou) print('Best inctance avg mIOU is: %.5f'%best_inctance_avg_iou) logger.info('Best inctance avg mIOU is: %.5f'%best_inctance_avg_iou) if __name__ == '__main__': args = parse_args() main(args)
数据加载部分(注解还在跟进)
# *_*coding:utf-8 *_* import os import json import warnings import numpy as np from torch.utils.data import Dataset warnings.filterwarnings('ignore') def pc_normalize(pc): centroid = np.mean(pc, axis=0) pc = pc - centroid m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) pc = pc / m return pc def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): """ Randomly jitter points. jittering is per point. Input: BxNx3 array, original batch of point clouds Return: BxNx3 array, jittered batch of point clouds """ N, C = batch_data.shape assert(clip > 0) jittered_data = np.clip(sigma * np.random.randn(N, C), -1*clip, clip) jittered_data += batch_data return jittered_data class PartNormalDataset(Dataset): def __init__(self, npoints=2500, split='train', normalize=True, jitter=False): self.npoints = npoints self.root = r'G:\Pointnet_Pointnet2_pytorch-master\data\ShapeNet' self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')#加入目录 self.cat = {} self.normalize = normalize self.jitter = jitter with open(self.catfile, 'r') as f: for line in f: ls = line.strip().split()#以空格为分界化成list self.cat[ls[0]] = ls[1]#创建一个字典 self.cat = {k: v for k, v in self.cat.items()}#形成一个大字典 # print(self.cat) self.meta = {} with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f: train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f: val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f: test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) for item in self.cat: # print('category', item) self.meta[item] = [] dir_point = os.path.join(self.root, self.cat[item]) fns = sorted(os.listdir(dir_point)) # print(fns[0][0:-4]) if split == 'trainval': fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] elif split == 'train': fns = [fn for fn in fns if fn[0:-4] in train_ids] elif split == 'val': fns = [fn for fn in fns if fn[0:-4] in val_ids] elif split == 'test': fns = [fn for fn in fns if fn[0:-4] in test_ids] else: print('Unknown split: %s. Exiting..' % (split)) exit(-1) # print(os.path.basename(fns)) for fn in fns: token = (os.path.splitext(os.path.basename(fn))[0]) self.meta[item].append(os.path.join(dir_point, token + '.txt'))#为每个类找到源文件地址 self.datapath = [] for item in self.cat: for fn in self.meta[item]: self.datapath.append((item, fn)) if split == 'trainval': self.datapath = self.datapath[:2000] # print(len(self.datapath)) # print(len(self.datapath)) self.classes = dict(zip(self.cat, range(len(self.cat)))) # Mapping from category ('Chair') to a list of int [10,11,12,13] as segmentation labels self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} for cat in sorted(self.seg_classes.keys()): print(cat, self.seg_classes[cat]) self.cache = {} # from index to (point_set, cls, seg) tuple self.cache_size = 20000 #上一步的得到所有数据集的地址信息 def __getitem__(self, index): #每个照片的数据 if index in self.cache: point_set, normal, seg, cls = self.cache[index] else: fn = self.datapath[index] cat = self.datapath[index][0] cls = self.classes[cat] cls = np.array([cls]).astype(np.int32) data = np.loadtxt(fn[1]).astype(np.float32) point_set = data[:, 0:3] normal = data[:, 3:6] seg = data[:, -1].astype(np.int32) if len(self.cache) < self.cache_size: self.cache[index] = (point_set, normal, seg, cls) if self.normalize: point_set = pc_normalize(point_set) if self.jitter: jitter_point_cloud(point_set) choice = np.random.choice(len(seg), self.npoints, replace=True) # resample point_set = point_set[choice, :] seg = seg[choice] normal = normal[choice, :] return point_set,cls, seg, normal def __len__(self): return len(self.datapath) if __name__ == '__main__': TRAIN_DATASET = PartNormalDataset(npoints=2048, split='trainval',normalize=True, jitter=False)
网络结构部分
import torch import torch.nn as nn import torch.nn.parallel import torch.utils.data from torch.autograd import Variable import numpy as np import torch.nn.functional as F class STN3d(nn.Module): def __init__(self): super(STN3d, self).__init__() # 这里需要注意的是上文提到的MLP均由卷积结构完成 # 比如说将3维映射到64维,其利用64个1x3的卷积核 self.conv1 = torch.nn.Conv1d(3, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 1024, 1) self.fc1 = nn.Linear(1024, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, 9) self.relu = nn.ReLU() self.bn1 = nn.BatchNorm1d(64) self.bn2 = nn.BatchNorm1d(128) self.bn3 = nn.BatchNorm1d(1024) self.bn4 = nn.BatchNorm1d(512) self.bn5 = nn.BatchNorm1d(256) def forward(self, x): batchsize = x.size()[0] x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 1024) x = F.relu(self.bn4(self.fc1(x))) x = F.relu(self.bn5(self.fc2(x))) x = self.fc3(x) iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat( batchsize, 1) if x.is_cuda: iden = iden.cuda() x = x + iden x = x.view(-1, 3, 3) return x class STNkd(nn.Module): def __init__(self, k=64): super(STNkd, self).__init__() self.conv1 = torch.nn.Conv1d(k, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 1024, 1) self.fc1 = nn.Linear(1024, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, k * k) self.relu = nn.ReLU() self.bn1 = nn.BatchNorm1d(64) self.bn2 = nn.BatchNorm1d(128) self.bn3 = nn.BatchNorm1d(1024) self.bn4 = nn.BatchNorm1d(512) self.bn5 = nn.BatchNorm1d(256) self.k = k def forward(self, x): batchsize = x.size()[0] x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 1024) x = F.relu(self.bn4(self.fc1(x))) x = F.relu(self.bn5(self.fc2(x))) x = self.fc3(x) iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat( batchsize, 1) if x.is_cuda: iden = iden.cuda() x = x + iden x = x.view(-1, self.k, self.k) return x class PointNetEncoder(nn.Module): def __init__(self, global_feat=True, feature_transform=False, semseg = False): super(PointNetEncoder, self).__init__() self.stn = STN3d() if not semseg else STNkd(k=9) self.conv1 = torch.nn.Conv1d(3, 64, 1) if not semseg else torch.nn.Conv1d(9, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 1024, 1) self.bn1 = nn.BatchNorm1d(64) self.bn2 = nn.BatchNorm1d(128) self.bn3 = nn.BatchNorm1d(1024) self.global_feat = global_feat self.feature_transform = feature_transform if self.feature_transform: self.fstn = STNkd(k=64) def forward(self, x): n_pts = x.size()[2] trans = self.stn(x) x = x.transpose(2, 1) x = torch.bmm(x, trans) x = x.transpose(2, 1) x = F.relu(self.bn1(self.conv1(x))) if self.feature_transform: trans_feat = self.fstn(x) x = x.transpose(2, 1) x = torch.bmm(x, trans_feat) x = x.transpose(2, 1) else: trans_feat = None pointfeat = x x = F.relu(self.bn2(self.conv2(x))) x = self.bn3(self.conv3(x)) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 1024) if self.global_feat: return x, trans, trans_feat else: x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) return torch.cat([x, pointfeat], 1), trans, trans_feat class PointNetDenseCls(nn.Module): def __init__(self, cat_num=16,part_num=50): super(PointNetDenseCls, self).__init__() self.cat_num = cat_num self.part_num = part_num self.stn = STN3d() self.conv1 = torch.nn.Conv1d(3, 64, 1) self.conv2 = torch.nn.Conv1d(64, 128, 1) self.conv3 = torch.nn.Conv1d(128, 128, 1) self.conv4 = torch.nn.Conv1d(128, 512, 1) self.conv5 = torch.nn.Conv1d(512, 2048, 1) self.bn1 = nn.BatchNorm1d(64) self.bn2 = nn.BatchNorm1d(128) self.bn3 = nn.BatchNorm1d(128) self.bn4 = nn.BatchNorm1d(512) self.bn5 = nn.BatchNorm1d(2048) self.fstn = STNkd(k=128) # classification network self.fc1 = nn.Linear(2048, 256) self.fc2 = nn.Linear(256, 256) self.fc3 = nn.Linear(256, cat_num) self.dropout = nn.Dropout(p=0.3) self.bnc1 = nn.BatchNorm1d(256) self.bnc2 = nn.BatchNorm1d(256) # segmentation network self.convs1 = torch.nn.Conv1d(4944, 256, 1) self.convs2 = torch.nn.Conv1d(256, 256, 1) self.convs3 = torch.nn.Conv1d(256, 128, 1) self.convs4 = torch.nn.Conv1d(128, part_num, 1) self.bns1 = nn.BatchNorm1d(256) self.bns2 = nn.BatchNorm1d(256) self.bns3 = nn.BatchNorm1d(128) def forward(self, point_cloud,label): batchsize,_ , n_pts = point_cloud.size() # point_cloud_transformed trans = self.stn(point_cloud) point_cloud = point_cloud.transpose(2, 1) point_cloud_transformed = torch.bmm(point_cloud, trans) point_cloud_transformed = point_cloud_transformed.transpose(2, 1) # MLP out1 = F.relu(self.bn1(self.conv1(point_cloud_transformed))) out2 = F.relu(self.bn2(self.conv2(out1))) out3 = F.relu(self.bn3(self.conv3(out2))) # net_transformed trans_feat = self.fstn(out3) x = out3.transpose(2, 1) net_transformed = torch.bmm(x, trans_feat) net_transformed = net_transformed.transpose(2, 1) # MLP out4 = F.relu(self.bn4(self.conv4(net_transformed))) out5 = self.bn5(self.conv5(out4)) out_max = torch.max(out5, 2, keepdim=True)[0] out_max = out_max.view(-1, 2048) # classification network net = F.relu(self.bnc1(self.fc1(out_max))) net = F.relu(self.bnc2(self.dropout(self.fc2(net)))) net = self.fc3(net) # [B,16] # segmentation network out_max = torch.cat([out_max,label],1) expand = out_max.view(-1, 2048+16, 1).repeat(1, 1, n_pts) concat = torch.cat([expand, out1, out2, out3, out4, out5], 1) net2 = F.relu(self.bns1(self.convs1(concat))) net2 = F.relu(self.bns2(self.convs2(net2))) net2 = F.relu(self.bns3(self.convs3(net2))) net2 = self.convs4(net2) net2 = net2.transpose(2, 1).contiguous() net2 = F.log_softmax(net2.view(-1, self.part_num), dim=-1) net2 = net2.view(batchsize, n_pts, self.part_num) # [B, N 50] return net, net2, trans_feat def feature_transform_reguliarzer(trans): d = trans.size()[1] I = torch.eye(d)[None, :, :] if trans.is_cuda: I = I.cuda() loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2))) return loss class PointNetSeg(nn.Module): def __init__(self,num_class,feature_transform=False, semseg = False): super(PointNetSeg, self).__init__() self.k = num_class self.feat = PointNetEncoder(global_feat=False,feature_transform=feature_transform, semseg = semseg) self.conv1 = torch.nn.Conv1d(1088, 512, 1) self.conv2 = torch.nn.Conv1d(512, 256, 1) self.conv3 = torch.nn.Conv1d(256, 128, 1) self.conv4 = torch.nn.Conv1d(128, self.k, 1) self.bn1 = nn.BatchNorm1d(512) self.bn1_1 = nn.BatchNorm1d(1024) self.bn2 = nn.BatchNorm1d(256) self.bn3 = nn.BatchNorm1d(128) def forward(self, x): batchsize = x.size()[0] n_pts = x.size()[2] x, trans, trans_feat = self.feat(x) x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) x = self.conv4(x) x = x.transpose(2,1).contiguous() x = F.log_softmax(x.view(-1,self.k), dim=-1) x = x.view(batchsize, n_pts, self.k) return x, trans_feat if __name__ == '__main__': #判断网络是否正确 point = torch.randn(8,3,1024) label = torch.randn(8,16) model = PointNetDenseCls() net, net2, trans_feat = model(point,label) print('net',net.shape) print('net2',net2.shape) print('trans_feat',trans_feat.shape)
(理清里面的数据走向,你就懂了)
(B,)
损失函数
class PointNetLoss(torch.nn.Module):
def __init__(self, weight=1,mat_diff_loss_scale=0.001):
super(PointNetLoss, self).__init__()
self.mat_diff_loss_scale = mat_diff_loss_scale
self.weight = weight
def forward(self, labels_pred, label, seg_pred,seg, trans_feat):
seg_loss = F.nll_loss(seg_pred, seg))#语义分割loss
mat_diff_loss = feature_transform_reguliarzer(trans_feat)#正交损失
label_loss = F.nll_loss(labels_pred, label)#分类loss
loss = self.weight * seg_loss + (1-self.weight) * label_loss + mat_diff_loss * self.mat_diff_loss_scale
return loss, seg_loss, label_loss
测试集代码
def compute_cat_iou(pred,target,iou_tabel): iou_list = [] target = target.cpu().data.numpy() for j in range(pred.size(0)): batch_pred = pred[j] batch_target = target[j] batch_choice = batch_pred.data.max(1)[1].cpu().data.numpy() for cat in np.unique(batch_target): # intersection = np.sum((batch_target == cat) & (batch_choice == cat)) # union = float(np.sum((batch_target == cat) | (batch_choice == cat))) # iou = intersection/union if not union ==0 else 1 I = np.sum(np.logical_and(batch_choice == cat, batch_target == cat)) U = np.sum(np.logical_or(batch_choice == cat, batch_target == cat)) if U == 0: iou = 1 # If the union of groundtruth and prediction points is empty, then count part IoU as 1 else: iou = I / float(U) iou_tabel[cat,0] += iou iou_tabel[cat,1] += 1 iou_list.append(iou) return iou_tabel,iou_list def test_partseg(model, loader, catdict, num_classes = 50,forpointnet2=False): ''' catdict = {0:Airplane, 1:Airplane, ...49:Table} ''' iou_tabel = np.zeros((len(catdict),3)) iou_list = [] metrics = defaultdict(lambda:list()) hist_acc = [] # mean_correct = [] for batch_id, (points, label, target, norm_plt) in tqdm(enumerate(loader), total=len(loader), smoothing=0.9): batchsize, num_point,_= points.size() points, label, target, norm_plt = Variable(points.float()),Variable(label.long()), Variable(target.long()),Variable(norm_plt.float()) points = points.transpose(2, 1) norm_plt = norm_plt.transpose(2, 1) # points, label, target, norm_plt = points.cuda(), label.squeeze().cuda(), target.cuda(), norm_plt.cuda() points, label, target, norm_plt = points, label.squeeze(), target, norm_plt if forpointnet2: seg_pred = model(points, norm_plt, to_categorical(label, 16)) else: labels_pred, seg_pred, _ = model(points,to_categorical(label,16))#(B,16)/(B,2048,50) # labels_pred_choice = labels_pred.data.max(1)[1] # labels_correct = labels_pred_choice.eq(label.long().data).cpu().sum() # mean_correct.append(labels_correct.item() / float(points.size()[0])) # print(pred.size()) iou_tabel, iou = compute_cat_iou(seg_pred,target,iou_tabel)#计算交叉比 iou_list+=iou # shape_ious += compute_overall_iou(pred, target, num_classes) seg_pred = seg_pred.contiguous().view(-1, num_classes)#(B*2048,50) target = target.view(-1, 1)[:, 0]#(B*2048) pred_choice = seg_pred.data.max(1)[1] correct = pred_choice.eq(target.data).cpu().sum()#得到分割的准确数 metrics['accuracy'].append(correct.item()/ (batchsize * num_point))#每次bitch得到准确率 iou_tabel[:,2] = iou_tabel[:,0] /iou_tabel[:,1] hist_acc += metrics['accuracy'] metrics['accuracy'] = np.mean(hist_acc) metrics['inctance_avg_iou'] = np.mean(iou_list) iou_tabel = pd.DataFrame(iou_tabel,columns=['iou','count','mean_iou']) iou_tabel['Category_IOU'] = [catdict[i] for i in range(len(catdict)) ] cat_iou = iou_tabel.groupby('Category_IOU')['mean_iou'].mean() metrics['class_avg_iou'] = np.mean(cat_iou)#该类的IOU return metrics, hist_acc, cat_iou
补:原作者代码https://github.com/yanx27/Pointnet_Pointnet2_pytorch
Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。