当前位置:   article > 正文

3DGS学习(六)—— 参数更新_3dgs 数学推到

3dgs 数学推到

参数更新

参考文章:3dgs中的数学推导

协方差矩阵的参数更新

  • 直接通过pytorch自带的更新机制,通过渲染后计算损失,只能更新2D协方差矩阵 Σ ′ \Sigma^\prime Σ,再通过公式逆推出3d空间协方差矩阵 Σ \Sigma Σ的值。该过程处理矩阵计算多且复杂,计算效率低下。
  • 为了提高计算效率,我们需要显示的表示 Σ \Sigma Σ,即利用前面学习到的将该矩阵拆分成旋转矩阵 R R R以及放缩矩阵 S S S
    Σ = R S S ⊤ R ⊤ \boldsymbol{\Sigma}=\boldsymbol{R}\boldsymbol{S}\boldsymbol{S}^\top\boldsymbol{R}^\top Σ=RSSR
  • 通过旋转四元数,我们可以进一步将要更新的R矩阵内的9个参数压缩到4个参数
    q = q r + q i ⋅ i + q j ⋅ j + q k ⋅ k R ( q ) = 2 ( 1 2 − ( q j 2 + q k 2 ) ( q i q j − q r q k ) ( q i q k + q r q j ) ( q i q j + q r q k ) 1 2 − ( q i 2 + q k 2 ) ( q j q k − q r q i ) ( q i q k − q r q j ) ( q j q k + q r q i ) 1 2 − ( q i 2 + q j 2 ) )
    q=qr+qii+qjj+qkkR(q)=2(12(qj2+qk2)(qiqjqrqk)(qiqk+qrqj)(qiqj+qrqk)12(qi2+qk2)(qjqkqrqi)(qiqkqrqj)(qjqk+qrqi)12(qi2+qj2))
    q=qr+qii+qjj+qkkR(q)=2 21(qj2+qk2)(qiqj+qrqk)(qiqkqrqj)(qiqjqrqk)21(qi2+qk2)(qjqk+qrqi)(qiqk+qrqj)(qjqkqrqi)21(qi2+qj2)
  • 旋转四元数是一种用于表示三维空间中旋转的数学工具。它是四元数的一种特殊形式,由一个实部和三个虚部组成。

  • 旋转四元数通常表示为q = w + xi + yj + zk,其中w是实部,(x, y, z)是虚部,i、j、k是虚数单位。这里需要满足四元数的数学性质:i² = j² = k² = ijk = -1。

  • 旋转四元数的核心思想是,通过对旋转轴上的旋转角度进行编码,以及通过旋转轴的单位向量来表示旋转的方向。旋转四元数的实部(w)用于表示旋转角度的余弦值,而虚部(x, y, z)则表示旋转轴在单位向量上的三个分量。

  • 放缩矩阵更不需要记录整个矩阵的信息,只需要记录其在三个轴方向的缩放比即可。

综上所述,协方差矩阵的更新转变为更新旋转四元数 q q q和一个含缩放比信息的三维向量 s s s
3dgs推导了 q q q s s s的梯度,节约了自动微分的成本,具体可以参考3dgs原论文附录部分的梯度回传数学推导部分。
在这里插入图片描述

颜色的参数更新

官方代码中更新颜色部分的代码如下:

