当前位置:   article > 正文

一文带你入门NeRF——利用PyTorch实现NeRF代码详解(附代码)_nerf三维重建pytorch

nerf三维重建pytorch

作者:大森林 | 来源:3DCV

在公众号「3DCV」后台,回复「原论文」即可获取代码。

添加微信:dddvisiona,备注:NeRF,拉你入群。文末附行业细分群。

1. NeRF定义

神经辐射场(NeRF)是一种利用神经网络来表示和渲染复杂的三维场景的方法。它可以从一组二维图片中学习出一个连续的三维函数,这个函数可以给出空间中任意位置和方向上的颜色和密度。通过体积渲染的技术,NeRF可以从任意视角合成出逼真的图像,包括透明和半透明物体,以及复杂的光线传播效果。

2. NeRF优势

NeRF模型相比于其他新的视图合成和场景表示方法有以下几个优势:

1)NeRF不需要离散化的三维表示,如网格或体素,因此可以避免模型精度和细节程度受到限制。NeRF也可以自适应地处理不同形状和大小的场景,而不需要人工调整参数。

2)NeRF使用位置编码的方式将位置和角度信息映射到高频域,使得网络能够更好地捕捉场景的细微结构和变化。NeRF还使用视角相关的颜色预测,能够生成不同视角下不同的光照效果。

3)NeRF使用分段随机采样的方式来近似体积渲染的积分,这样可以保证采样位置的连续性,同时避免网络过拟合于离散点的信息。NeRF还使用多层级体素采样的技巧,以提高渲染效率和质量。

3. NeRF实现步骤

1)定义一个全连接的神经网络,它的输入是空间位置和视角方向输出是颜色和密度

2)使用位置编码的方式将输入映射到高频域,以便网络能够捕捉细微的结构和变化。

3)使用分段随机采样的方式从每条光线上采样一些点,然后用神经网络预测这些点的颜色和密度。

4)使用体积渲染的公式计算每条光线上的颜色和透明度,作为最终的图像输出。

