Skip to content

Commit

Permalink
Make Rotation.from_axis_angle return a jaxlie.SO3
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jul 12, 2024
1 parent 0ec08f8 commit 05bce55
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def compute_R() -> tuple[jtp.Matrix, jtp.Array]:
axis = jnp.array(joint_axis).astype(float).squeeze()

pre_H_suc = jaxlie.SE3.from_rotation(
rotation=jaxlie.SO3.from_matrix(Rotation.from_axis_angle(vector=s * axis))
rotation=Rotation.from_axis_angle(vector=s * axis)
)

S = jnp.vstack(jnp.hstack([jnp.zeros(3), axis]))
Expand Down
12 changes: 7 additions & 5 deletions src/jaxsim/math/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def theta_is_not_zero(theta_and_v: Tuple[jtp.Float, jtp.Vector]) -> jtp.Matrix:

return R.transpose()

return jax.lax.cond(
pred=(theta == 0.0),
true_fun=lambda operand: jnp.eye(3),
false_fun=theta_is_not_zero,
operand=(theta, vector),
return jaxlie.SO3.from_matrix(
jax.lax.cond(
pred=(theta == 0.0),
true_fun=lambda operand: jnp.eye(3),
false_fun=theta_is_not_zero,
operand=(theta, vector),
)
)

0 comments on commit 05bce55

Please sign in to comment.