Skip to content

Commit

Permalink
Move util functions out of splatfacto (#3538)
Browse files Browse the repository at this point in the history
* Move util functions out of splatfacto

Nothing else currently uses some of the SH utils, but it might make sense to get them out of splatfacto.

I also moved the k nearest neighbors to utils since it doesn't depend on the model class.

* fix assert

* fix sh test

* convert to using degrees, not levels
  • Loading branch information
akristoffersen authored Dec 5, 2024
1 parent e8bf472 commit a8888e7
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 181 deletions.
15 changes: 9 additions & 6 deletions nerfstudio/field_components/encodings.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@

from nerfstudio.field_components.base_field_component import FieldComponent
from nerfstudio.utils.external import TCNN_EXISTS, tcnn
from nerfstudio.utils.math import components_from_spherical_harmonics, expected_sin, generate_polyhedron_basis
from nerfstudio.utils.math import expected_sin, generate_polyhedron_basis
from nerfstudio.utils.printing import print_tcnn_speed_warning
from nerfstudio.utils.spherical_harmonics import MAX_SH_DEGREE, components_from_spherical_harmonics


class Encoding(FieldComponent):
Expand Down Expand Up @@ -756,14 +757,16 @@ class SHEncoding(Encoding):
"""Spherical harmonic encoding
Args:
levels: Number of spherical harmonic levels to encode.
levels: Number of spherical harmonic levels to encode. (level = sh degree + 1)
"""

def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "torch") -> None:
super().__init__(in_dim=3)

if levels <= 0 or levels > 4:
raise ValueError(f"Spherical harmonic encoding only supports 1 to 4 levels, requested {levels}")
if levels <= 0 or levels > MAX_SH_DEGREE + 1:
raise ValueError(
f"Spherical harmonic encoding only supports 1 to {MAX_SH_DEGREE + 1} levels, requested {levels}"
)

self.levels = levels

Expand All @@ -778,7 +781,7 @@ def __init__(self, levels: int = 4, implementation: Literal["tcnn", "torch"] = "
)

