Skip to content

Commit

Permalink
Augmented quaternions are used only for train data
Browse files Browse the repository at this point in the history
  • Loading branch information
AndReGeist committed Mar 25, 2024
1 parent 66d8fdb commit f1a669c
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 28 deletions.
48 changes: 25 additions & 23 deletions hitchhiking_rotations/utils/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,27 @@
# All rights reserved. Licensed under the MIT license.
# See LICENSE file in the project root for details.
#
from .euler_helper import euler_angles_to_matrix, matrix_to_euler_angles
from hitchhiking_rotations.utils.euler_helper import euler_angles_to_matrix, matrix_to_euler_angles
import roma
import torch


def euler_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
def euler_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor:
return euler_angles_to_matrix(inp.reshape(-1, 3), convention="XZY")


def quaternion_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
def quaternion_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor:
# without normalization
# normalize first
x = inp.reshape(-1, 4)
return roma.unitquat_to_rotmat(x / x.norm(dim=1, keepdim=True))


def gramschmidt_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
def gramschmidt_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor:
return roma.special_gramschmidt(inp.reshape(-1, 3, 2))


def symmetric_orthogonalization(x):
def symmetric_orthogonalization(x, **kwargs):
"""Maps 9D input vectors onto SO(3) via symmetric orthogonalization.
x: should have size [batch_size, 9]
Expand All @@ -40,72 +40,76 @@ def symmetric_orthogonalization(x):
return r


def procrustes_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
def procrustes_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor:
return symmetric_orthogonalization(inp)
return roma.special_procrustes(inp.reshape(-1, 3, 3))


def rotvec_to_rotmat(inp: torch.Tensor) -> torch.Tensor:
def rotvec_to_rotmat(inp: torch.Tensor, **kwargs) -> torch.Tensor:
return roma.rotvec_to_rotmat(inp.reshape(-1, 3))


# rotmat to x / maybe here reshape is missing


def rotmat_to_euler(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_euler(base: torch.Tensor, **kwargs) -> torch.Tensor:
return matrix_to_euler_angles(base, convention="XZY")


def rotmat_to_quaternion(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_quaternion(base: torch.Tensor, **kwargs) -> torch.Tensor:
return roma.rotmat_to_unitquat(base)


def rotmat_to_quaternion_rand_flip(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_quaternion_rand_flip(base: torch.Tensor, **kwargs) -> torch.Tensor:
rep = roma.rotmat_to_unitquat(base)
rand_flipping = torch.rand(base.shape[0]) > 0.5
rep[rand_flipping] *= -1
return rep


def rotmat_to_quaternion_canonical(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_quaternion_canonical(base: torch.Tensor, **kwargs) -> torch.Tensor:
rep = roma.rotmat_to_unitquat(base)
rep[rep[:, 3] < 0] *= -1
return rep


def rotmat_to_quaternion_aug(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_quaternion_aug(base: torch.Tensor, mode: str) -> torch.Tensor:
"""Performs memory-efficient quaternion augmentation by randomly
selecting some quaternions in the batch for which the scalar part
is smaller than 0.1 then multiply the selected quaternions by -1.
"""
rep = rotmat_to_quaternion_canonical(base)
idxs = torch.arange(rep.size(0), device=rep.device)[rep[:, 3] < 0.1]
num_rows_to_flip = torch.randint(0, idxs.size(0) + 1, (1,)).item()
random_indices = torch.randperm(idxs.size(0), device=rep.device)[:num_rows_to_flip]
selected_idxs = idxs[random_indices]
rep[selected_idxs] *= -1

if mode == "train":
idxs = torch.arange(rep.size(0), device=rep.device)[rep[:, 3] < 0.1]
num_rows_to_flip = torch.randint(0, idxs.size(0) + 1, (1,)).item()
random_indices = torch.randperm(idxs.size(0), device=rep.device)[:num_rows_to_flip]
selected_idxs = idxs[random_indices]
rep[selected_idxs] *= -1

return rep


def rotmat_to_gramschmidt(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_gramschmidt(base: torch.Tensor, **kwargs) -> torch.Tensor:
return base[:, :, :2]


def rotmat_to_gramschmidt_f(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_gramschmidt_f(base: torch.Tensor, **kwargs) -> torch.Tensor:
return base[:, :, :2].reshape(-1, 6)


def rotmat_to_procrustes(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_procrustes(base: torch.Tensor, **kwargs) -> torch.Tensor:
return base


def rotmat_to_rotvec(base: torch.Tensor) -> torch.Tensor:
def rotmat_to_rotvec(base: torch.Tensor, **kwargs) -> torch.Tensor:
return roma.rotmat_to_rotvec(base)


def test_all():
from scipy.spatial.transform import Rotation
import numpy as np
from torch import from_numpy as tr

rs = Rotation.random(1000)
euler = rs.as_euler("XZY", degrees=False)
Expand All @@ -114,8 +118,6 @@ def test_all():
quat_hm = np.where(quat[:, 3:4] < 0, -quat, quat)
rotvec = rs.as_rotvec()

tr = lambda x: torch.from_numpy(x)

# euler_to_rotmat
print(np.allclose(euler_to_rotmat(tr(euler)).numpy(), rot))
print(np.allclose(quaternion_to_rotmat(tr(quat)).numpy(), rot))
Expand Down
6 changes: 3 additions & 3 deletions hitchhiking_rotations/utils/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
# All rights reserved. Licensed under the MIT license.
# See LICENSE file in the project root for details.
#
def passthrough(*x):
def passthrough(*x, **kwargs):
if len(x) == 1:
return x[0]
return x


def flatten(x):
def flatten(x, **kwargs):
return x.reshape(x.shape[0], -1)


def n_3x3(x):
def n_3x3(x, **kwargs):
return x.reshape(-1, 3, 3)
4 changes: 2 additions & 2 deletions hitchhiking_rotations/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def train_batch(self, x, target, epoch):
with torch.no_grad():
pp_target = self.preprocess_target(target)

x = self.preprocess_input(x)
x = self.preprocess_input(x, mode="train")
pred = self.model(x)
pred_loss = self.postprocess_pred_loss(pred)

Expand All @@ -96,7 +96,7 @@ def train_batch(self, x, target, epoch):
@torch.no_grad()
def test_batch(self, x, target, epoch, mode):
self.model.eval()
x = self.preprocess_input(x)
x = self.preprocess_input(x, mode="test")
pred = self.model(x)
pred_loss = self.postprocess_pred_loss(pred)
pp_target = self.preprocess_target(target)
Expand Down

0 comments on commit f1a669c

Please sign in to comment.