// Backward pass for conversion of spherical harmonics to RGB for
// each Gaussian.
__device__ void computeColorFromSH(int idx, int deg, int max_coeffs, const glm::vec3* means, glm::vec3 campos, const float* shs, const bool* clamped, const glm::vec3* dL_dcolor, glm::vec3* dL_dmeans, glm::vec3* dL_dshs)
{
	// Compute intermediate values, as it is done during forward
	glm::vec3 pos = means[idx];
	glm::vec3 dir_orig = pos - campos;
	glm::vec3 dir = dir_orig / glm::length(dir_orig);

	glm::vec3* sh = ((glm::vec3*)shs) + idx * max_coeffs;

	// Use PyTorch rule for clamping: if clamping was applied,
	// gradient becomes 0.
	glm::vec3 dL_dRGB = dL_dcolor[idx];
	dL_dRGB.x *= clamped[3 * idx + 0] ? 0 : 1;
	dL_dRGB.y *= clamped[3 * idx + 1] ? 0 : 1;
	dL_dRGB.z *= clamped[3 * idx + 2] ? 0 : 1;

	glm::vec3 dRGBdx(0, 0, 0);
	glm::vec3 dRGBdy(0, 0, 0);
	glm::vec3 dRGBdz(0, 0, 0);
	float x = dir.x;
	float y = dir.y;
	float z = dir.z;

	// Target location for this Gaussian to write SH gradients to
	glm::vec3* dL_dsh = dL_dshs + idx * max_coeffs;

	// No tricks here, just high school-level calculus.
	float dRGBdsh0 = SH_C0;
	dL_dsh[0] = dRGBdsh0 * dL_dRGB;
	if (deg > 0)
	{
		float dRGBdsh1 = -SH_C1 * y;
		float dRGBdsh2 = SH_C1 * z;
		float dRGBdsh3 = -SH_C1 * x;
		dL_dsh[1] = dRGBdsh1 * dL_dRGB;
		dL_dsh[2] = dRGBdsh2 * dL_dRGB;
		dL_dsh[3] = dRGBdsh3 * dL_dRGB;

		dRGBdx = -SH_C1 * sh[3];
		dRGBdy = -SH_C1 * sh[1];
		dRGBdz = SH_C1 * sh[2];

		if (deg > 1)
		{
			float xx = x * x, yy = y * y, zz = z * z;
			float xy = x * y, yz = y * z, xz = x * z;

			float dRGBdsh4 = SH_C2[0] * xy;
			float dRGBdsh5 = SH_C2[1] * yz;
			float dRGBdsh6 = SH_C2[2] * (2.f * zz - xx - yy);
			float dRGBdsh7 = SH_C2[3] * xz;
			float dRGBdsh8 = SH_C2[4] * (xx - yy);
			dL_dsh[4] = dRGBdsh4 * dL_dRGB;
			dL_dsh[5] = dRGBdsh5 * dL_dRGB;
			dL_dsh[6] = dRGBdsh6 * dL_dRGB;
			dL_dsh[7] = dRGBdsh7 * dL_dRGB;
			dL_dsh[8] = dRGBdsh8 * dL_dRGB;

			dRGBdx += SH_C2[0] * y * sh[4] + SH_C2[2] * 2.f * -x * sh[6] + SH_C2[3] * z * sh[7] + SH_C2[4] * 2.f * x * sh[8];
			dRGBdy += SH_C2[0] * x * sh[4] + SH_C2[1] * z * sh[5] + SH_C2[2] * 2.f * -y * sh[6] + SH_C2[4] * 2.f * -y * sh[8];
			dRGBdz += SH_C2[1] * y * sh[5] + SH_C2[2] * 2.f * 2.f * z * sh[6] + SH_C2[3] * x * sh[7];

			if (deg > 2)
			{
				float dRGBdsh9 = SH_C3[0] * y * (3.f * xx - yy);
				float dRGBdsh10 = SH_C3[1] * xy * z;
				float dRGBdsh11 = SH_C3[2] * y * (4.f * zz - xx - yy);
				float dRGBdsh12 = SH_C3[3] * z * (2.f * zz - 3.f * xx - 3.f * yy);
				float dRGBdsh13 = SH_C3[4] * x * (4.f * zz - xx - yy);
				float dRGBdsh14 = SH_C3[5] * z * (xx - yy);
				float dRGBdsh15 = SH_C3[6] * x * (xx - 3.f * yy);
				dL_dsh[9] = dRGBdsh9 * dL_dRGB;
				dL_dsh[10] = dRGBdsh10 * dL_dRGB;
				dL_dsh[11] = dRGBdsh11 * dL_dRGB;
				dL_dsh[12] = dRGBdsh12 * dL_dRGB;
				dL_dsh[13] = dRGBdsh13 * dL_dRGB;
				dL_dsh[14] = dRGBdsh14 * dL_dRGB;
				dL_dsh[15] = dRGBdsh15 * dL_dRGB;

				dRGBdx += (
					SH_C3[0] * sh[9] * 3.f * 2.f * xy +
					SH_C3[1] * sh[10] * yz +
					SH_C3[2] * sh[11] * -2.f * xy +
					SH_C3[3] * sh[12] * -3.f * 2.f * xz +
					SH_C3[4] * sh[13] * (-3.f * xx + 4.f * zz - yy) +
					SH_C3[5] * sh[14] * 2.f * xz +
					SH_C3[6] * sh[15] * 3.f * (xx - yy));

				dRGBdy += (
					SH_C3[0] * sh[9] * 3.f * (xx - yy) +
					SH_C3[1] * sh[10] * xz +
					SH_C3[2] * sh[11] * (-3.f * yy + 4.f * zz - xx) +
					SH_C3[3] * sh[12] * -3.f * 2.f * yz +
					SH_C3[4] * sh[13] * -2.f * xy +
					SH_C3[5] * sh[14] * -2.f * yz +
					SH_C3[6] * sh[15] * -3.f * 2.f * xy);

				dRGBdz += (
					SH_C3[1] * sh[10] * xy +
					SH_C3[2] * sh[11] * 4.f * 2.f * yz +
					SH_C3[3] * sh[12] * 3.f * (2.f * zz - xx - yy) +
					SH_C3[4] * sh[13] * 4.f * 2.f * xz +
					SH_C3[5] * sh[14] * (xx - yy));
			}
		}
	}

	// The view direction is an input to the computation. View direction
	// is influenced by the Gaussian's mean, so SHs gradients
	// must propagate back into 3D position.
	glm::vec3 dL_ddir(glm::dot(dRGBdx, dL_dRGB), glm::dot(dRGBdy, dL_dRGB), glm::dot(dRGBdz, dL_dRGB));

	// Account for normalization of direction
	float3 dL_dmean = dnormvdv(float3{ dir_orig.x, dir_orig.y, dir_orig.z }, float3{ dL_ddir.x, dL_ddir.y, dL_ddir.z });

	// Gradients of loss w.r.t. Gaussian means, but only the portion 
	// that is caused because the mean affects the view-dependent color.
	// Additional mean gradient is accumulated in below methods.
	dL_dmeans[idx] += glm::vec3(dL_dmean.x, dL_dmean.y, dL_dmean.z);
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 输入参数包括:

idx: 代表当前高斯函数的索引;
deg: 指定球谐函数的阶数;
max_coeffs: 球谐函数的系数个数;
means: 存储每个高斯函数的均值向量;
campos: 相机位置;
shs: 存储每个高斯函数的球谐函数系数;
clamped: 存储每个高斯函数是否需要进行截断;
dL_dcolor: 目标颜色对 RGB 颜色空间的导数;
dL_dmeans: 目标颜色对高斯函数均值的导数;
dL_dshs: 目标颜色对球谐函数系数的导数。

  • 该函数主要实现以下过程:

计算相机与当前高斯函数均值之间的方向向量;
根据 PyTorch 规则,如果某个高斯函数需要进行截断,则其梯度为 0;
计算 RGB 颜色空间中的梯度;
计算 RGB 颜色空间中每个分量对坐标轴的偏导数;
根据球谐函数的定义,计算球谐函数系数与 RGB 颜色空间中每个分量之间的导数关系;
根据相机位置和方向向量,计算目标颜色对高斯函数均值的导数;
将高斯函数均值的导数累加到 dL_dmeans 中;

声明:本文内容由网友自发贡献,不代表【wpsshop博客】立场,版权归原作者所有,本站不承担相应法律责任。如您发现有侵权的内容,请联系我们。转载请注明出处:https://www.wpsshop.cn/w/繁依Fanyi0/article/detail/585176
推荐阅读
相关标签
  

闽ICP备14008679号