当前位置:   article > 正文

复现graspnet并使用自己的数据实现(pycharm)

graspnet

参考文章:Baseline model for "GraspNet-1Billion: A Large-Scale Benchmark for General Object Grasping" (CVPR 2020).[paper] [dataset] [API] [doc]

一、下载graspnet

1、安装

获取代码

  1. git clone https://github.com/graspnet/graspnet-baseline.git
  2. cd graspnet-baseline

通过 Pip 安装软件包

pip install -r requirements.txt

编译并安装 pointnet2

  1. cd pointnet2
  2. python setup.py install

编译并安装 knn

  1. cd knn
  2. python setup.py install

安装graspnetAPI,将它的包文件放置在graspnet的环境中

  1. git clone https://github.com/graspnet/graspnetAPI.git
  2. cd graspnetAPI
  3. pip install .

手动构建文档

  1. cd docs
  2. pip install -r requirements.txt
  3. bash build_doc.sh

2、环境构建

安装需要的软件包

  • Python 3
  • PyTorch 1.6
  • Open3d >=0.8
  • TensorBoard 2.3
  • NumPy
  • SciPy
  • Pillow
  • tqdm

3、生成标签

从 Google Drive/Baidu Pan 下载

运行程序生成:

  1. mv tolerance.tar dataset/
  2. cd dataset
  3. tar -xvf tolerance.tar

4、下载预权重

预训练权重可以从以下位置下载:

checkpoint-rs.tarcheckpoint-kn.tar是分别使用 RealSense 数据和 Kinect 数据进行训练。

二、demo复现

1、编辑配置

打开demo.py,在pycharm右上方的位置,展开,选择编辑配置

 在形参那里输入预训练权重

 根据下载的权重输入形参,注意后面的路径要修改为自己存储文件的位置

  1. --checkpoint_path logs/log_kn/checkpoint.tar
  2. --dump_dir logs/dump_rs --checkpoint_path logs/log_rs/checkpoint.tar --camera realsense --dataset_root /data/Benchmark/graspnet
  3. --log_dir logs/log_rs --batch_size 2

形参输入格式

--形参1 路径1 --形参2 路径2

 2、复现演示

在graspnet-baseline/doc/example_data里可以查看输入图片

 运行demo.py可以得到3D展示图,生成了6D抓取位姿

 结束展示

三、采用自己的数据集实现抓取预测

1、数据介绍

使用realsensel515实感相机,使用数据线连接电脑

realsensel515相机参数

factor_depth          4000                                             深度转换值

intrinsic_matrix                                                          相机内部参数

1351.720979.26
01352.93556.038
001

2、数据输入

软件包配置

pyrealsense2

cv2

实现实际场景的输入,成功转化为图片形式, 用于抓取输入

将depth_image与color_image对齐