@classmethod
def get_tcnn_encoding_config(cls, levels) -> dict:
def get_tcnn_encoding_config(cls, levels: int) -> dict:
"""Get the encoding configuration for tcnn if implemented"""
encoding_config = {
"otype": "SphericalHarmonics",
Expand All @@ -792,7 +795,7 @@ def get_out_dim(self) -> int:
@torch.no_grad()
def pytorch_fwd(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
"""Forward pass using pytorch. Significantly slower than TCNN implementation."""
return components_from_spherical_harmonics(levels=self.levels, directions=in_tensor)
return components_from_spherical_harmonics(degree=self.levels - 1, directions=in_tensor)

def forward(self, in_tensor: Float[Tensor, "*bs input_dim"]) -> Float[Tensor, "*bs output_dim"]:
if self.tcnn_encoding is not None:
Expand Down
5 changes: 3 additions & 2 deletions nerfstudio/model_components/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@

from nerfstudio.cameras.rays import RaySamples
from nerfstudio.utils import colors
from nerfstudio.utils.math import components_from_spherical_harmonics, safe_normalize
from nerfstudio.utils.math import safe_normalize
from nerfstudio.utils.spherical_harmonics import components_from_spherical_harmonics

BackgroundColor = Union[Literal["random", "last_sample", "black", "white"], Float[Tensor, "3"], Float[Tensor, "*bs 3"]]
BACKGROUND_COLOR_OVERRIDE: Optional[Float[Tensor, "3"]] = None
Expand Down Expand Up @@ -268,7 +269,7 @@ def forward(
sh = sh.view(*sh.shape[:-1], 3, sh.shape[-1] // 3)

levels = int(math.sqrt(sh.shape[-1]))
components = components_from_spherical_harmonics(levels=levels, directions=directions)
components = components_from_spherical_harmonics(degree=levels - 1, directions=directions)

rgb = sh * components[..., None, :] # [..., num_samples, 3, sh_components]
rgb = torch.sum(rgb, dim=-1) # [..., num_samples, 3]
Expand Down
89 changes: 3 additions & 86 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@

from __future__ import annotations

import math
from dataclasses import dataclass, field
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

import numpy as np
import torch
from gsplat.strategy import DefaultStrategy

Expand All @@ -42,70 +40,10 @@
from nerfstudio.model_components.lib_bilagrid import BilateralGrid, color_correct, slice, total_variation_loss
from nerfstudio.models.base_model import Model, ModelConfig
from nerfstudio.utils.colors import get_color
from nerfstudio.utils.math import k_nearest_sklearn, random_quat_tensor
from nerfstudio.utils.misc import torch_compile
from nerfstudio.utils.rich_utils import CONSOLE


def num_sh_bases(degree: int) -> int:
"""
Returns the number of spherical harmonic bases for a given degree.
"""
assert degree <= 4, "We don't support degree greater than 4."
return (degree + 1) ** 2


def quat_to_rotmat(quat):
assert quat.shape[-1] == 4, quat.shape
w, x, y, z = torch.unbind(quat, dim=-1)
mat = torch.stack(
[
1 - 2 * (y**2 + z**2),
2 * (x * y - w * z),
2 * (x * z + w * y),
2 * (x * y + w * z),
1 - 2 * (x**2 + z**2),
2 * (y * z - w * x),
2 * (x * z - w * y),
2 * (y * z + w * x),
1 - 2 * (x**2 + y**2),
],
dim=-1,
)
return mat.reshape(quat.shape[:-1] + (3, 3))


def random_quat_tensor(N):
"""
Defines a random quaternion tensor of shape (N, 4)
"""
u = torch.rand(N)
v = torch.rand(N)
w = torch.rand(N)
return torch.stack(
[
torch.sqrt(1 - u) * torch.sin(2 * math.pi * v),
torch.sqrt(1 - u) * torch.cos(2 * math.pi * v),
torch.sqrt(u) * torch.sin(2 * math.pi * w),
torch.sqrt(u) * torch.cos(2 * math.pi * w),
],
dim=-1,
)


def RGB2SH(rgb):
"""
Converts from RGB values [0,1] to the 0th spherical harmonic coefficient
"""
C0 = 0.28209479177387814
return (rgb - 0.5) / C0


def SH2RGB(sh):
"""
Converts from the 0th spherical harmonic coefficient to RGB values [0,1]
"""
C0 = 0.28209479177387814
return sh * C0 + 0.5
from nerfstudio.utils.spherical_harmonics import RGB2SH, SH2RGB, num_sh_bases


def resize_image(image: torch.Tensor, d: int):
Expand Down Expand Up @@ -243,8 +181,7 @@ def populate_modules(self):
means = torch.nn.Parameter(self.seed_points[0]) # (Location, Color)
else:
means = torch.nn.Parameter((torch.rand((self.config.num_random, 3)) - 0.5) * self.config.random_scale)
distances, _ = self.k_nearest_sklearn(means.data, 3)
distances = torch.from_numpy(distances)
distances, _ = k_nearest_sklearn(means.data, 3)
# find the average of the three nearest neighbors for each point and use that as the scale
avg_dist = distances.mean(dim=-1, keepdim=True)
scales = torch.nn.Parameter(torch.log(avg_dist.repeat(1, 3)))
Expand Down Expand Up @@ -392,26 +329,6 @@ def load_state_dict(self, dict, **kwargs): # type: ignore
self.gauss_params[name] = torch.nn.Parameter(torch.zeros(new_shape, device=self.device))
super().load_state_dict(dict, **kwargs)

def k_nearest_sklearn(self, x: torch.Tensor, k: int):
"""
Find k-nearest neighbors using sklearn's NearestNeighbors.
x: The data tensor of shape [num_samples, num_features]
k: The number of neighbors to retrieve
"""
# Convert tensor to numpy array
x_np = x.cpu().numpy()

# Build the nearest neighbors model
from sklearn.neighbors import NearestNeighbors

nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric="euclidean").fit(x_np)

# Find the k-nearest neighbors
distances, indices = nn_model.kneighbors(x_np)

# Exclude the point itself from the result and return
return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32)

def set_crop(self, crop_box: Optional[OrientedBox]):
self.crop_box = crop_box

Expand Down
132 changes: 63 additions & 69 deletions nerfstudio/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,78 +20,12 @@
from typing import Literal, Tuple

import torch
from jaxtyping import Bool, Float
from jaxtyping import Bool, Float, Int
from torch import Tensor

from nerfstudio.data.scene_box import OrientedBox


def components_from_spherical_harmonics(
levels: int, directions: Float[Tensor, "*batch 3"]
) -> Float[Tensor, "*batch components"]:
"""
Returns value for each component of spherical harmonics.
Args:
levels: Number of spherical harmonic levels to compute.
directions: Spherical harmonic coefficients
"""
num_components = levels**2
components = torch.zeros((*directions.shape[:-1], num_components), device=directions.device)

assert 1 <= levels <= 5, f"SH levels must be in [1,4], got {levels}"
assert directions.shape[-1] == 3, f"Direction input should have three dimensions. Got {directions.shape[-1]}"

x = directions[..., 0]
y = directions[..., 1]
z = directions[..., 2]

xx = x**2
yy = y**2
zz = z**2

# l0
components[..., 0] = 0.28209479177387814

# l1
if levels > 1:
components[..., 1] = 0.4886025119029199 * y
components[..., 2] = 0.4886025119029199 * z
components[..., 3] = 0.4886025119029199 * x

# l2
if levels > 2:
components[..., 4] = 1.0925484305920792 * x * y
components[..., 5] = 1.0925484305920792 * y * z
components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999
components[..., 7] = 1.0925484305920792 * x * z
components[..., 8] = 0.5462742152960396 * (xx - yy)

# l3
if levels > 3:
components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy)
components[..., 10] = 2.890611442640554 * x * y * z
components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1)
components[..., 12] = 0.3731763325901154 * z * (5 * zz - 3)
components[..., 13] = 0.4570457994644658 * x * (5 * zz - 1)
components[..., 14] = 1.445305721320277 * z * (xx - yy)
components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy)

