From b44cc41cbcfe1b54a54604a258fcfb39011500c8 Mon Sep 17 00:00:00 2001 From: akristoffersen Date: Sun, 1 Dec 2024 22:51:06 -0800 Subject: [PATCH] convert to using degrees, not levels --- nerfstudio/field_components/encodings.py | 10 +++++----- nerfstudio/model_components/renderers.py | 2 +- nerfstudio/utils/spherical_harmonics.py | 16 ++++++++-------- tests/field_components/test_encodings.py | 4 ++-- tests/utils/test_spherical_harmonics.py | 10 +++++----- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/nerfstudio/field_components/encodings.py b/nerfstudio/field_components/encodings.py index 287f168d25..b5f8bf4f0e 100644 --- a/nerfstudio/field_components/encodings.py +++ b/nerfstudio/field_components/encodings.py @@ -757,15 +757,15 @@ class SHEncoding(Encoding): """Spherical harmonic encoding Args: - levels: Number of spherical harmonic levels to encode. + levels: Number of spherical harmonic levels to encode. (level = sh degree + 1) """ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "torch") -> None: super().__init__(in_dim=3) - if levels <= 0 or levels > MAX_SH_DEGREE: + if levels <= 0 or levels > MAX_SH_DEGREE + 1: raise ValueError( - f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE} levels, requested {levels}" + f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE + 1} levels, requested {levels}" ) self.levels = levels @@ -781,7 +781,7 @@ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = " ) @classmethod - def get_tcnn_encoding_config(cls, levels) -> dict: + def get_tcnn_encoding_config(cls, levels: int) -> dict: """Get the encoding configuration for tcnn if implemented""" encoding_config = { "otype": "SphericalHarmonics", @@ -795,7 +795,7 @@ def get_out_dim(self) -> int: @torch.no_grad() def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: """Forward pass using pytorch. Significantly slower than TCNN implementation.""" - return components_from_spherical_harmonics(levels=self.levels, directions=in_tensor) + return components_from_spherical_harmonics(degree=self.levels - 1, directions=in_tensor) def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]: if self.tcnn_encoding is not None: diff --git a/nerfstudio/model_components/renderers.py b/nerfstudio/model_components/renderers.py index b2f33d45d5..99c14ca7d0 100644 --- a/nerfstudio/model_components/renderers.py +++ b/nerfstudio/model_components/renderers.py @@ -269,7 +269,7 @@ def forward( sh = sh.view(*sh.shape[:-1], 3, sh.shape[-1] // 3) levels = int(math.sqrt(sh.shape[-1])) - components = components_from_spherical_harmonics(levels=levels, directions=directions) + components = components_from_spherical_harmonics(degree=levels - 1, directions=directions) rgb = sh * components[..., None, :] # [..., num_samples, 3, sh_components] rgb = torch.sum(rgb, dim=-1) # [..., num_samples, 3] diff --git a/nerfstudio/utils/spherical_harmonics.py b/nerfstudio/utils/spherical_harmonics.py index 09968cb569..07936281cc 100644 --- a/nerfstudio/utils/spherical_harmonics.py +++ b/nerfstudio/utils/spherical_harmonics.py @@ -22,19 +22,19 @@ def components_from_spherical_harmonics( - levels: int, directions: Float[Tensor, "*batch 3"] + degree: int, directions: Float[Tensor, "*batch 3"] ) -> Float[Tensor, "*batch components"]: """ Returns value for each component of spherical harmonics. Args: - levels: Number of spherical harmonic levels to compute. + degree: Number of spherical harmonic degrees to compute. directions: Spherical harmonic coefficients """ - num_components = levels**2 + num_components = num_sh_bases(degree) components = torch.zeros((*directions.shape[:-1], num_components), device=directions.device) - assert 1 <= levels <= MAX_SH_DEGREE, f"SH levels must be in [1,{MAX_SH_DEGREE}], got {levels}" + assert 0 <= degree <= MAX_SH_DEGREE, f"SH degree must be in [0, {MAX_SH_DEGREE}], got {degree}" assert directions.shape[-1] == 3, f"Direction input should have three dimensions. Got {directions.shape[-1]}" x = directions[..., 0] @@ -49,13 +49,13 @@ def components_from_spherical_harmonics( components[..., 0] = 0.28209479177387814 # l1 - if levels > 1: + if degree > 0: components[..., 1] = 0.4886025119029199 * y components[..., 2] = 0.4886025119029199 * z components[..., 3] = 0.4886025119029199 * x # l2 - if levels > 2: + if degree > 1: components[..., 4] = 1.0925484305920792 * x * y components[..., 5] = 1.0925484305920792 * y * z components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999 @@ -63,7 +63,7 @@ def components_from_spherical_harmonics( components[..., 8] = 0.5462742152960396 * (xx - yy) # l3 - if levels > 3: + if degree > 2: components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy) components[..., 10] = 2.890611442640554 * x * y * z components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1) @@ -73,7 +73,7 @@ def components_from_spherical_harmonics( components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy) # l4 - if levels > 4: + if degree > 3: components[..., 16] = 2.5033429417967046 * x * y * (xx - yy) components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy) components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1) diff --git a/tests/field_components/test_encodings.py b/tests/field_components/test_encodings.py index a241fc52a1..63dc9a0261 100644 --- a/tests/field_components/test_encodings.py +++ b/tests/field_components/test_encodings.py @@ -125,11 +125,11 @@ def test_tensor_cp_encoder(): def test_tensor_sh_encoder(): """Test Spherical Harmonic encoder""" - levels = 4 + levels = 5 out_dim = levels**2 with pytest.raises(ValueError): - encoder = encodings.SHEncoding(levels=5) + encoder = encodings.SHEncoding(levels=6) encoder = encodings.SHEncoding(levels=levels) assert encoder.get_out_dim() == out_dim diff --git a/tests/utils/test_spherical_harmonics.py b/tests/utils/test_spherical_harmonics.py index 3970262609..a8949891a1 100644 --- a/tests/utils/test_spherical_harmonics.py +++ b/tests/utils/test_spherical_harmonics.py @@ -1,16 +1,16 @@ import pytest import torch -from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics +from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics, num_sh_bases -@pytest.mark.parametrize("components", list(range(1, 5))) -def test_spherical_harmonics(components): +@pytest.mark.parametrize("degree", list(range(0, 5))) +def test_spherical_harmonics(degree): torch.manual_seed(0) N = 1000000 dx = torch.normal(0, 1, size=(N, 3)) dx = dx / torch.linalg.norm(dx, dim=-1, keepdim=True) - sh = components_from_spherical_harmonics(components, dx) + sh = components_from_spherical_harmonics(degree, dx) matrix = (sh.T @ sh) / N * 4 * torch.pi - torch.testing.assert_close(matrix, torch.eye(components**2), rtol=0, atol=1.5e-2) + torch.testing.assert_close(matrix, torch.eye(num_sh_bases(degree)), rtol=0, atol=1.5e-2)