diff --git a/hitchhiking_rotations/utils/conversions.py b/hitchhiking_rotations/utils/conversions.py index ce5154c..ef99201 100644 --- a/hitchhiking_rotations/utils/conversions.py +++ b/hitchhiking_rotations/utils/conversions.py @@ -3,27 +3,27 @@ # All rights reserved. Licensed under the MIT license. # See LICENSE file in the project root for details. # -from .euler_helper import euler_angles_to_matrix, matrix_to_euler_angles +from hitchhiking_rotations.utils.euler_helper import euler_angles_to_matrix, matrix_to_euler_angles import roma import torch -def euler_to_rotmat(inp: torch.Tensor) -> torch.Tensor: +def euler_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor: return euler_angles_to_matrix(inp.reshape(-1, 3), convention="XZY") -def quaternion_to_rotmat(inp: torch.Tensor) -> torch.Tensor: +def quaternion_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor: # without normalization # normalize first x = inp.reshape(-1, 4) return roma.unitquat_to_rotmat(x / x.norm(dim=1, keepdim=True)) -def gramschmidt_to_rotmat(inp: torch.Tensor) -> torch.Tensor: +def gramschmidt_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor: return roma.special_gramschmidt(inp.reshape(-1, 3, 2)) -def symmetric_orthogonalization(x): +def symmetric_orthogonalization(x, **kwargs): """Maps 9D input vectors onto SO(3) via symmetric orthogonalization. x: should have size [batch_size, 9] @@ -40,72 +40,76 @@ def symmetric_orthogonalization(x): return r -def procrustes_to_rotmat(inp: torch.Tensor) -> torch.Tensor: +def procrustes_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor: return symmetric_orthogonalization(inp) return roma.special_procrustes(inp.reshape(-1, 3, 3)) -def rotvec_to_rotmat(inp: torch.Tensor) -> torch.Tensor: +def rotvec_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor: return roma.rotvec_to_rotmat(inp.reshape(-1, 3)) # rotmat to x / maybe here reshape is missing -def rotmat_to_euler(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_euler(base: torch.Tensor, **kwargs) -> torch.Tensor: return matrix_to_euler_angles(base, convention="XZY") -def rotmat_to_quaternion(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_quaternion(base: torch.Tensor, **kwargs) -> torch.Tensor: return roma.rotmat_to_unitquat(base) -def rotmat_to_quaternion_rand_flip(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_quaternion_rand_flip(base: torch.Tensor, **kwargs) -> torch.Tensor: rep = roma.rotmat_to_unitquat(base) rand_flipping = torch.rand(base.shape[0]) > 0.5 rep[rand_flipping] *= -1 return rep -def rotmat_to_quaternion_canonical(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_quaternion_canonical(base: torch.Tensor, **kwargs) -> torch.Tensor: rep = roma.rotmat_to_unitquat(base) rep[rep[:, 3] < 0] *= -1 return rep -def rotmat_to_quaternion_aug(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_quaternion_aug(base: torch.Tensor, mode: str) -> torch.Tensor: """Performs memory-efficient quaternion augmentation by randomly selecting some quaternions in the batch for which the scalar part is smaller than 0.1 then multiply the selected quaternions by -1. """ rep = rotmat_to_quaternion_canonical(base) - idxs = torch.arange(rep.size(0), device=rep.device)[rep[:, 3] < 0.1] - num_rows_to_flip = torch.randint(0, idxs.size(0) + 1, (1,)).item() - random_indices = torch.randperm(idxs.size(0), device=rep.device)[:num_rows_to_flip] - selected_idxs = idxs[random_indices] - rep[selected_idxs] *= -1 + + if mode == "train": + idxs = torch.arange(rep.size(0), device=rep.device)[rep[:, 3] < 0.1] + num_rows_to_flip = torch.randint(0, idxs.size(0) + 1, (1,)).item() + random_indices = torch.randperm(idxs.size(0), device=rep.device)[:num_rows_to_flip] + selected_idxs = idxs[random_indices] + rep[selected_idxs] *= -1 + return rep -def rotmat_to_gramschmidt(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_gramschmidt(base: torch.Tensor, **kwargs) -> torch.Tensor: return base[:, :, :2] -def rotmat_to_gramschmidt_f(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_gramschmidt_f(base: torch.Tensor, **kwargs) -> torch.Tensor: return base[:, :, :2].reshape(-1, 6) -def rotmat_to_procrustes(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_procrustes(base: torch.Tensor, **kwargs) -> torch.Tensor: return base -def rotmat_to_rotvec(base: torch.Tensor) -> torch.Tensor: +def rotmat_to_rotvec(base: torch.Tensor, **kwargs) -> torch.Tensor: return roma.rotmat_to_rotvec(base) def test_all(): from scipy.spatial.transform import Rotation import numpy as np + from torch import from_numpy as tr rs = Rotation.random(1000) euler = rs.as_euler("XZY", degrees=False) @@ -114,8 +118,6 @@ def test_all(): quat_hm = np.where(quat[:, 3:4] < 0, -quat, quat) rotvec = rs.as_rotvec() - tr = lambda x: torch.from_numpy(x) - # euler_to_rotmat print(np.allclose(euler_to_rotmat(tr(euler)).numpy(), rot)) print(np.allclose(quaternion_to_rotmat(tr(quat)).numpy(), rot)) diff --git a/hitchhiking_rotations/utils/helper.py b/hitchhiking_rotations/utils/helper.py index 7fb8bc1..34491aa 100644 --- a/hitchhiking_rotations/utils/helper.py +++ b/hitchhiking_rotations/utils/helper.py @@ -3,15 +3,15 @@ # All rights reserved. Licensed under the MIT license. # See LICENSE file in the project root for details. # -def passthrough(*x): +def passthrough(*x, **kwargs): if len(x) == 1: return x[0] return x -def flatten(x): +def flatten(x, **kwargs): return x.reshape(x.shape[0], -1) -def n_3x3(x): +def n_3x3(x, **kwargs): return x.reshape(-1, 3, 3) diff --git a/hitchhiking_rotations/utils/trainer.py b/hitchhiking_rotations/utils/trainer.py index 6dc5041..31df337 100644 --- a/hitchhiking_rotations/utils/trainer.py +++ b/hitchhiking_rotations/utils/trainer.py @@ -78,7 +78,7 @@ def train_batch(self, x, target, epoch): with torch.no_grad(): pp_target = self.preprocess_target(target) - x = self.preprocess_input(x) + x = self.preprocess_input(x, mode="train") pred = self.model(x) pred_loss = self.postprocess_pred_loss(pred) @@ -96,7 +96,7 @@ def train_batch(self, x, target, epoch): @torch.no_grad() def test_batch(self, x, target, epoch, mode): self.model.eval() - x = self.preprocess_input(x) + x = self.preprocess_input(x, mode="test") pred = self.model(x) pred_loss = self.postprocess_pred_loss(pred) pp_target = self.preprocess_target(target)