# l4
if levels > 4:
components[..., 16] = 2.5033429417967046 * x * y * (xx - yy)
components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy)
components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1)
components[..., 19] = 0.6690465435572892 * y * z * (7 * zz - 3)
components[..., 20] = 0.10578554691520431 * (35 * zz * zz - 30 * zz + 3)
components[..., 21] = 0.6690465435572892 * x * z * (7 * zz - 3)
components[..., 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1)
components[..., 23] = 1.7701307697799304 * x * z * (xx - 3 * yy)
components[..., 24] = 0.6258357354491761 * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))

return components


@dataclass
class Gaussians:
"""Stores Gaussians
Expand Down Expand Up @@ -323,7 +257,9 @@ def masked_reduction(


def normalized_depth_scale_and_shift(
prediction: Float[Tensor, "1 32 mult"], target: Float[Tensor, "1 32 mult"], mask: Bool[Tensor, "1 32 mult"]
prediction: Float[Tensor, "1 32 mult"],
target: Float[Tensor, "1 32 mult"],
mask: Bool[Tensor, "1 32 mult"],
):
"""
More info here: https://arxiv.org/pdf/2206.00665.pdf supplementary section A2 Depth Consistency Loss
Expand Down Expand Up @@ -405,7 +341,10 @@ def _compute_tesselation_weights(v: int) -> Tensor:


def _tesselate_geodesic(
vertices: Float[Tensor, "N 3"], faces: Float[Tensor, "M 3"], v: int, eps: float = 1e-4
vertices: Float[Tensor, "N 3"],
faces: Float[Tensor, "M 3"],
v: int,
eps: float = 1e-4,
) -> Tensor:
"""Tesselate the vertices of a geodesic polyhedron.
Expand Down Expand Up @@ -518,3 +457,58 @@ def generate_polyhedron_basis(

basis = verts.flip(-1)
return basis


def random_quat_tensor(N: int) -> Float[Tensor, "*batch 4"]:
"""
Defines a random quaternion tensor.
Args:
N: Number of quaternions to generate
Returns:
a random quaternion tensor of shape (N, 4)
"""
u = torch.rand(N)
v = torch.rand(N)
w = torch.rand(N)
return torch.stack(
[
torch.sqrt(1 - u) * torch.sin(2 * math.pi * v),
torch.sqrt(1 - u) * torch.cos(2 * math.pi * v),
torch.sqrt(u) * torch.sin(2 * math.pi * w),
torch.sqrt(u) * torch.cos(2 * math.pi * w),
],
dim=-1,
)


def k_nearest_sklearn(
x: torch.Tensor, k: int, metric: str = "euclidean"
) -> Tuple[Float[Tensor, "*batch k"], Int[Tensor, "*batch k"]]:
"""
Find k-nearest neighbors using sklearn's NearestNeighbors.
Args:
x: input tensor
k: number of neighbors to find
metric: metric to use for distance computation
Returns:
distances: distances to the k-nearest neighbors
indices: indices of the k-nearest neighbors
"""
# Convert tensor to numpy array
x_np = x.cpu().numpy()

# Build the nearest neighbors model
from sklearn.neighbors import NearestNeighbors

nn_model = NearestNeighbors(n_neighbors=k + 1, algorithm="auto", metric=metric).fit(x_np)

# Find the k-nearest neighbors
distances, indices = nn_model.kneighbors(x_np)

# Exclude the point itself from the result and return
return torch.tensor(distances[:, 1:], dtype=torch.float32), torch.tensor(indices[:, 1:], dtype=torch.int64)
Loading

0 comments on commit a8888e7

Please sign in to comment.