5)使用渲染损失函数来优化神经网络的参数,使得渲染的图像与输入的图像尽可能接近。

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. # 定义一个全连接的神经网络,它的输入是空间位置和视角方向,输出是颜色和密度。
  5. class NeRF(nn.Module):
  6.     def __init__(self, D=8, W=256input_ch=3input_ch_views=3output_ch=4, skips=[4]):
  7.         super().__init__()
  8.         # 定义位置编码后的位置信息的线性层,如果层数在skips列表中,则将原始位置信息与隐藏层拼接
  9.         self.pts_linears = nn.ModuleList(
  10.             [nn.Linear(input_ch, W)] + [nn.Linear(W, W) if i not in skips else nn.Linear(W + input_ch, W) for i in range(D-1)])
  11.         # 定义位置编码后的视角方向信息的线性层
  12.         self.views_linears = nn.ModuleList([nn.Linear(W + input_ch_views, W//2)] + [nn.Linear(W//2, W//2for i in range(1)])
  13.         # 定义特征向量的线性层
  14.         self.feature_linear = nn.Linear(W//2, W)
  15.         # 定义透明度(alpha)值的线性层
  16.         self.alpha_linear = nn.Linear(W, 1)
  17.         # 定义RGB颜色的线性层
  18.         self.rgb_linear = nn.Linear(W + input_ch_views, 3)
  19.     def forward(self, x):
  20.         # x: (B, input_ch + input_ch_views)
  21.         # 提取位置和视角方向信息
  22.         p = x[:, :3] # (B, 3)
  23.         d = x[:, 3:] # (B, 3)
  24.         # 对输入进行位置编码,将低频信号映射到高频域
  25.         p = positional_encoding(p) # (B, input_ch)
  26.         d = positional_encoding(d) # (B, input_ch_views)
  27.         # 将位置信息输入网络
  28.         h = p
  29.         for i, l in enumerate(self.pts_linears):
  30.             h = l(h)
  31.             h = F.relu(h)
  32.             if i in skips:
  33.                 h = torch.cat([h, p], -1) # 如果层数在skips列表中,则将原始位置信息与隐藏层拼接
  34.         # 将视角方向信息与隐藏层拼接,并输入网络
  35.         h = torch.cat([h, d], -1)
  36.         for i, l in enumerate(self.views_linears):
  37.             h = l(h)
  38.             h = F.relu(h)
  39.         # 预测特征向量和透明度(alpha)值
  40.         feature = self.feature_linear(h) # (B, W)
  41.         alpha = self.alpha_linear(feature) # (B, 1)
  42.         
  43.         # 使用特征向量和视角方向信息预测RGB颜色
  44.         rgb = torch.cat([feature, d], -1
  45.         rgb = self.rgb_linear(rgb) # (B, 3)
  46.         return torch.cat([rgb, alpha], -1) # (B, 4)
  47. # 定义位置编码函数
  48. def positional_encoding(x):
  49.     # x: (B, C)
  50.     B, C = x.shape
  51.     L = int(C // 2) # 计算位置编码的长度
  52.     freqs = torch.logspace(0., L - 1, steps=L).to(x.device) * math.pi # 计算频率系数,呈指数增长
  53.     freqs = freqs[None].repeat(B, 1) # (B, L)
  54.     x_pos_enc_low = torch.sin(x[:, :L] * freqs) # 对前一半的输入进行正弦变换,得到低频部分 (B, L)
  55.     x_pos_enc_high = torch.cos(x[:, :L] * freqs) # 对前一半的输入进行余弦变换,得到高频部分 (B, L)
  56.     x_pos_enc = torch.cat([x_pos_enc_low, x_pos_enc_high], dim=-1) # 将低频和高频部分拼接,得到位置编码后的输入 (B, C)
  57.     return x_pos_enc
  58. # 定义体积渲染函数
  59. def volume_rendering(rays_o, rays_d, model):
  60.     # rays_o: (B, 3), 每条光线的起点
  61.     # rays_d: (B, 3), 每条光线的方向
  62.     B = rays_o.shape[0]
  63.     # 在每条光线上采样一些点
  64.     near, far = 0., 1. # 近平面和远平面
  65.     N_samples = 64 # 每条光线的采样数
  66.     t_vals = torch.linspace(near, far, N_samples).to(rays_o.device) # (N_samples,)
  67.     t_vals = t_vals.expand(B, N_samples) # (B, N_samples)
  68.     z_vals = near * (1. - t_vals) + far * t_vals # 计算每个采样点的深度值 (B, N_samples)
  69.     z_vals = z_vals.unsqueeze(-1) # (B, N_samples, 1)
  70.     pts = rays_o.unsqueeze(1+ rays_d.unsqueeze(1* z_vals # 计算每个采样点的空间位置 (B, N_samples, 3)
  71.     # 将采样点和视角方向输入网络
  72.     pts_flat = pts.reshape(-13) # (B*N_samples, 3)
  73.     rays_d_flat = rays_d.unsqueeze(1).expand(-1, N_samples, -1).reshape(-13) # (B*N_samples, 3)
  74.     x_flat = torch.cat([pts_flat, rays_d_flat], -1) # (B*N_samples, 6)
  75.     y_flat = model(x_flat) # (B*N_samples, 4)
  76.     y = y_flat.reshape(B, N_samples, 4) # (B, N_samples, 4)
  77.     # 提取RGB颜色和透明度(alpha)值
  78.     rgb = y[..., :3] # (B, N_samples, 3)
  79.     alpha = y[..., 3] # (B, N_samples)
  80.     # 计算每个采样点的权重
  81.     dists = torch.cat([z_vals[..., 1:] - z_vals[..., :-1], torch.tensor([1e10]).to(z_vals.device).expand(B, 1)], -1) # 计算相邻采样点之间的距离,最后一个距离设为很大的值 (B, N_samples)
  82.     alpha = 1. - torch.exp(-alpha * dists) # 计算每个采样点的不透明度,即1减去透明度的指数衰减 (B, N_samples)
  83.     weights = alpha * torch.cumprod(torch.cat([torch.ones((B, 1)).to(alpha.device), 1. - alpha + 1e-10], -1), -1)[:, :-1] # 计算每个采样点的权重,即不透明度乘以之前所有采样点的透明度累积积,最后一个权重设为0 (B, N_samples)
  84.     # 计算每条光线的最终颜色和透明度
  85.     rgb_map = torch.sum(weights.unsqueeze(-1* rgb, -2) # 加权平均每个采样点的RGB颜色,得到每条光线的颜色 (B, 3)
  86.     depth_map = torch.sum(weights * z_vals.squeeze(-1), -1) # 加权平均每个采样点的深度值,得到每条光线的深度 (B,)
  87.     acc_map = torch.sum(weights, -1) # 累加每个采样点的权重,得到每条光线的不透明度 (B,)
  88.     
  89.     return rgb_map, depth_map, acc_map
  90. # 定义渲染损失函数
  91. def rendering_loss(rgb_map_pred, rgb_map_gt):
  92.     return ((rgb_map_pred - rgb_map_gt)**2).mean() # 计算预测的颜色与真实颜色之间的均方误差

综上所述,本代码实现了NeRF的核心结构,具体实现内容包括以下四个部分。

1)定义了NeRF网络结构,包含位置编码和多层全连接网络,输入是位置和视角,输出是颜色和密度。

2)实现了位置编码函数,通过正弦和余弦变换引入高频信息。

3)实现了体积渲染函数,在光线上采样点,查询NeRF网络预测颜色和密度,然后通过加权平均实现整体渲染。

4)定义了渲染损失函数,计算预测颜色和真实颜色的均方误差。

当然,本方案只是实现NeRF的一个基础方案,更多的细节还需要进行优化。需要完整学习代码的同学可以通过下面两个链接获取:

原论文及代码(NeRF: Neural Radiance Fields):https://github.com/bmild/nerf

大佬实现的pytorch版本(NeRF-pytorch):https://github.com/yenchenlin/nerf-pytorch

当然,为了方便下载,我们已经将上述两个源代码打包好了,请关注“3DCV”,回复:原论文获取完整详细代码

—END—

高效学习3D视觉三部曲

第一步 加入行业交流群,保持技术的先进性

目前工坊已经建立了3D视觉方向多个社群,包括SLAM、工业3D视觉、自动驾驶方向,细分群包括:

[工业方向]三维点云、结构光、机械臂、缺陷检测、三维测量、TOF、相机标定、综合群;

[SLAM方向]多传感器融合、ORB-SLAM、激光SLAM、机器人导航、RTK|GPS|UWB等传感器交流群、SLAM综合讨论群;

[自动驾驶方向]深度估计、Transformer、毫米波|激光雷达|视觉摄像头传感器讨论群、多传感器标定、自动驾驶综合群等。

[三维重建方向]NeRF、colmap、OpenMVS等。除了这些,还有求职、硬件选型、视觉产品落地等交流群。

大家可以添加小助理微信: cv3d008,备注:加群+方向+学校|公司, 小助理会拉你入群。

第二步 加入知识星球,问题及时得到解答
3.1 「3D视觉从入门到精通」技术星球

针对3D视觉领域的视频课程(三维重建、三维点云、结构光、手眼标定、相机标定、激光/视觉SLAM、自动驾驶等)、源码分享、知识点汇总、入门进阶学习路线、最新paper分享、疑问解答等进行深耕,更有各类大厂的算法工程人员进行技术指导。与此同时,星球将联合知名企业发布3D视觉相关算法开发岗位以及项目对接信息,打造成集技术与就业、项目对接为一体的铁杆粉丝聚集区,6000+星球成员为创造更好的AI世界共同进步,知识星球入口:「3D视觉从入门到精通」

学习3D视觉核心技术,扫描查看,3天内无条件退款

高质量教程资料、答疑解惑、助你高效解决问题

3.2 3D视觉岗求职星球

本星球:3D视觉岗求职星球 依托于公众号「3D视觉工坊」和「计算机视觉工坊」、「3DCV」,旨在发布3D视觉项目、3D视觉产品、3D视觉算法招聘信息,具体内容主要包括:

  • 收集汇总并发布3D视觉领域优秀企业的最新招聘信息。

  • 发布项目需求,包括2D、3D视觉、深度学习、VSLAM,自动驾驶、三维重建、结构光、机械臂位姿估计与抓取、光场重建、无人机、AR/VR等。

  • 分享3D视觉算法岗的秋招、春招准备攻略,心得体会,内推机会、实习机会等,涉及计算机视觉、SLAM、深度学习、自动驾驶、大数据等方向。

  • 星球内含有多家企业HR及猎头提供就业机会。群主和嘉宾既有21届/22届/23届参与招聘拿到算法offer(含有海康威视、阿里、美团、华为等大厂offer)。

  • 发布3D视觉行业新科技产品,触及行业新动向。

扫码加入,3D视觉岗求职星球,简历投起来

第三步 系统学习3D视觉,对模块知识体系,深刻理解并运行

如果大家对3D视觉某一个细分方向想系统学习[从理论、代码到实战],推荐3D视觉精品课程学习网址:www.3dcver.com

科研论文写作:

[1]国内首个面向三维视觉的科研方法与学术论文写作教程

基础课程:

[1]面向三维视觉算法的C++重要模块精讲:从零基础入门到进阶

[2]面向三维视觉的Linux嵌入式系统教程[理论+代码+实战]

[3]如何学习相机模型与标定?(代码+实战)

[4]ROS2从入门到精通:理论与实战

[5]彻底理解dToF雷达系统设计[理论+代码+实战]

工业3D视觉方向课程:

[1](第二期)从零搭建一套结构光3D重建系统[理论+源码+实践]

[2]保姆级线结构光(单目&双目)三维重建系统教程

[3]机械臂抓取从入门到实战课程(理论+源码)

[4]三维点云处理:算法与实战汇总

[5]彻底搞懂基于Open3D的点云处理教程!

[6]3D视觉缺陷检测教程:理论与实战!

SLAM方向课程:

[1]深度剖析面向机器人领域的3D激光SLAM技术原理、代码与实战

[1]彻底剖析激光-视觉-IMU-GPS融合SLAM算法:理论推导、代码讲解和实战

[2](第二期)彻底搞懂基于LOAM框架的3D激光SLAM:源码剖析到算法优化

[3]彻底搞懂视觉-惯性SLAM:VINS-Fusion原理精讲与源码剖析

[4]彻底剖析室内、室外激光SLAM关键算法和实战(cartographer+LOAM+LIO-SAM)

[5](第二期)ORB-SLAM3理论讲解与代码精析

视觉三维重建:

[1]彻底搞透视觉三维重建:原理剖析、代码讲解、及优化改进

自动驾驶方向课程:

[1] 深度剖析面向自动驾驶领域的车载传感器空间同步(标定)

[2] 国内首个面向自动驾驶目标检测领域的Transformer原理与实战课程

[3]单目深度估计方法:算法梳理与代码实现

[4]面向自动驾驶领域的3D点云目标检测全栈学习路线!(单模态+多模态/数据+代码)

[5]如何将深度学习模型部署到实际工程中?(分类+检测+分割)

无人机:

[1] 零基础入门四旋翼建模与控制(MATLAB仿真)[理论+实战]

最后

1、3D视觉文章投稿作者招募

2、3D视觉课程(自动驾驶、SLAM和工业3D视觉)主讲老师招募

3、顶会论文分享与3D视觉传感器行业直播邀请

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/很楠不爱3/article/detail/709180
推荐阅读
相关标签
  

闽ICP备14008679号