diff --git a/hitchhiking_rotations/utils/conversions.py b/hitchhiking_rotations/utils/conversions.py index 25b903f..a7117b8 100644 --- a/hitchhiking_rotations/utils/conversions.py +++ b/hitchhiking_rotations/utils/conversions.py @@ -82,7 +82,7 @@ def rotmat_to_quaternion_aug(base: torch.Tensor, mode: str) -> torch.Tensor: rep = rotmat_to_quaternion_canonical(base) if mode == "train": - rep[torch.logical_and(torch.rand(rep.size(0), device=rep.device) < 0.5, rep[:, 3] < 0.3)] *= -1 + rep[torch.logical_and(torch.rand(rep.size(0), device=rep.device) < 0.5, rep[:, 3] < 0.1)] *= -1 return rep