Skip to content

Commit

Permalink
Simplify SO3 instantiation from quaternions
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jul 10, 2024
1 parent 8830119 commit 0ec08f8
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 28 deletions.
13 changes: 2 additions & 11 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import jaxsim.api as js
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.math import Quaternion
from jaxsim.rbda.contacts.soft import SoftContacts
from jaxsim.utils import Mutability
from jaxsim.utils.tracing import not_tracing
Expand Down Expand Up @@ -191,9 +190,7 @@ def build(

W_H_B = jaxlie.SE3.from_rotation_and_translation(
translation=base_position,
rotation=jaxlie.SO3.from_quaternion_xyzw(
base_quaternion[jnp.array([1, 2, 3, 0])]
),
rotation=jaxlie.SO3(wxyz=base_quaternion),
).as_matrix()

v_WB = JaxSimModelData.other_representation_to_inertial(
Expand Down Expand Up @@ -380,13 +377,7 @@ def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix
on_false=W_Q_B / jnp.linalg.norm(W_Q_B),
)

return (
W_Q_B
if not dcm
else jaxlie.SO3.from_quaternion_xyzw(
Quaternion.to_xyzw(wxyz=W_Q_B)
).as_matrix()
).astype(float)
return (W_Q_B if not dcm else jaxlie.SO3(wxyz=W_Q_B).as_matrix()).astype(float)

@jax.jit
def base_transform(self) -> jtp.Matrix:
Expand Down
3 changes: 1 addition & 2 deletions src/jaxsim/math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import jaxsim.typing as jtp

from .quaternion import Quaternion
from .skew import Skew


Expand Down Expand Up @@ -31,7 +30,7 @@ def from_quaternion_and_translation(
assert quaternion.size == 4
assert translation.size == 3

Q_sixd = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion))
Q_sixd = jaxlie.SO3(wxyz=quaternion)
Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize()

return Adjoint.from_rotation_and_translation(
Expand Down
6 changes: 2 additions & 4 deletions src/jaxsim/math/quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix:
Returns:
jtp.Matrix: Direction cosine matrix (DCM).
"""
return jaxlie.SO3.from_quaternion_xyzw(
xyzw=Quaternion.to_xyzw(quaternion)
).as_matrix()
return jaxlie.SO3(wxyz=quaternion).as_matrix()

@staticmethod
def from_dcm(dcm: jtp.Matrix) -> jtp.Vector:
Expand Down Expand Up @@ -158,7 +156,7 @@ def integration(
A_Q_B = jnp.array(quaternion).squeeze().astype(float)

# Build the initial SO(3) quaternion.
W_Q_B_t0 = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=A_Q_B))
W_Q_B_t0 = jaxlie.SO3(wxyz=A_Q_B)

# Integrate the quaternion on the manifold.
W_Q_B_tf = jax.lax.select(
Expand Down
4 changes: 1 addition & 3 deletions src/jaxsim/math/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import jaxsim.typing as jtp

from .quaternion import Quaternion


class Transform:

Expand Down Expand Up @@ -35,7 +33,7 @@ def from_quaternion_and_translation(
assert W_p_B.size == 3
assert W_Q_B.size == 4

A_R_B = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(W_Q_B))
A_R_B = jaxlie.SO3(wxyz=W_Q_B)
A_R_B = A_R_B if not normalize_quaternion else A_R_B.normalize()

A_H_B = jaxlie.SE3.from_rotation_and_translation(
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/rbda/aba.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
from jaxsim.math import Adjoint, Cross, StandardGravity

from . import utils

Expand Down Expand Up @@ -77,7 +77,7 @@ def aba(

# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
rotation=jaxlie.SO3(wxyz=W_Q_B),
translation=W_p_B,
)

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/rbda/collidable_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Quaternion, Skew
from jaxsim.math import Adjoint, Skew

from . import utils

Expand Down Expand Up @@ -57,7 +57,7 @@ def collidable_points_pos_vel(

# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
rotation=jaxlie.SO3(wxyz=W_Q_B),
translation=W_p_B,
)

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/rbda/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Quaternion
from jaxsim.math import Adjoint

from . import utils

Expand Down Expand Up @@ -42,7 +42,7 @@ def forward_kinematics_model(

# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
rotation=jaxlie.SO3(wxyz=W_Q_B),
translation=W_p_B,
)

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/rbda/rnea.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Cross, Quaternion, StandardGravity
from jaxsim.math import Adjoint, Cross, StandardGravity

from . import utils

Expand Down Expand Up @@ -82,7 +82,7 @@ def rnea(

# Compute the base transform.
W_H_B = jaxlie.SE3.from_rotation_and_translation(
rotation=jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(wxyz=W_Q_B)),
rotation=jaxlie.SO3(wxyz=W_Q_B),
translation=W_p_B,
)

Expand Down

0 comments on commit 0ec08f8

Please sign in to comment.