当前位置:   article > 正文

PointNet - 2复现语义分割网络:Windows + PyTorch + S3DIS语义分割 + 代码_for i, (points, target) in tqdm(enumerate(traindat

for i, (points, target) in tqdm(enumerate(traindataloader), total=len(traind

一、平台

Windows 10

GPU RTX 3090 + CUDA 11.1 + cudnn 8.9.6

Python 3.9

Torch 1.9.1 + cu111

所用的原始代码:https://github.com/yanx27/Pointnet_Pointnet2_pytorch

二、数据

Stanford3dDataset_v1.2_Aligned_Version

三、代码

分享给有需要的人,代码质量勿喷。

对源代码进行了简化和注释。

分割结果保存成txt,或者利用 laspy 生成点云。

别问为啥在C盘,问就是2T的三星980Pro

3.1 文件组织结构

3.2 数据预处理

3.2.1 run_collect_indoor3d_data.py 生成*.npy文件

改了路径

3.2.2 indoor3d_util.py

改了路径

3.2.3 S3DISDataLoader.py

改了路径

3.3 训练 train_SematicSegmentation.py

  1. # 参考
  2. # https://github.com/yanx27/Pointnet_Pointnet2_pytorch
  3. # 先在Terminal运行:python -m visdom.server
  4. # 再运行本文件
  5. import argparse
  6. import os
  7. # import datetime
  8. import logging
  9. import importlib
  10. import shutil
  11. from tqdm import tqdm
  12. import numpy as np
  13. import time
  14. import visdom
  15. import torch
  16. import warnings
  17. warnings.filterwarnings('ignore')
  18. from dataset.S3DISDataLoader import S3DISDataset
  19. from PointNet2 import dataProcess
  20. # PointNet
  21. from PointNet2.pointnet_sem_seg import get_model as PNss
  22. from PointNet2.pointnet_sem_seg import get_loss as PNloss
  23. # PointNet++
  24. from PointNet2.pointnet2_sem_seg import get_model as PN2SS
  25. from PointNet2.pointnet2_sem_seg import get_loss as PN2loss
  26. # True为PointNet++
  27. PN2bool = True
  28. # PN2bool = False
  29. # 当前文件的路径
  30. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  31. # 训练输出模型的路径: PointNet
  32. dirModel1 = ROOT_DIR + '/trainModel/pointnet_model'
  33. if not os.path.exists(dirModel1):
  34. os.makedirs(dirModel1)
  35. # 训练输出模型的路径
  36. dirModel2 = ROOT_DIR + '/trainModel/PointNet2_model'
  37. if not os.path.exists(dirModel2):
  38. os.makedirs(dirModel2)
  39. # 日志的路径
  40. pathLog = os.path.join(ROOT_DIR, 'LOG_train.txt')
  41. # 数据集的路径
  42. pathDataset = os.path.join(ROOT_DIR, 'dataset/stanford_indoor3d/')
  43. # 分类的类别
  44. classNumber = 13
  45. classes = ['ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', 'table', 'chair', 'sofa', 'bookcase',
  46. 'board', 'clutter']
  47. class2label = {cls: i for i, cls in enumerate(classes)}
  48. seg_classes = class2label
  49. seg_label_to_cat = {}
  50. for i, cat in enumerate(seg_classes.keys()):
  51. seg_label_to_cat[i] = cat
  52. # 日志和输出
  53. def log_string(str):
  54. logger.info(str)
  55. print(str)
  56. def inplace_relu(m):
  57. classname = m.__class__.__name__
  58. if classname.find('ReLU') != -1:
  59. m.inplace=True
  60. def parse_args():
  61. parser = argparse.ArgumentParser('Model')
  62. parser.add_argument('--pnModel', type=bool, default=True, help='True = PointNet++;False = PointNet')
  63. parser.add_argument('--batch_size', type=int, default=32, help='Batch Size during training [default: 32]')
  64. parser.add_argument('--epoch', default=320, type=int, help='Epoch to run [default: 32]')
  65. parser.add_argument('--learning_rate', default=0.001, type=float, help='Initial learning rate [default: 0.001]')
  66. parser.add_argument('--GPU', type=str, default='0', help='GPU to use [default: GPU 0]')
  67. parser.add_argument('--optimizer', type=str, default='Adam', help='Adam or SGD [default: Adam]')
  68. parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay [default: 1e-4]')
  69. parser.add_argument('--npoint', type=int, default=4096, help='Point Number [default: 4096]')
  70. parser.add_argument('--step_size', type=int, default=10, help='Decay step for lr decay [default: every 10 epochs]')
  71. parser.add_argument('--lr_decay', type=float, default=0.7, help='Decay rate for lr decay [default: 0.7]')
  72. parser.add_argument('--test_area', type=int, default=5, help='Which area to use for test, option: 1-6 [default: 5]')
  73. return parser.parse_args()
  74. if __name__ == '__main__':
  75. # python -m visdom.server
  76. visdomTL = visdom.Visdom()
  77. visdomTLwindow = visdomTL.line([0], [0], opts=dict(title='train_loss'))
  78. visdomVL = visdom.Visdom()
  79. visdomVLwindow = visdomVL.line([0], [0], opts=dict(title='validate_loss'))
  80. visdomTVL = visdom.Visdom(env='PointNet++')
  81. # region 创建日志文件
  82. logger = logging.getLogger("train")
  83. logger.setLevel(logging.INFO)
  84. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  85. file_handler = logging.FileHandler(pathLog)
  86. file_handler.setLevel(logging.INFO)
  87. file_handler.setFormatter(formatter)
  88. logger.addHandler(file_handler)
  89. #endregion
  90. #region 超参数
  91. args = parse_args()
  92. args.pnModel = PN2bool
  93. log_string('------------ hyper-parameter ------------')
  94. log_string(args)
  95. # 指定GPU
  96. os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
  97. pointNumber = args.npoint
  98. batchSize = args.batch_size
  99. #endregion
  100. # region dataset
  101. # train data
  102. trainData = S3DISDataset(split='train',
  103. data_root=pathDataset, num_point=pointNumber,
  104. test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None)
  105. trainDataLoader = torch.utils.data.DataLoader(trainData, batch_size=batchSize, shuffle=True, num_workers=0,
  106. pin_memory=True, drop_last=True,
  107. worker_init_fn=lambda x: np.random.seed(x + int(time.time())))
  108. # Validation data
  109. testData = S3DISDataset(split='test',
  110. data_root=pathDataset, num_point=pointNumber,
  111. test_area=args.test_area, block_size=1.0, sample_rate=1.0, transform=None)
  112. testDataLoader = torch.utils.data.DataLoader(testData, batch_size=batchSize, shuffle=False, num_workers=0,
  113. pin_memory=True, drop_last=True)
  114. log_string("The number of training data is: %d" % len(trainData))
  115. log_string("The number of validation data is: %d" % len(testData))
  116. weights = torch.Tensor(trainData.labelweights).cuda()
  117. #endregion
  118. # region loading model:使用预训练模型或新训练
  119. modelSS = ''
  120. criterion = ''
  121. if PN2bool:
  122. modelSS = PN2SS(classNumber).cuda()
  123. criterion = PN2loss().cuda()
  124. modelSS.apply(inplace_relu)
  125. else:
  126. modelSS = PNss(classNumber).cuda()
  127. criterion = PNloss().cuda()
  128. modelSS.apply(inplace_relu)
  129. # 权重初始化
  130. def weights_init(m):
  131. classname = m.__class__.__name__
  132. if classname.find('Conv2d') != -1:
  133. torch.nn.init.xavier_normal_(m.weight.data)
  134. torch.nn.init.constant_(m.bias.data, 0.0)
  135. elif classname.find('Linear') != -1:
  136. torch.nn.init.xavier_normal_(m.weight.data)
  137. torch.nn.init.constant_(m.bias.data, 0.0)
  138. try:
  139. path_premodel = ''
  140. if PN2bool:
  141. path_premodel = os.path.join(dirModel2, 'best_model_S3DIS.pth')
  142. else:
  143. path_premodel = os.path.join(dirModel1, 'best_model_S3DIS.pth')
  144. checkpoint = torch.load(path_premodel)
  145. start_epoch = checkpoint['epoch']
  146. # print('pretrain epoch = '+str(start_epoch))
  147. modelSS.load_state_dict(checkpoint['model_state_dict'])
  148. log_string('!!!!!!!!!! Use pretrain model')
  149. except:
  150. log_string('...... starting new training ......')
  151. start_epoch = 0
  152. modelSS = modelSS.apply(weights_init)
  153. #endregion
  154. # start_epoch = 0
  155. # modelSS = modelSS.apply(weights_init)
  156. #region 训练的参数和选项
  157. if args.optimizer == 'Adam':
  158. optimizer = torch.optim.Adam(
  159. modelSS.parameters(),
  160. lr=args.learning_rate,
  161. betas=(0.9, 0.999),
  162. eps=1e-08,
  163. weight_decay=args.decay_rate
  164. )
  165. else:
  166. optimizer = torch.optim.SGD(modelSS.parameters(), lr=args.learning_rate, momentum=0.9)
  167. def bn_momentum_adjust(m, momentum):
  168. if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
  169. m.momentum = momentum
  170. LEARNING_RATE_CLIP = 1e-5
  171. MOMENTUM_ORIGINAL = 0.1
  172. MOMENTUM_DECCAY = 0.5
  173. MOMENTUM_DECCAY_STEP = args.step_size
  174. global_epoch = 0
  175. best_iou = 0
  176. #endregion
  177. for epoch in range(start_epoch, args.epoch):
  178. # region Train on chopped scenes
  179. log_string('****** Epoch %d (%d/%s) ******' % (global_epoch + 1, epoch + 1, args.epoch))
  180. lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP)
  181. log_string('Learning rate:%f' % lr)
  182. for param_group in optimizer.param_groups:
  183. param_group['lr'] = lr
  184. momentum = MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // MOMENTUM_DECCAY_STEP))
  185. if momentum < 0.01:
  186. momentum = 0.01
  187. log_string('BN momentum updated to: %f' % momentum)
  188. modelSS = modelSS.apply(lambda x: bn_momentum_adjust(x, momentum))
  189. modelSS = modelSS.train()
  190. #endregion
  191. # region 训练
  192. num_batches = len(trainDataLoader)
  193. total_correct = 0
  194. total_seen = 0
  195. loss_sum = 0
  196. for i, (points, target) in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9):
  197. # 梯度归零
  198. optimizer.zero_grad()
  199. # xyzL
  200. points = points.data.numpy() # ndarray = bs,4096,9(xyz rgb nxnynz)
  201. points[:, :, :3] = dataProcess.rotate_point_cloud_z(points[:, :, :3]) ## 数据处理的操作
  202. points = torch.Tensor(points) # tensor = bs,4096,9
  203. points, target = points.float().cuda(), target.long().cuda()
  204. points = points.transpose(2, 1) # tensor = bs,9,4096
  205. # 预测结果
  206. seg_pred, trans_feat = modelSS(points) # tensor = bs,4096,13 # tensor = bs,512,16
  207. seg_pred = seg_pred.contiguous().view(-1, classNumber) # tensor = (bs*4096=)点数量,13
  208. # 真实标签
  209. batch_label = target.view(-1, 1)[:, 0].cpu().data.numpy() # ndarray = (bs*4096=)点数量
  210. target = target.view(-1, 1)[:, 0] # tensor = (bs*4096=)点数量
  211. # loss
  212. loss = criterion(seg_pred, target, trans_feat, weights)
  213. loss.backward()
  214. # 优化器来更新模型的参数
  215. optimizer.step()
  216. pred_choice = seg_pred.cpu().data.max(1)[1].numpy() # ndarray = (bs*4096=)点数量
  217. correct = np.sum(pred_choice == batch_label) # 预测正确的点数量
  218. total_correct += correct
  219. total_seen += (batchSize * pointNumber)
  220. loss_sum += loss
  221. log_string('Training mean loss: %f' % (loss_sum / num_batches))
  222. log_string('Training accuracy: %f' % (total_correct / float(total_seen)))
  223. # draw
  224. trainLoss = (loss_sum.item()) / num_batches
  225. visdomTL.line([trainLoss], [epoch+1], win=visdomTLwindow, update='append')
  226. #endregion
  227. # region 保存模型
  228. if epoch % 1 == 0:
  229. modelpath=''
  230. if PN2bool:
  231. modelpath = os.path.join(dirModel2, 'model' + str(epoch + 1) + '_S3DIS.pth')
  232. else:
  233. modelpath = os.path.join(dirModel1, 'model' + str(epoch + 1) + '_S3DIS.pth')
  234. state = {
  235. 'epoch': epoch,
  236. 'model_state_dict': modelSS.state_dict(),
  237. 'optimizer_state_dict': optimizer.state_dict(),
  238. }
  239. torch.save(state, modelpath)
  240. logger.info('Save model...'+modelpath)
  241. #endregion
  242. # region Evaluate on chopped scenes
  243. with torch.no_grad():
  244. num_batches = len(testDataLoader)
  245. total_correct = 0
  246. total_seen = 0
  247. loss_sum = 0
  248. labelweights = np.zeros(classNumber)
  249. total_seen_class = [0 for _ in range(classNumber)]
  250. total_correct_class = [0 for _ in range(classNumber)]
  251. total_iou_deno_class = [0 for _ in range(classNumber)]
  252. modelSS = modelSS.eval()
  253. log_string('****** Epoch Evaluation %d (%d/%s) ******' % (global_epoch + 1, epoch + 1, args.epoch))
  254. for i, (points, target) in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9):
  255. points = points.data.numpy() # ndarray = bs,4096,9
  256. points = torch.Tensor(points) # tensor = bs,4096,9
  257. points, target = points.float().cuda(), target.long().cuda() # tensor = bs,4096,9 # tensor = bs,4096
  258. points = points.transpose(2, 1) # tensor = bs,9,4096
  259. seg_pred, trans_feat = modelSS(points) # tensor = bs,4096,13 # tensor = bs,512,16
  260. pred_val = seg_pred.contiguous().cpu().data.numpy() # ndarray = bs,4096,13
  261. seg_pred = seg_pred.contiguous().view(-1, classNumber) # tensor = bs*4096,13
  262. batch_label = target.cpu().data.numpy() # ndarray = bs,4096
  263. target = target.view(-1, 1)[:, 0] # tensor = bs*4096
  264. loss = criterion(seg_pred, target, trans_feat, weights)
  265. loss_sum += loss
  266. pred_val = np.argmax(pred_val, 2) # ndarray = bs,4096
  267. correct = np.sum((pred_val == batch_label))
  268. total_correct += correct
  269. total_seen += (batchSize * pointNumber)
  270. tmp, _ = np.histogram(batch_label, range(classNumber + 1))
  271. labelweights += tmp
  272. for l in range(classNumber):
  273. total_seen_class[l] += np.sum((batch_label == l))
  274. total_correct_class[l] += np.sum((pred_val == l) & (batch_label == l))
  275. total_iou_deno_class[l] += np.sum(((pred_val == l) | (batch_label == l)))
  276. labelweights = labelweights.astype(np.float32) / np.sum(labelweights.astype(np.float32))
  277. mIoU = np.mean(np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float64) + 1e-6))
  278. log_string('eval mean loss: %f' % (loss_sum / float(num_batches)))
  279. log_string('eval point avg class IoU: %f' % (mIoU))
  280. log_string('eval point accuracy: %f' % (total_correct / float(total_seen)))
  281. log_string('eval point avg class acc: %f' % (
  282. np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=np.float64) + 1e-6))))
  283. iou_per_class_str = '------- IoU --------\n'
  284. for l in range(classNumber):
  285. iou_per_class_str += 'class %s weight: %.3f, IoU: %.3f \n' % (
  286. seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])), labelweights[l - 1],
  287. total_correct_class[l] / float(total_iou_deno_class[l]))
  288. log_string(iou_per_class_str)
  289. log_string('Eval mean loss: %f' % (loss_sum / num_batches))
  290. log_string('Eval accuracy: %f' % (total_correct / float(total_seen)))
  291. # draw
  292. valLoss = (loss_sum.item()) / num_batches
  293. visdomVL.line([valLoss], [epoch+1], win=visdomVLwindow, update='append')
  294. # region 根据 mIoU确定最佳模型
  295. if mIoU >= best_iou:
  296. best_iou = mIoU
  297. bestmodelpath = ''
  298. if PN2bool:
  299. bestmodelpath = os.path.join(dirModel2, 'best_model_S3DIS.pth')
  300. else:
  301. bestmodelpath = os.path.join(dirModel1, 'best_model_S3DIS.pth')
  302. state = {
  303. 'epoch': epoch,
  304. 'class_avg_iou': mIoU,
  305. 'model_state_dict': modelSS.state_dict(),
  306. 'optimizer_state_dict': optimizer.state_dict(),
  307. }
  308. torch.save(state, bestmodelpath)
  309. logger.info('Save best model......'+bestmodelpath)
  310. log_string('Best mIoU: %f' % best_iou)
  311. #endregion
  312. #endregion
  313. global_epoch += 1
  314. # draw
  315. visdomTVL.line(X=[epoch+1], Y=[trainLoss],name="train loss", win='line', update='append',
  316. opts=dict(showlegend=True, markers=False,
  317. title='PointNet++ train validate loss',
  318. xlabel='epoch', ylabel='loss'))
  319. visdomTVL.line(X=[epoch+1], Y=[valLoss], name="train loss", win='line', update='append')
  320. log_string('-------------------------------------------------\n\n')

