Skip to content

Commit

Permalink
convert to using degrees, not levels
Browse files Browse the repository at this point in the history
  • Loading branch information
akristoffersen committed Dec 2, 2024
1 parent 872668d commit b44cc41
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 21 deletions.
10 changes: 5 additions & 5 deletions nerfstudio/field_components/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,15 +757,15 @@ class SHEncoding(Encoding):
"""Spherical harmonic encoding
Args:
levels: Number of spherical harmonic levels to encode.
levels: Number of spherical harmonic levels to encode. (level = sh degree + 1)
"""

def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "torch") -> None:
super().__init__(in_dim=3)

if levels <= 0 or levels > MAX_SH_DEGREE:
if levels <= 0 or levels > MAX_SH_DEGREE + 1:
raise ValueError(
f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE} levels, requested {levels}"
f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE + 1} levels, requested {levels}"
)

self.levels = levels
Expand All @@ -781,7 +781,7 @@ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "
)

@classmethod
def get_tcnn_encoding_config(cls, levels) -> dict:
def get_tcnn_encoding_config(cls, levels: int) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
encoding_config = {
"otype": "SphericalHarmonics",
Expand All @@ -795,7 +795,7 @@ def get_out_dim(self) -> int:
@torch.no_grad()
def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
"""Forward pass using pytorch. Significantly slower than TCNN implementation."""
return components_from_spherical_harmonics(levels=self.levels, directions=in_tensor)
return components_from_spherical_harmonics(degree=self.levels - 1, directions=in_tensor)

def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
if self.tcnn_encoding is not None:
Expand Down
2 changes: 1 addition & 1 deletion nerfstudio/model_components/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def forward(
sh = sh.view(*sh.shape[:-1], 3, sh.shape[-1] // 3)

levels = int(math.sqrt(sh.shape[-1]))
components = components_from_spherical_harmonics(levels=levels, directions=directions)
components = components_from_spherical_harmonics(degree=levels - 1, directions=directions)

rgb = sh * components[..., None, :] # [..., num_samples, 3, sh_components]
rgb = torch.sum(rgb, dim=-1) # [..., num_samples, 3]
Expand Down
16 changes: 8 additions & 8 deletions nerfstudio/utils/spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@


def components_from_spherical_harmonics(
levels: int, directions: Float[Tensor, "*batch 3"]
degree: int, directions: Float[Tensor, "*batch 3"]
) -> Float[Tensor, "*batch components"]:
"""
Returns value for each component of spherical harmonics.
Args:
levels: Number of spherical harmonic levels to compute.
degree: Number of spherical harmonic degrees to compute.
directions: Spherical harmonic coefficients
"""
num_components = levels**2
num_components = num_sh_bases(degree)
components = torch.zeros((*directions.shape[:-1], num_components), device=directions.device)

assert 1 <= levels <= MAX_SH_DEGREE, f"SH levels must be in [1,{MAX_SH_DEGREE}], got {levels}"
assert 0 <= degree <= MAX_SH_DEGREE, f"SH degree must be in [0, {MAX_SH_DEGREE}], got {degree}"
assert directions.shape[-1] == 3, f"Direction input should have three dimensions. Got {directions.shape[-1]}"

x = directions[..., 0]
Expand All @@ -49,21 +49,21 @@ def components_from_spherical_harmonics(
components[..., 0] = 0.28209479177387814

# l1
if levels > 1:
if degree > 0:
components[..., 1] = 0.4886025119029199 * y
components[..., 2] = 0.4886025119029199 * z
components[..., 3] = 0.4886025119029199 * x

# l2
if levels > 2:
if degree > 1:
components[..., 4] = 1.0925484305920792 * x * y
components[..., 5] = 1.0925484305920792 * y * z
components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999
components[..., 7] = 1.0925484305920792 * x * z
components[..., 8] = 0.5462742152960396 * (xx - yy)

# l3
if levels > 3:
if degree > 2:
components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy)
components[..., 10] = 2.890611442640554 * x * y * z
components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1)
Expand All @@ -73,7 +73,7 @@ def components_from_spherical_harmonics(
components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy)

# l4
if levels > 4:
if degree > 3:
components[..., 16] = 2.5033429417967046 * x * y * (xx - yy)
components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy)
components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1)
Expand Down
4 changes: 2 additions & 2 deletions tests/field_components/test_encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ def test_tensor_cp_encoder():
def test_tensor_sh_encoder():
"""Test Spherical Harmonic encoder"""

levels = 4
levels = 5
out_dim = levels**2

with pytest.raises(ValueError):
encoder = encodings.SHEncoding(levels=5)
encoder = encodings.SHEncoding(levels=6)

encoder = encodings.SHEncoding(levels=levels)
assert encoder.get_out_dim() == out_dim
Expand Down
10 changes: 5 additions & 5 deletions tests/utils/test_spherical_harmonics.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import pytest
import torch

from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics
from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics, num_sh_bases


@pytest.mark.parametrize("components", list(range(1, 5)))
def test_spherical_harmonics(components):
@pytest.mark.parametrize("degree", list(range(0, 5)))
def test_spherical_harmonics(degree):
torch.manual_seed(0)
N = 1000000

dx = torch.normal(0, 1, size=(N, 3))
dx = dx / torch.linalg.norm(dx, dim=-1, keepdim=True)
sh = components_from_spherical_harmonics(components, dx)
sh = components_from_spherical_harmonics(degree, dx)
matrix = (sh.T @ sh) / N * 4 * torch.pi
torch.testing.assert_close(matrix, torch.eye(components**2), rtol=0, atol=1.5e-2)
torch.testing.assert_close(matrix, torch.eye(num_sh_bases(degree)), rtol=0, atol=1.5e-2)

0 comments on commit b44cc41

Please sign in to comment.