修改相机内部参数(焦距、光学中心) 及深度转化值

 完整代码

  1. """ Demo to show prediction results.
  2. Author: chenxi-wang
  3. """
  4. import os
  5. import sys
  6. import numpy as np
  7. import open3d as o3d
  8. import argparse
  9. import importlib
  10. import scipy.io as scio
  11. from PIL import Image
  12. import torch
  13. from graspnetAPI import GraspGroup
  14. import pyrealsense2 as rs
  15. import cv2
  16. from matplotlib import pyplot as plt
  17. ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
  18. sys.path.append(os.path.join(ROOT_DIR, 'models'))
  19. sys.path.append(os.path.join(ROOT_DIR, 'dataset'))
  20. sys.path.append(os.path.join(ROOT_DIR, 'utils'))
  21. from models.graspnet import GraspNet, pred_decode
  22. from graspnet_dataset import GraspNetDataset
  23. from collision_detector import ModelFreeCollisionDetector
  24. from data_utils import CameraInfo, create_point_cloud_from_depth_image
  25. parser = argparse.ArgumentParser()
  26. parser.add_argument('--checkpoint_path', required=True, help='Model checkpoint path')
  27. parser.add_argument('--num_point', type=int, default=20000, help='Point Number [default: 20000]')
  28. parser.add_argument('--num_view', type=int, default=300, help='View Number [default: 300]')
  29. parser.add_argument('--collision_thresh', type=float, default=0.01, help='Collision Threshold in collision detection [default: 0.01]')
  30. parser.add_argument('--voxel_size', type=float, default=0.01, help='Voxel Size to process point clouds before collision detection [default: 0.01]')
  31. cfgs = parser.parse_args()
  32. def get_net():
  33. # Init the model
  34. net = GraspNet(input_feature_dim=0, num_view=cfgs.num_view, num_angle=12, num_depth=4,
  35. cylinder_radius=0.05, hmin=-0.02, hmax_list=[0.01,0.02,0.03,0.04], is_training=False)
  36. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  37. net.to(device)
  38. # Load checkpoint
  39. checkpoint = torch.load(cfgs.checkpoint_path)
  40. net.load_state_dict(checkpoint['model_state_dict'])
  41. start_epoch = checkpoint['epoch']
  42. print("-> loaded checkpoint %s (epoch: %d)"%(cfgs.checkpoint_path, start_epoch))
  43. # set model to eval mode
  44. net.eval()
  45. return net
  46. def get_and_process_data(data_dir):
  47. # load data
  48. color = np.array(Image.open(os.path.join(data_dir, 'color.png')), dtype=np.float32) / 255.0
  49. depth = np.array(Image.open(os.path.join(data_dir, 'depth.png')))
  50. workspace_mask = np.array(Image.open(os.path.join(data_dir, 'workspace_mask.png')))
  51. meta = scio.loadmat(os.path.join(data_dir, 'meta.mat'))# Resize depth to match color image resolution while preserving spatial alignment
  52. color_height, color_width = color.shape[:2]
  53. depth = cv2.resize(depth, (color_width, color_height), interpolation=cv2.INTER_NEAREST)
  54. intrinsic = meta['intrinsic_matrix']
  55. factor_depth =meta['factor_depth']
  56. # generate cloud
  57. camera = CameraInfo(1280.0, 720.0, intrinsic[0][0], intrinsic[1][1], intrinsic[0][2], intrinsic[1][2], factor_depth)
  58. cloud = create_point_cloud_from_depth_image(depth, camera, organized=True)
  59. # get valid points
  60. mask = (workspace_mask & (depth > 0))
  61. cloud_masked = cloud[mask]
  62. color_masked = color[mask]
  63. # sample points
  64. if len(cloud_masked) >= cfgs.num_point:
  65. idxs = np.random.choice(len(cloud_masked), cfgs.num_point, replace=False)
  66. else:
  67. idxs1 = np.arange(len(cloud_masked))
  68. idxs2 = np.random.choice(len(cloud_masked), cfgs.num_point-len(cloud_masked), replace=True)
  69. idxs = np.concatenate([idxs1, idxs2], axis=0)
  70. cloud_sampled = cloud_masked[idxs]
  71. color_sampled = color_masked[idxs]
  72. # convert data
  73. cloud = o3d.geometry.PointCloud()
  74. cloud.points = o3d.utility.Vector3dVector(cloud_masked.astype(np.float32))
  75. cloud.colors = o3d.utility.Vector3dVector(color_masked.astype(np.float32))
  76. end_points = dict()
  77. cloud_sampled = torch.from_numpy(cloud_sampled[np.newaxis].astype(np.float32))
  78. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  79. cloud_sampled = cloud_sampled.to(device)
  80. end_points['point_clouds'] = cloud_sampled
  81. end_points['cloud_colors'] = color_sampled
  82. return end_points, cloud
  83. def get_grasps(net, end_points):
  84. # Forward pass
  85. with torch.no_grad():
  86. end_points = net(end_points)
  87. grasp_preds = pred_decode(end_points)
  88. gg_array = grasp_preds[0].detach().cpu().numpy()
  89. gg = GraspGroup(gg_array)
  90. return gg
  91. def collision_detection(gg, cloud):
  92. mfcdetector = ModelFreeCollisionDetector(cloud, voxel_size=cfgs.voxel_size)
  93. collision_mask = mfcdetector.detect(gg, approach_dist=0.05, collision_thresh=cfgs.collision_thresh)
  94. gg = gg[~collision_mask]
  95. return gg
  96. def vis_grasps(gg, cloud):
  97. gg.nms()
  98. gg.sort_by_score()
  99. gg = gg[:50]
  100. grippers = gg.to_open3d_geometry_list()
  101. o3d.visualization.draw_geometries([cloud, *grippers])
  102. def demo(data_dir):
  103. net = get_net()
  104. end_points, cloud = get_and_process_data(data_dir)
  105. gg = get_grasps(net, end_points)
  106. if cfgs.collision_thresh > 0:
  107. gg = collision_detection(gg, np.array(cloud.points))
  108. vis_grasps(gg, cloud)
  109. def input1():
  110. # Create a pipeline
  111. pipeline = rs.pipeline()
  112. # Create a config object to configure the pipeline
  113. config = rs.config()
  114. config.enable_stream(rs.stream.depth, 1024, 768, rs.format.z16, 30)
  115. config.enable_stream(rs.stream.color, 1280, 720, rs.format.bgr8, 30)
  116. # Start the pipeline
  117. pipeline.start(config)
  118. align = rs.align(rs.stream.color) # Create align object for depth-color alignment
  119. try:
  120. while True:
  121. # Wait for a coherent pair of frames: color and depth
  122. frames = pipeline.wait_for_frames()
  123. aligned_frames = align.process(frames)
  124. if not aligned_frames:
  125. continue # If alignment fails, go back to the beginning of the loop
  126. color_frame = aligned_frames.get_color_frame()
  127. aligned_depth_frame = aligned_frames.get_depth_frame()
  128. if not color_frame or not aligned_depth_frame:
  129. continue
  130. # Convert aligned_depth_frame and color_frame to numpy arrays
  131. aligned_depth_image = np.asanyarray(aligned_depth_frame.get_data())
  132. color_image = np.asanyarray(color_frame.get_data())
  133. # Display the aligned depth image
  134. aligned_depth_colormap = cv2.applyColorMap(cv2.convertScaleAbs(aligned_depth_image, alpha=0.03),
  135. cv2.COLORMAP_JET)
  136. cv2.imshow("Aligned Depth colormap", aligned_depth_colormap)
  137. cv2.imshow("Aligned Depth Image", aligned_depth_image)
  138. cv2.imwrite('./data1/depth.png', aligned_depth_image)
  139. # Display the color image
  140. cv2.imshow("Color Image", color_image)
  141. cv2.imwrite('./data1/color.png', color_image)
  142. # Press 'q' to quit
  143. if cv2.waitKey(1) & 0xFF == ord('q'):
  144. break
  145. finally:
  146. # Stop the pipeline and close all windows
  147. pipeline.stop()
  148. cv2.destroyAllWindows()
  149. if __name__=='__main__':
  150. input1()
  151. data_dir = 'data1'
  152. demo(data_dir)

其中dada1为自己的数据文件,里面包含

color.png     自己相机生成的彩色图

depth.png   与彩色图对齐后的深度图

workspace.png  从demo数据文件中直接复制过来

meta.mat   复制demo数据文件的meta.mat,将里面的参数修改为自己相机的参数

记得权重赋予与demo.py一致

3、结果展示

彩色图 RGB

 深度图

生成6d抓取位姿及3d图

4、只保留最优抓取位姿

 修改部分代码

  1. def vis_grasps(gg, cloud, num_top_grasps=10):
  2. gg.nms()
  3. gg.sort_by_score(reverse=True) # Sort the grasps in descending order of scores
  4. gg = gg[:num_top_grasps] # Keep only the top num_top_grasps grasps
  5. grippers = gg.to_open3d_geometry_list()
  6. o3d.visualization.draw_geometries([cloud, *grippers])

运行程序

实验结束 

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/IT小白/article/detail/115918
推荐阅读
相关标签
  

闽ICP备14008679号