3.4 预测测试 test_SematicSegmentation.py

  1. # 参考
  2. # https://github.com/yanx27/Pointnet_Pointnet2_pytorch
  3. import argparse
  4. import sys
  5. import os
  6. import numpy as np
  7. import logging
  8. from pathlib import Path
  9. import importlib
  10. from tqdm import tqdm
  11. import torch
  12. import warnings
  13. warnings.filterwarnings('ignore')
  14. from dataset.S3DISDataLoader import ScannetDatasetWholeScene
  15. from dataset.indoor3d_util import g_label2color
  16. # PointNet
  17. from PointNet2.pointnet_sem_seg import get_model as PNss
  18. # PointNet++
  19. from PointNet2.pointnet2_sem_seg import get_model as PN2SS
  20. PN2bool = True
  21. # PN2bool = False
  22. # region 函数:投票;日志输出;保存结果为las。
  23. # 投票决定结果
  24. def add_vote(vote_label_pool, point_idx, pred_label, weight):
  25. B = pred_label.shape[0]
  26. N = pred_label.shape[1]
  27. for b in range(B):
  28. for n in range(N):
  29. if weight[b, n] != 0 and not np.isinf(weight[b, n]):
  30. vote_label_pool[int(point_idx[b, n]), int(pred_label[b, n])] += 1
  31. return vote_label_pool
  32. # 日志
  33. def log_string(str):
  34. logger.info(str)
  35. print(str)
  36. # save to LAS
  37. import laspy
  38. def SaveResultLAS(newLasPath, point_np, rgb_np, label1, label2):
  39. # data
  40. newx = point_np[:, 0]
  41. newy = point_np[:, 1]
  42. newz = point_np[:, 2]
  43. newred = rgb_np[:, 0]
  44. newgreen = rgb_np[:, 1]
  45. newblue = rgb_np[:, 2]
  46. newclassification = label1
  47. newuserdata = label2
  48. minx = min(newx)
  49. miny = min(newy)
  50. minz = min(newz)
  51. # create a new header
  52. newheader = laspy.LasHeader(point_format=3, version="1.2")
  53. newheader.scales = np.array([0.0001, 0.0001, 0.0001])
  54. newheader.offsets = np.array([minx, miny, minz])
  55. newheader.add_extra_dim(laspy.ExtraBytesParams(name="Classification", type=np.uint8))
  56. newheader.add_extra_dim(laspy.ExtraBytesParams(name="UserData", type=np.uint8))
  57. # create a Las
  58. newlas = laspy.LasData(newheader)
  59. newlas.x = newx
  60. newlas.y = newy
  61. newlas.z = newz
  62. newlas.red = newred
  63. newlas.green = newgreen
  64. newlas.blue = newblue
  65. newlas.Classification = newclassification
  66. newlas.UserData = newuserdata
  67. # write
  68. newlas.write(newLasPath)
  69. # 超参数
  70. def parse_args():
  71. parser = argparse.ArgumentParser('Model')
  72. parser.add_argument('--pnModel', type=bool, default=True, help='True = PointNet++;False = PointNet')
  73. parser.add_argument('--batch_size', type=int, default=32, help='batch size in testing [default: 32]')
  74. parser.add_argument('--GPU', type=str, default='0', help='specify GPU device')
  75. parser.add_argument('--num_point', type=int, default=4096, help='point number [default: 4096]')
  76. parser.add_argument('--test_area', type=int, default=5, help='area for testing, option: 1-6 [default: 5]')
  77. parser.add_argument('--num_votes', type=int, default=1,
  78. help='aggregate segmentation scores with voting [default: 1]')
  79. return parser.parse_args()
  80. #endregion
  81. # 当前文件的路径
  82. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  83. # 模型的路径
  84. pathTrainModel = os.path.join(ROOT_DIR, 'trainModel/pointnet_model')
  85. if PN2bool:
  86. pathTrainModel = os.path.join(ROOT_DIR, 'trainModel/PointNet2_model')
  87. # 结果路径
  88. visual_dir = ROOT_DIR + '/testResultPN/'
  89. if PN2bool:
  90. visual_dir = ROOT_DIR + '/testResultPN2/'
  91. visual_dir = Path(visual_dir)
  92. visual_dir.mkdir(exist_ok=True)
  93. # 日志的路径
  94. pathLog = os.path.join(ROOT_DIR, 'LOG_test_eval.txt')
  95. # 数据集的路径
  96. pathDataset = os.path.join(ROOT_DIR, 'dataset/stanford_indoor3d/')
  97. # 分割类别排序
  98. classNumber = 13
  99. classes = ['ceiling', 'floor', 'wall', 'beam', 'column', 'window', 'door', 'table', 'chair', 'sofa', 'bookcase',
  100. 'board', 'clutter']
  101. class2label = {cls: i for i, cls in enumerate(classes)}
  102. seg_classes = class2label
  103. seg_label_to_cat = {}
  104. for i, cat in enumerate(seg_classes.keys()):
  105. seg_label_to_cat[i] = cat
  106. if __name__ == '__main__':
  107. #region LOG info
  108. logger = logging.getLogger("test_eval")
  109. logger.setLevel(logging.INFO) #日志级别:DEBUG, INFO, WARNING, ERROR, 和 CRITICAL
  110. file_handler = logging.FileHandler(pathLog)
  111. file_handler.setLevel(logging.INFO)
  112. formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
  113. file_handler.setFormatter(formatter)
  114. logger.addHandler(file_handler)
  115. #endregion
  116. #region 超参数
  117. args = parse_args()
  118. args.pnModel = PN2bool
  119. log_string('--- hyper-parameter ---')
  120. log_string(args)
  121. os.environ["CUDA_VISIBLE_DEVICES"] = args.GPU
  122. batchSize = args.batch_size
  123. pointNumber = args.num_point
  124. testArea = args.test_area
  125. voteNumber = args.num_votes
  126. #endregion
  127. #region ---------- 加载语义分割的模型 ----------
  128. log_string("---------- Loading sematic segmentation model ----------")
  129. ssModel = ''
  130. if PN2bool:
  131. ssModel = PN2SS(classNumber).cuda()
  132. else:
  133. ssModel = PNss(classNumber).cuda()
  134. path_model = os.path.join(pathTrainModel, 'best_model_S3DIS.pth')
  135. checkpoint = torch.load(path_model)
  136. ssModel.load_state_dict(checkpoint['model_state_dict'])
  137. ssModel = ssModel.eval()
  138. #endregion
  139. # 模型推断(inference)或评估(evaluation)阶段,不需要计算梯度,而且关闭梯度计算可以显著减少内存占用,加速计算。
  140. log_string('--- Evaluation whole scene')
  141. with torch.no_grad():
  142. # IOU 结果
  143. total_seen_class = [0 for _ in range(classNumber)]
  144. total_correct_class = [0 for _ in range(classNumber)]
  145. total_iou_deno_class = [0 for _ in range(classNumber)]
  146. # 测试区域的所有文件
  147. testDataset = ScannetDatasetWholeScene(pathDataset, split='test', test_area=testArea, block_points=pointNumber)
  148. scene_id_name = testDataset.file_list
  149. scene_id_name = [x[:-4] for x in scene_id_name] # 名称(无扩展名)
  150. testCount = len(scene_id_name)
  151. testCount = 1
  152. # 遍历需要预测的物体
  153. for batch_idx in range(testCount):
  154. log_string("Inference [%d/%d] %s ..." % (batch_idx + 1, testCount, scene_id_name[batch_idx]))
  155. # 数据
  156. whole_scene_data = testDataset.scene_points_list[batch_idx]
  157. # 真值
  158. whole_scene_label = testDataset.semantic_labels_list[batch_idx]
  159. whole_scene_labelR = np.reshape(whole_scene_label, (whole_scene_label.size, 1))
  160. # 预测标签
  161. vote_label_pool = np.zeros((whole_scene_label.shape[0], classNumber))
  162. # 同一物体多次预测
  163. for _ in tqdm(range(voteNumber), total=voteNumber):
  164. scene_data, scene_label, scene_smpw, scene_point_index = testDataset[batch_idx]
  165. num_blocks = scene_data.shape[0]
  166. s_batch_num = (num_blocks + batchSize - 1) // batchSize
  167. batch_data = np.zeros((batchSize, pointNumber, 9))
  168. batch_label = np.zeros((batchSize, pointNumber))
  169. batch_point_index = np.zeros((batchSize, pointNumber))
  170. batch_smpw = np.zeros((batchSize, pointNumber))
  171. for sbatch in range(s_batch_num):
  172. start_idx = sbatch * batchSize
  173. end_idx = min((sbatch + 1) * batchSize, num_blocks)
  174. real_batch_size = end_idx - start_idx
  175. batch_data[0:real_batch_size, ...] = scene_data[start_idx:end_idx, ...]
  176. batch_label[0:real_batch_size, ...] = scene_label[start_idx:end_idx, ...]
  177. batch_point_index[0:real_batch_size, ...] = scene_point_index[start_idx:end_idx, ...]
  178. batch_smpw[0:real_batch_size, ...] = scene_smpw[start_idx:end_idx, ...]
  179. batch_data[:, :, 3:6] /= 1.0
  180. torch_data = torch.Tensor(batch_data)
  181. torch_data = torch_data.float().cuda()
  182. torch_data = torch_data.transpose(2, 1)
  183. seg_pred, _ = ssModel(torch_data)
  184. batch_pred_label = seg_pred.contiguous().cpu().data.max(2)[1].numpy()
  185. # 投票产生预测标签
  186. vote_label_pool = add_vote(vote_label_pool, batch_point_index[0:real_batch_size, ...],
  187. batch_pred_label[0:real_batch_size, ...],
  188. batch_smpw[0:real_batch_size, ...])
  189. # region 保存预测的结果
  190. # 预测标签
  191. pred_label = np.argmax(vote_label_pool, 1)
  192. pred_labelR = np.reshape(pred_label, (pred_label.size, 1))
  193. # 点云-真值-预测标签
  194. pcrgb_ll = np.hstack((whole_scene_data, whole_scene_labelR, pred_labelR))
  195. # ---------- 保存成 txt ----------
  196. pathTXT = os.path.join(visual_dir, scene_id_name[batch_idx] + '.txt')
  197. np.savetxt(pathTXT, pcrgb_ll, fmt='%f', delimiter='\t')
  198. log_string('save:' + pathTXT)
  199. # ---------- 保存成 las ----------
  200. pathLAS = os.path.join(visual_dir, scene_id_name[batch_idx] + '.las')
  201. SaveResultLAS(pathLAS, pcrgb_ll[:,0:3], pcrgb_ll[:,3:6], pcrgb_ll[:,6], pcrgb_ll[:,7])
  202. log_string('save:' + pathLAS)
  203. # endregion
  204. # IOU 临时结果
  205. total_seen_class_tmp = [0 for _ in range(classNumber)]
  206. total_correct_class_tmp = [0 for _ in range(classNumber)]
  207. total_iou_deno_class_tmp = [0 for _ in range(classNumber)]
  208. for l in range(classNumber):
  209. total_seen_class_tmp[l] += np.sum((whole_scene_label == l))
  210. total_correct_class_tmp[l] += np.sum((pred_label == l) & (whole_scene_label == l))
  211. total_iou_deno_class_tmp[l] += np.sum(((pred_label == l) | (whole_scene_label == l)))
  212. total_seen_class[l] += total_seen_class_tmp[l]
  213. total_correct_class[l] += total_correct_class_tmp[l]
  214. total_iou_deno_class[l] += total_iou_deno_class_tmp[l]
  215. iou_map = np.array(total_correct_class_tmp) / (np.array(total_iou_deno_class_tmp, dtype=np.float64) + 1e-6)
  216. print(iou_map)
  217. arr = np.array(total_seen_class_tmp)
  218. tmp_iou = np.mean(iou_map[arr != 0])
  219. log_string('Mean IoU of %s: %.4f' % (scene_id_name[batch_idx], tmp_iou))
  220. IoU = np.array(total_correct_class) / (np.array(total_iou_deno_class, dtype=np.float64) + 1e-6)
  221. iou_per_class_str = '----- IoU -----\n'
  222. for l in range(classNumber):
  223. iou_per_class_str += 'class %s, IoU: %.3f \n' % (
  224. seg_label_to_cat[l] + ' ' * (14 - len(seg_label_to_cat[l])),
  225. total_correct_class[l] / float(total_iou_deno_class[l]))
  226. log_string(iou_per_class_str)
  227. log_string('eval point avg class IoU: %f' % np.mean(IoU))
  228. log_string('eval whole scene point avg class acc: %f' % (
  229. np.mean(np.array(total_correct_class) / (np.array(total_seen_class, dtype=np.float64) + 1e-6))))
  230. log_string('eval whole scene point accuracy: %f' % (
  231. np.sum(total_correct_class) / float(np.sum(total_seen_class) + 1e-6)))
  232. log_string('--------------------------------------\n\n')

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/小蓝xlanll/article/detail/101755?site
推荐阅读
相关标签
  

闽ICP备14008679号