From 35df9713da1eecdafbdc3273402a97704fd130df Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Wed, 13 Mar 2024 19:54:17 +0100 Subject: [PATCH] [wip] Prepare transition to jaxsim.api from jaxsim.{high_level|physics} --- src/jaxsim/__init__.py | 7 +++--- src/jaxsim/api/__init__.py | 3 ++- src/jaxsim/api/common.py | 13 +++++++++- src/jaxsim/api/contact.py | 29 +++++++++++++---------- src/jaxsim/api/data.py | 25 +++++++++---------- src/jaxsim/api/link.py | 12 ++++++---- src/jaxsim/api/model.py | 19 ++++++++------- src/jaxsim/api/ode.py | 14 +++++------ src/jaxsim/api/ode_data.py | 6 +++++ src/jaxsim/api/references.py | 3 ++- src/jaxsim/high_level/__init__.py | 2 +- src/jaxsim/high_level/common.py | 11 --------- src/jaxsim/integrators/__init__.py | 2 +- src/jaxsim/integrators/fixed_step.py | 14 +++++------ src/jaxsim/math/adjoint.py | 4 ++-- src/jaxsim/math/quaternion.py | 6 ++--- src/jaxsim/math/rotation.py | 8 +++---- src/jaxsim/parsers/descriptions/link.py | 4 ++-- src/jaxsim/parsers/rod/utils.py | 15 ++++++------ src/jaxsim/physics/model/physics_model.py | 6 ++--- src/jaxsim/rbda/__init__.py | 9 +++++++ src/jaxsim/simulation/__init__.py | 3 --- src/jaxsim/simulation/ode_data.py | 6 ++--- src/jaxsim/sixd/__init__.py | 2 -- src/jaxsim/terrain/__init__.py | 2 ++ tests/utils_idyntree.py | 2 +- 26 files changed, 119 insertions(+), 108 deletions(-) create mode 100644 src/jaxsim/api/ode_data.py delete mode 100644 src/jaxsim/high_level/common.py create mode 100644 src/jaxsim/rbda/__init__.py delete mode 100644 src/jaxsim/sixd/__init__.py create mode 100644 src/jaxsim/terrain/__init__.py diff --git a/src/jaxsim/__init__.py b/src/jaxsim/__init__.py index c06d2c7cc..ee588a8fa 100644 --- a/src/jaxsim/__init__.py +++ b/src/jaxsim/__init__.py @@ -61,7 +61,6 @@ def _is_editable() -> bool: del _np_options del _is_editable -from . import high_level, logging, math, simulation, sixd -from .high_level.common import VelRepr -from .simulation.ode_integration import IntegratorType -from .simulation.simulator import JaxSim +from . import terrain # isort:skip +from . import api, integrators, logging, math, rbda +from .api.common import VelRepr diff --git a/src/jaxsim/api/__init__.py b/src/jaxsim/api/__init__.py index 4f2e5d968..250d44c65 100644 --- a/src/jaxsim/api/__init__.py +++ b/src/jaxsim/api/__init__.py @@ -1,2 +1,3 @@ +from . import common # isort:skip from . import model, data # isort:skip -from . import common, contact, joint, kin_dyn_parameters, link, ode, references +from . import contact, joint, kin_dyn_parameters, link, ode, ode_data, references diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index 0fc33e3e0..f014906d7 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -1,6 +1,7 @@ import abc import contextlib import dataclasses +import enum import functools from typing import ContextManager @@ -11,7 +12,6 @@ from jax_dataclasses import Static import jaxsim.typing as jtp -from jaxsim.high_level.common import VelRepr from jaxsim.utils import JaxsimDataclass, Mutability try: @@ -20,6 +20,17 @@ from typing_extensions import Self +@enum.unique +class VelRepr(enum.IntEnum): + """ + Enumeration of all supported 6D velocity representations. + """ + + Body = enum.auto() + Mixed = enum.auto() + Inertial = enum.auto() + + @jax_dataclasses.pytree_dataclass class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC): """ diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 745353e0a..8b38f7e3a 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -4,8 +4,8 @@ import jax.numpy as jnp import jaxsim.api as js +import jaxsim.rbda import jaxsim.typing as jtp -from jaxsim.physics.algos import soft_contacts @jax.jit @@ -28,9 +28,9 @@ def collidable_point_kinematics( the linear component of the mixed 6D frame velocity. """ - from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel + from jaxsim.rbda import soft_contacts - W_p_Ci, W_ṗ_Ci = collidable_points_pos_vel( + W_p_Ci, W_ṗ_Ci = soft_contacts.collidable_points_pos_vel( model=model.physics_model, q=data.state.physics_model.joint_positions, qd=data.state.physics_model.joint_velocities, @@ -101,9 +101,9 @@ def in_contact( if set(link_names) - set(model.link_names()) != set(): raise ValueError("One or more link names are not part of the model") - from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel + from jaxsim.rbda import soft_contacts - W_p_Ci, _ = collidable_points_pos_vel( + W_p_Ci, _ = soft_contacts.collidable_points_pos_vel( model=model.physics_model, q=data.state.physics_model.joint_positions, qd=data.state.physics_model.joint_velocities, @@ -134,7 +134,7 @@ def estimate_good_soft_contacts_parameters( number_of_active_collidable_points_steady_state: jtp.IntLike = 1, damping_ratio: jtp.FloatLike = 1.0, max_penetration: jtp.FloatLike | None = None, -) -> soft_contacts.SoftContactsParams: +) -> jaxsim.rbda.soft_contacts.SoftContactsParams: """ Estimate good soft contacts parameters for the given model. @@ -162,7 +162,8 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: """""" zero_data = js.data.JaxSimModelData.build( - model=model, soft_contacts_params=soft_contacts.SoftContactsParams() + model=model, + soft_contacts_params=jaxsim.rbda.soft_contacts.SoftContactsParams(), ) W_pz_CoM = js.model.com_position(model=model, data=zero_data)[2] @@ -181,12 +182,14 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: nc = number_of_active_collidable_points_steady_state - sc_parameters = soft_contacts.SoftContactsParams.build_default_from_physics_model( - physics_model=model.physics_model, - static_friction_coefficient=static_friction_coefficient, - max_penetration=max_δ, - number_of_active_collidable_points_steady_state=nc, - damping_ratio=damping_ratio, + sc_parameters = ( + jaxsim.rbda.soft_contacts.SoftContactsParams.build_default_from_physics_model( + physics_model=model.physics_model, + static_friction_coefficient=static_friction_coefficient, + max_penetration=max_δ, + number_of_active_collidable_points_steady_state=nc, + damping_ratio=damping_ratio, + ) ) return sc_parameters diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index bec336c71..4be2a002e 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -11,19 +11,13 @@ import numpy as np import jaxsim.api as js -import jaxsim.physics.algos.aba -import jaxsim.physics.algos.crba -import jaxsim.physics.algos.forward_kinematics -import jaxsim.physics.algos.rnea -import jaxsim.physics.model.physics_model -import jaxsim.physics.model.physics_model_state +import jaxsim.rbda import jaxsim.typing as jtp -from jaxsim.high_level.common import VelRepr -from jaxsim.physics.algos import soft_contacts -from jaxsim.simulation.ode_data import ODEState from jaxsim.utils import Mutability from . import common +from .common import VelRepr +from .ode_data import ODEState try: from typing import Self @@ -41,9 +35,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): gravity: jtp.Array - soft_contacts_params: soft_contacts.SoftContactsParams = dataclasses.field( - repr=False + soft_contacts_params: jaxsim.rbda.soft_contacts.SoftContactsParams = ( + dataclasses.field(repr=False) ) + time_ns: jtp.Int = dataclasses.field( default_factory=lambda: jnp.array(0, dtype=jnp.uint64) ) @@ -96,8 +91,10 @@ def build( base_angular_velocity: jtp.Vector | None = None, joint_velocities: jtp.Vector | None = None, gravity: jtp.Vector | None = None, - soft_contacts_state: soft_contacts.SoftContactsState | None = None, - soft_contacts_params: soft_contacts.SoftContactsParams | None = None, + soft_contacts_state: jaxsim.rbda.soft_contacts.SoftContactsState | None = None, + soft_contacts_params: ( + jaxsim.rbda.soft_contacts.SoftContactsParams | None + ) = None, velocity_representation: VelRepr = VelRepr.Inertial, time: jtp.FloatLike | None = None, ) -> JaxSimModelData: @@ -186,7 +183,7 @@ def build( ode_state = ODEState.build( physics_model=model.physics_model, - physics_model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState( + physics_model_state=js.ode_data.PhysicsModelState( base_position=base_position.astype(float), base_quaternion=base_quaternion.astype(float), joint_positions=joint_positions.astype(float), diff --git a/src/jaxsim/api/link.py b/src/jaxsim/api/link.py index 6754b2bd6..cdc895fbc 100644 --- a/src/jaxsim/api/link.py +++ b/src/jaxsim/api/link.py @@ -7,9 +7,10 @@ import numpy as np import jaxsim.api as js -import jaxsim.physics.algos.jacobian +import jaxsim.rbda import jaxsim.typing as jtp -from jaxsim.high_level.common import VelRepr + +from .common import VelRepr # ======================= # Index-related functions @@ -210,11 +211,12 @@ def jacobian( velocity representation. """ - if output_vel_repr is None: - output_vel_repr = data.velocity_representation + output_vel_repr = ( + output_vel_repr if output_vel_repr is not None else data.velocity_representation + ) # Compute the doubly left-trivialized free-floating jacobian - L_J_WL_B = jaxsim.physics.algos.jacobian.jacobian( + L_J_WL_B = jaxsim.rbda.jacobian.jacobian( model=model.physics_model, body_index=link_index, q=data.joint_positions(), diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index f7e94b9a3..1c71d5861 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -14,16 +14,13 @@ import jaxsim.api as js import jaxsim.parsers.descriptions -import jaxsim.physics.algos.aba -import jaxsim.physics.algos.crba -import jaxsim.physics.algos.forward_kinematics -import jaxsim.physics.algos.rnea import jaxsim.physics.model.physics_model +import jaxsim.physics.model.physics_model_state import jaxsim.typing as jtp -from jaxsim.high_level.common import VelRepr -from jaxsim.physics.algos.terrain import FlatTerrain, Terrain from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability +from .common import VelRepr + @jax_dataclasses.pytree_dataclass class JaxSimModel(JaxsimDataclass): @@ -37,8 +34,8 @@ class JaxSimModel(JaxsimDataclass): repr=False, compare=False, hash=False ) - terrain: Static[Terrain] = dataclasses.field( - default=FlatTerrain(), repr=False, compare=False, hash=False + terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field( + default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False ) built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field( @@ -388,7 +385,7 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp The first axis is the link index. """ - W_H_LL = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model( + W_H_LL = jaxsim.rbda.forward_kinematics.forward_kinematics_model( model=model.physics_model, q=data.state.physics_model.joint_positions, xfb=data.state.physics_model.xfb(), @@ -719,6 +716,8 @@ def free_floating_mass_matrix( The free-floating mass matrix of the model. """ + import jaxsim.physics.algos.crba + M_body = jaxsim.physics.algos.crba.crba( model=model.physics_model, q=data.state.physics_model.joint_positions, @@ -852,6 +851,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_vl_WC): velocity_representation=data.velocity_representation, ) + import jaxsim.physics.algos.rnea + # Compute RNEA with references.switch_velocity_representation(VelRepr.Inertial): W_f_B, τ = jaxsim.physics.algos.rnea.rnea( diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 55bd14ad0..208039f71 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -2,17 +2,15 @@ import jax import jax.numpy as jnp -import jaxlie import jaxsim.api as js -import jaxsim.physics.algos.soft_contacts +import jaxsim.rbda import jaxsim.typing as jtp -from jaxsim import VelRepr, integrators -from jaxsim.integrators.common import Time +from jaxsim.integrators import Time from jaxsim.math.quaternion import Quaternion -from jaxsim.physics.algos.soft_contacts import SoftContactsState -from jaxsim.physics.model.physics_model_state import PhysicsModelState -from jaxsim.simulation.ode_data import ODEState + +from .common import VelRepr +from .ode_data import ODEState, PhysicsModelState, SoftContactsState class SystemDynamicsFromModelAndData(Protocol): @@ -127,7 +125,7 @@ def system_velocity_dynamics( # Compute the 3D forces applied to each collidable point. W_f_Ci, ṁ = jax.vmap( - lambda p, ṗ, m: jaxsim.physics.algos.soft_contacts.SoftContacts( + lambda p, ṗ, m: jaxsim.rbda.soft_contacts.SoftContacts( parameters=data.soft_contacts_params, terrain=model.terrain ).contact_model(position=p, velocity=ṗ, tangential_deformation=m) )(W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation.T) diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py new file mode 100644 index 000000000..13d64bba0 --- /dev/null +++ b/src/jaxsim/api/ode_data.py @@ -0,0 +1,6 @@ +from jaxsim.physics.algos.soft_contacts import SoftContactsState +from jaxsim.physics.model.physics_model_state import ( + PhysicsModelInput, + PhysicsModelState, +) +from jaxsim.simulation.ode_data import ODEInput, ODEState diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index 3a82a4e97..92ac77932 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -8,9 +8,10 @@ import jaxsim.api as js import jaxsim.typing as jtp -from jaxsim import VelRepr from jaxsim.simulation.ode_data import ODEInput +from .common import VelRepr + try: from typing import Self except ImportError: diff --git a/src/jaxsim/high_level/__init__.py b/src/jaxsim/high_level/__init__.py index 8d485d4a5..258be2468 100644 --- a/src/jaxsim/high_level/__init__.py +++ b/src/jaxsim/high_level/__init__.py @@ -1,2 +1,2 @@ +from ..api.common import VelRepr from . import common, joint, link, model -from .common import VelRepr diff --git a/src/jaxsim/high_level/common.py b/src/jaxsim/high_level/common.py deleted file mode 100644 index 8cee02fdd..000000000 --- a/src/jaxsim/high_level/common.py +++ /dev/null @@ -1,11 +0,0 @@ -import enum - - -class VelRepr(enum.IntEnum): - """ - Enumeration of all supported 6D velocity representations. - """ - - Body = enum.auto() - Mixed = enum.auto() - Inertial = enum.auto() diff --git a/src/jaxsim/integrators/__init__.py b/src/jaxsim/integrators/__init__.py index bf6b0d023..28db9b609 100644 --- a/src/jaxsim/integrators/__init__.py +++ b/src/jaxsim/integrators/__init__.py @@ -1,2 +1,2 @@ from . import fixed_step -from .common import Integrator, Time, TimeStep +from .common import Integrator, SystemDynamics, Time, TimeStep diff --git a/src/jaxsim/integrators/fixed_step.py b/src/jaxsim/integrators/fixed_step.py index 53a0975c2..a3a2526b7 100644 --- a/src/jaxsim/integrators/fixed_step.py +++ b/src/jaxsim/integrators/fixed_step.py @@ -5,11 +5,11 @@ import jax_dataclasses import jaxlie -from jaxsim.simulation.ode_data import ODEState +import jaxsim.api as js from .common import ExplicitRungeKutta, PyTreeType, Time, TimeStep -ODEStateDerivative = ODEState +ODEStateDerivative = js.ode_data.ODEState # ===================================================== @@ -97,8 +97,8 @@ class ExplicitRungeKuttaSO3Mixin: @classmethod def post_process_state( - cls, x0: ODEState, t0: Time, xf: ODEState, dt: TimeStep - ) -> ODEState: + cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep + ) -> js.ode_data.ODEState: # Indices to convert quaternions between serializations. to_xyzw = jnp.array([1, 2, 3, 0]) @@ -130,15 +130,15 @@ def post_process_state( @jax_dataclasses.pytree_dataclass -class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[ODEState]): +class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]): pass @jax_dataclasses.pytree_dataclass -class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[ODEState]): +class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]): pass @jax_dataclasses.pytree_dataclass -class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[ODEState]): +class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]): pass diff --git a/src/jaxsim/math/adjoint.py b/src/jaxsim/math/adjoint.py index 02b252324..100189dda 100644 --- a/src/jaxsim/math/adjoint.py +++ b/src/jaxsim/math/adjoint.py @@ -1,7 +1,7 @@ import jax.numpy as jnp +import jaxlie import jaxsim.typing as jtp -from jaxsim.sixd import so3 from .quaternion import Quaternion from .skew import Skew @@ -31,7 +31,7 @@ def from_quaternion_and_translation( assert quaternion.size == 4 assert translation.size == 3 - Q_sixd = so3.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion)) + Q_sixd = jaxlie.SO3.from_quaternion_xyzw(xyzw=Quaternion.to_xyzw(quaternion)) Q_sixd = Q_sixd if not normalize_quaternion else Q_sixd.normalize() return Adjoint.from_rotation_and_translation( diff --git a/src/jaxsim/math/quaternion.py b/src/jaxsim/math/quaternion.py index 09b43f890..40e627621 100644 --- a/src/jaxsim/math/quaternion.py +++ b/src/jaxsim/math/quaternion.py @@ -1,8 +1,8 @@ import jax.lax import jax.numpy as jnp +import jaxlie import jaxsim.typing as jtp -from jaxsim.sixd import so3 class Quaternion: @@ -43,7 +43,7 @@ def to_dcm(quaternion: jtp.Vector) -> jtp.Matrix: Returns: jtp.Matrix: Direction cosine matrix (DCM). """ - return so3.SO3.from_quaternion_xyzw( + return jaxlie.SO3.from_quaternion_xyzw( xyzw=Quaternion.to_xyzw(quaternion) ).as_matrix() @@ -59,7 +59,7 @@ def from_dcm(dcm: jtp.Matrix) -> jtp.Vector: jtp.Vector: Quaternion in XYZW representation. """ return Quaternion.to_wxyz( - xyzw=so3.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw() + xyzw=jaxlie.SO3.from_matrix(matrix=dcm).as_quaternion_xyzw() ) @staticmethod diff --git a/src/jaxsim/math/rotation.py b/src/jaxsim/math/rotation.py index 532615029..48cdd6767 100644 --- a/src/jaxsim/math/rotation.py +++ b/src/jaxsim/math/rotation.py @@ -2,9 +2,9 @@ import jax import jax.numpy as jnp +import jaxlie import jaxsim.typing as jtp -from jaxsim.sixd import so3 from .skew import Skew @@ -21,7 +21,7 @@ def x(theta: jtp.Float) -> jtp.Matrix: Returns: jtp.Matrix: 3D rotation matrix. """ - return so3.SO3.from_x_radians(theta=theta).as_matrix() + return jaxlie.SO3.from_x_radians(theta=theta).as_matrix() @staticmethod def y(theta: jtp.Float) -> jtp.Matrix: @@ -34,7 +34,7 @@ def y(theta: jtp.Float) -> jtp.Matrix: Returns: jtp.Matrix: 3D rotation matrix. """ - return so3.SO3.from_y_radians(theta=theta).as_matrix() + return jaxlie.SO3.from_y_radians(theta=theta).as_matrix() @staticmethod def z(theta: jtp.Float) -> jtp.Matrix: @@ -47,7 +47,7 @@ def z(theta: jtp.Float) -> jtp.Matrix: Returns: jtp.Matrix: 3D rotation matrix. """ - return so3.SO3.from_z_radians(theta=theta).as_matrix() + return jaxlie.SO3.from_z_radians(theta=theta).as_matrix() @staticmethod def from_axis_angle(vector: jtp.Vector) -> jtp.Matrix: diff --git a/src/jaxsim/parsers/descriptions/link.py b/src/jaxsim/parsers/descriptions/link.py index d2912d7f0..9676619d3 100644 --- a/src/jaxsim/parsers/descriptions/link.py +++ b/src/jaxsim/parsers/descriptions/link.py @@ -3,10 +3,10 @@ import jax.numpy as jnp import jax_dataclasses +import jaxlie from jax_dataclasses import Static import jaxsim.typing as jtp -from jaxsim.sixd import se3 from jaxsim.utils import JaxsimDataclass @@ -78,7 +78,7 @@ def lump_with( I_removed = link.inertia # Create the SE3 object. Note the inverse. - r_H_l = se3.SE3.from_matrix(lumped_H_removed).inverse() + r_H_l = jaxlie.SE3.from_matrix(lumped_H_removed).inverse() r_X_l = r_H_l.adjoint() # Move the inertia diff --git a/src/jaxsim/parsers/rod/utils.py b/src/jaxsim/parsers/rod/utils.py index 987b18f83..b281abc9f 100644 --- a/src/jaxsim/parsers/rod/utils.py +++ b/src/jaxsim/parsers/rod/utils.py @@ -1,15 +1,17 @@ import os from typing import Union -import jax.numpy as jnp +import jaxlie import numpy as np import numpy.typing as npt import rod +import jaxsim.typing as jtp +from jaxsim.math.inertia import Inertia from jaxsim.parsers import descriptions -def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray: +def from_sdf_inertial(inertial: rod.Inertial) -> jtp.Matrix: """ Extract the 6D inertia matrix from an SDF inertial element. @@ -20,9 +22,6 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray: The 6D inertia matrix of the link expressed in the link frame. """ - from jaxsim.math.inertia import Inertia - from jaxsim.sixd import se3 - # Extract the "mass" element m = inertial.mass @@ -52,13 +51,13 @@ def from_sdf_inertial(inertial: rod.Inertial) -> npt.NDArray: L_H_CoM = inertial.pose.transform() if inertial.pose is not None else np.eye(4) # We need its inverse - CoM_H_L = se3.SE3.from_matrix(matrix=L_H_CoM).inverse() - CoM_X_L: npt.NDArray = CoM_H_L.adjoint() + CoM_H_L = jaxlie.SE3.from_matrix(matrix=L_H_CoM).inverse() + CoM_X_L = CoM_H_L.adjoint() # Express the CoM inertia matrix in the link frame L M_L = CoM_X_L.T @ M_CoM @ CoM_X_L - return jnp.array(M_L) + return M_L.astype(dtype=float) def axis_to_jtype( diff --git a/src/jaxsim/physics/model/physics_model.py b/src/jaxsim/physics/model/physics_model.py index 86ce19018..34143f55b 100644 --- a/src/jaxsim/physics/model/physics_model.py +++ b/src/jaxsim/physics/model/physics_model.py @@ -4,6 +4,7 @@ import jax.lax import jax.numpy as jnp import jax_dataclasses +import jaxlie import numpy as np from jax_dataclasses import Static @@ -12,7 +13,6 @@ import jaxsim.typing as jtp from jaxsim.parsers.descriptions import JointDescriptor, JointType from jaxsim.physics import default_gravity -from jaxsim.sixd import se3 from jaxsim.utils import JaxsimDataclass, not_tracing from .ground_contact import GroundContact @@ -185,7 +185,7 @@ def build_from( # (this is just the pose of the base link in the SDF description) base_link = model_description.links_dict[model_description.link_names()[0]] R_H_B = model_description.transform(name=base_link.name) - tree_transform_0 = se3.SE3.from_matrix(matrix=R_H_B).adjoint() + tree_transform_0 = jaxlie.SE3.from_matrix(matrix=R_H_B).adjoint() # Helper to compute the transform pre(i)_H_λ(i). # Given a joint 'i', it is the coordinate transform between its predecessor @@ -200,7 +200,7 @@ def build_from( tree_transforms_dict = { 0: tree_transform_0, **{ - j.index: se3.SE3.from_matrix(matrix=prei_H_λi(j)).adjoint() + j.index: jaxlie.SE3.from_matrix(matrix=prei_H_λi(j)).adjoint() for j in model_description.joints }, } diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py new file mode 100644 index 000000000..1929241a6 --- /dev/null +++ b/src/jaxsim/rbda/__init__.py @@ -0,0 +1,9 @@ +from jaxsim.physics.algos import ( + aba, + crba, + forward_kinematics, + jacobian, + rnea, + soft_contacts, + utils, +) diff --git a/src/jaxsim/simulation/__init__.py b/src/jaxsim/simulation/__init__.py index a56ae20fa..21dd7c132 100644 --- a/src/jaxsim/simulation/__init__.py +++ b/src/jaxsim/simulation/__init__.py @@ -1,4 +1 @@ -from . import integrators, ode, ode_data, simulator from .ode_data import ODEInput, ODEState -from .ode_integration import IntegratorType -from .simulator import JaxSim, SimulatorData diff --git a/src/jaxsim/simulation/ode_data.py b/src/jaxsim/simulation/ode_data.py index ec3930687..744578dc9 100644 --- a/src/jaxsim/simulation/ode_data.py +++ b/src/jaxsim/simulation/ode_data.py @@ -1,13 +1,11 @@ import jax.flatten_util import jax_dataclasses +import jaxsim.api as js import jaxsim.typing as jtp +from jaxsim.api.ode_data import PhysicsModelInput, PhysicsModelState from jaxsim.physics.algos.soft_contacts import SoftContactsState from jaxsim.physics.model.physics_model import PhysicsModel -from jaxsim.physics.model.physics_model_state import ( - PhysicsModelInput, - PhysicsModelState, -) from jaxsim.utils import JaxsimDataclass diff --git a/src/jaxsim/sixd/__init__.py b/src/jaxsim/sixd/__init__.py deleted file mode 100644 index ff7ee7c4a..000000000 --- a/src/jaxsim/sixd/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from jaxlie import _se3 as se3 -from jaxlie import _so3 as so3 diff --git a/src/jaxsim/terrain/__init__.py b/src/jaxsim/terrain/__init__.py new file mode 100644 index 000000000..065a93491 --- /dev/null +++ b/src/jaxsim/terrain/__init__.py @@ -0,0 +1,2 @@ +from jaxsim.physics.algos import terrain +from jaxsim.physics.algos.terrain import FlatTerrain, Terrain diff --git a/tests/utils_idyntree.py b/tests/utils_idyntree.py index c022cf5f7..76d2fb414 100644 --- a/tests/utils_idyntree.py +++ b/tests/utils_idyntree.py @@ -8,7 +8,7 @@ import numpy.typing as npt import jaxsim.api as js -from jaxsim.high_level.common import VelRepr +from jaxsim import VelRepr def build_kindyncomputations_from_jaxsim_model(