diff --git a/src/jaxsim/high_level/__init__.py b/src/jaxsim/high_level/__init__.py deleted file mode 100644 index 258be2468..000000000 --- a/src/jaxsim/high_level/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from ..api.common import VelRepr -from . import common, joint, link, model diff --git a/src/jaxsim/high_level/joint.py b/src/jaxsim/high_level/joint.py deleted file mode 100644 index db8d54876..000000000 --- a/src/jaxsim/high_level/joint.py +++ /dev/null @@ -1,148 +0,0 @@ -import dataclasses -import functools -from typing import Any - -import jax.numpy as jnp -import jax_dataclasses -from jax_dataclasses import Static - -import jaxsim.parsers -import jaxsim.typing as jtp -from jaxsim.utils import Vmappable, not_tracing, oop - - -@jax_dataclasses.pytree_dataclass -class Joint(Vmappable): - """ - High-level class to operate in r/o on a single joint of a simulated model. - """ - - joint_description: Static[jaxsim.parsers.descriptions.JointDescription] - - _parent_model: Any = dataclasses.field( - default=None, repr=False, compare=False, hash=False - ) - - @property - def parent_model(self) -> "jaxsim.high_level.model.Model": - """""" - - return self._parent_model - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def valid(self) -> jtp.Bool: - """""" - - return jnp.array(self.parent_model is not None, dtype=bool) - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def index(self) -> jtp.Int: - """""" - - return jnp.array(self.joint_description.index, dtype=int) - - @functools.partial(oop.jax_tf.method_ro) - def dofs(self) -> jtp.Int: - """""" - - return jnp.array(1, dtype=int) - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def name(self) -> str: - """""" - - return self.joint_description.name - - @functools.partial(oop.jax_tf.method_ro) - def position(self, dof: int | None = None) -> jtp.Float: - """""" - - dof = dof if dof is not None else 0 - - return jnp.array( - self.parent_model.joint_positions(joint_names=(self.name(),))[dof], - dtype=float, - ) - - @functools.partial(oop.jax_tf.method_ro) - def velocity(self, dof: int | None = None) -> jtp.Float: - """""" - - dof = dof if dof is not None else 0 - - return jnp.array( - self.parent_model.joint_velocities(joint_names=(self.name(),))[dof], - dtype=float, - ) - - @functools.partial(oop.jax_tf.method_ro) - def force_target(self, dof: int | None = None) -> jtp.Float: - """""" - - dof = dof if dof is not None else 0 - - return jnp.array( - self.parent_model.joint_generalized_forces_targets( - joint_names=(self.name(),) - )[dof], - dtype=float, - ) - - @functools.partial(oop.jax_tf.method_ro) - def position_limit(self, dof: int | None = None) -> tuple[jtp.Float, jtp.Float]: - """""" - - dof = dof if dof is not None else 0 - - if not_tracing(dof) and dof != 0: - msg = "Only joints with 1 DoF are currently supported" - raise ValueError(msg) - - low, high = self.joint_description.position_limit - - return jnp.array(low, dtype=float), jnp.array(high, dtype=float) - - # ============= - # Motor methods - # ============= - @functools.partial(oop.jax_tf.method_ro) - def motor_inertia(self) -> jtp.Vector: - """""" - - return jnp.array(self.joint_description.motor_inertia, dtype=float) - - @functools.partial(oop.jax_tf.method_ro) - def motor_gear_ratio(self) -> jtp.Vector: - """""" - - return jnp.array(self.joint_description.motor_gear_ratio, dtype=float) - - @functools.partial(oop.jax_tf.method_ro) - def motor_viscous_friction(self) -> jtp.Vector: - """""" - - return jnp.array(self.joint_description.motor_viscous_friction, dtype=float) - - # ================= - # Multi-DoF methods - # ================= - - @functools.partial(oop.jax_tf.method_ro) - def joint_position(self) -> jtp.Vector: - """""" - - return self.parent_model.joint_positions(joint_names=(self.name(),)) - - @functools.partial(oop.jax_tf.method_ro) - def joint_velocity(self) -> jtp.Vector: - """""" - - return self.parent_model.joint_velocities(joint_names=(self.name(),)) - - @functools.partial(oop.jax_tf.method_ro) - def joint_force_target(self) -> jtp.Vector: - """""" - - return self.parent_model.joint_generalized_forces_targets( - joint_names=(self.name(),) - ) diff --git a/src/jaxsim/high_level/link.py b/src/jaxsim/high_level/link.py deleted file mode 100644 index d82e1d258..000000000 --- a/src/jaxsim/high_level/link.py +++ /dev/null @@ -1,259 +0,0 @@ -import dataclasses -import functools -from typing import Any - -import jax.lax -import jax.numpy as jnp -import jax_dataclasses -import numpy as np -from jax_dataclasses import Static - -import jaxsim.parsers -import jaxsim.typing as jtp -from jaxsim import sixd -from jaxsim.physics.algos.jacobian import jacobian -from jaxsim.utils import Vmappable, oop - -from .common import VelRepr - - -@jax_dataclasses.pytree_dataclass -class Link(Vmappable): - """ - High-level class to operate in r/o on a single link of a simulated model. - """ - - link_description: Static[jaxsim.parsers.descriptions.LinkDescription] - - _parent_model: Any = dataclasses.field( - default=None, repr=False, compare=False, hash=False - ) - - @property - def parent_model(self) -> "jaxsim.high_level.model.Model": - """""" - - return self._parent_model - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def valid(self) -> jtp.Bool: - """""" - - return jnp.array(self.parent_model is not None, dtype=bool) - - # ========== - # Properties - # ========== - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def name(self) -> str: - """""" - - return self.link_description.name - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def index(self) -> jtp.Int: - """""" - - return jnp.array(self.link_description.index, dtype=int) - - # ======== - # Dynamics - # ======== - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def mass(self) -> jtp.Float: - """""" - - return jnp.array(self.link_description.mass, dtype=float) - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def spatial_inertia(self) -> jtp.Matrix: - """""" - - return jnp.array(self.link_description.inertia, dtype=float) - - @functools.partial(oop.jax_tf.method_ro, vmap_in_axes=(0, None)) - def com_position(self, in_link_frame: bool = True) -> jtp.Vector: - """""" - - from jaxsim.math.inertia import Inertia - - _, L_p_CoM, _ = Inertia.to_params(M=self.spatial_inertia()) - - def com_in_link_frame(): - return L_p_CoM.squeeze() - - def com_in_inertial_frame(): - W_H_L = self.transform() - W_p̃_CoM = W_H_L @ jnp.hstack([L_p_CoM.squeeze(), 1]) - - return W_p̃_CoM[0:3].squeeze() - - return jax.lax.select( - pred=in_link_frame, - on_true=com_in_link_frame(), - on_false=com_in_inertial_frame(), - ) - - # ========== - # Kinematics - # ========== - - @functools.partial(oop.jax_tf.method_ro) - def position(self) -> jtp.Vector: - """""" - - return self.transform()[0:3, 3] - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["dcm"]) - def orientation(self, dcm: bool = False) -> jtp.Vector: - """""" - - R = self.transform()[0:3, 0:3] - - to_wxyz = np.array([3, 0, 1, 2]) - return R if dcm else sixd.so3.SO3.from_matrix(R).as_quaternion_xyzw()[to_wxyz] - - @functools.partial(oop.jax_tf.method_ro) - def transform(self) -> jtp.Matrix: - """""" - - return self.parent_model.forward_kinematics()[self.index()] - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"]) - def velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector: - """""" - - v_WL = ( - self.jacobian(output_vel_repr=vel_repr) - @ self.parent_model.generalized_velocity() - ) - - return v_WL - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"]) - def linear_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector: - """""" - - return self.velocity(vel_repr=vel_repr)[0:3] - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["vel_repr"]) - def angular_velocity(self, vel_repr: VelRepr | None = None) -> jtp.Vector: - """""" - - return self.velocity(vel_repr=vel_repr)[3:6] - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"]) - def jacobian(self, output_vel_repr: VelRepr | None = None) -> jtp.Matrix: - """""" - - if output_vel_repr is None: - output_vel_repr = self.parent_model.velocity_representation - - # Compute the doubly left-trivialized free-floating jacobian - L_J_WL_B = jacobian( - model=self.parent_model.physics_model, - body_index=self.index(), - q=self.parent_model.data.model_state.joint_positions, - ) - - if self.parent_model.velocity_representation is VelRepr.Body: - L_J_WL_target = L_J_WL_B - - elif self.parent_model.velocity_representation is VelRepr.Inertial: - dofs = self.parent_model.dofs() - W_H_B = self.parent_model.base_transform() - - B_X_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint() - zero_6n = jnp.zeros(shape=(6, dofs)) - - B_T_W = jnp.vstack( - [ - jnp.block([B_X_W, zero_6n]), - jnp.block([zero_6n.T, jnp.eye(dofs)]), - ] - ) - - L_J_WL_target = L_J_WL_B @ B_T_W - - elif self.parent_model.velocity_representation is VelRepr.Mixed: - dofs = self.parent_model.dofs() - W_H_B = self.parent_model.base_transform() - BW_H_B = jnp.array(W_H_B).at[0:3, 3].set(jnp.zeros(3)) - - B_X_BW = sixd.se3.SE3.from_matrix(BW_H_B).inverse().adjoint() - zero_6n = jnp.zeros(shape=(6, dofs)) - - B_T_BW = jnp.vstack( - [ - jnp.block([B_X_BW, zero_6n]), - jnp.block([zero_6n.T, jnp.eye(dofs)]), - ] - ) - - L_J_WL_target = L_J_WL_B @ B_T_BW - - else: - raise ValueError(self.parent_model.velocity_representation) - - if output_vel_repr is VelRepr.Body: - return L_J_WL_target - - elif output_vel_repr is VelRepr.Inertial: - W_H_L = self.transform() - W_X_L = sixd.se3.SE3.from_matrix(W_H_L).adjoint() - return W_X_L @ L_J_WL_target - - elif output_vel_repr is VelRepr.Mixed: - W_H_L = self.transform() - LW_H_L = jnp.array(W_H_L).at[0:3, 3].set(jnp.zeros(3)) - LW_X_L = sixd.se3.SE3.from_matrix(LW_H_L).adjoint() - return LW_X_L @ L_J_WL_target - - else: - raise ValueError(output_vel_repr) - - @functools.partial(oop.jax_tf.method_ro) - def external_force(self) -> jtp.Vector: - """ - Return the active external force acting on the link. - - This external force is a user input and is not computed by the physics engine. - During the simulation, this external force is summed to other terms like those - related to enforce contact constraints. - - Returns: - The active external 6D force acting on the link in the active representation. - """ - - # Get the external force stored in the inertial representation - W_f_ext = self.parent_model.data.model_input.f_ext[self.index()] - - # Express it in the active representation - if self.parent_model.velocity_representation is VelRepr.Inertial: - f_ext = W_f_ext - - elif self.parent_model.velocity_representation is VelRepr.Body: - W_H_L = self.transform() - W_X_L = sixd.se3.SE3.from_matrix(W_H_L).adjoint() - - f_ext = L_f_ext = W_X_L.transpose() @ W_f_ext - - elif self.parent_model.velocity_representation is VelRepr.Mixed: - W_p_L = self.transform()[0:3, 3] - W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L) - W_X_LW = sixd.se3.SE3.from_matrix(W_H_LW).adjoint() - - f_ext = LW_f_ext = W_X_LW.transpose() @ W_f_ext - - else: - raise ValueError(self.parent_model.velocity_representation) - - return f_ext - - @functools.partial(oop.jax_tf.method_ro) - def in_contact(self) -> jtp.Bool: - """""" - - return self.parent_model.in_contact()[self.index()] diff --git a/src/jaxsim/high_level/model.py b/src/jaxsim/high_level/model.py deleted file mode 100644 index 5e88057eb..000000000 --- a/src/jaxsim/high_level/model.py +++ /dev/null @@ -1,1686 +0,0 @@ -import dataclasses -import functools -import pathlib -from typing import Any, Dict, List, Optional, Tuple, Union - -import jax -import jax.numpy as jnp -import jax_dataclasses -import numpy as np -import rod -from jax_dataclasses import Static - -import jaxsim.physics.algos.aba -import jaxsim.physics.algos.crba -import jaxsim.physics.algos.forward_kinematics -import jaxsim.physics.algos.rnea -import jaxsim.physics.model.physics_model -import jaxsim.physics.model.physics_model_state -import jaxsim.typing as jtp -from jaxsim import high_level, logging, physics, sixd -from jaxsim.physics.algos import soft_contacts -from jaxsim.physics.algos.terrain import FlatTerrain, Terrain -from jaxsim.utils import JaxsimDataclass, Mutability, Vmappable, oop - -from .common import VelRepr - - -@jax_dataclasses.pytree_dataclass -class ModelData(JaxsimDataclass): - """ - Class used to store the model state and input at a given time. - """ - - model_state: jaxsim.physics.model.physics_model_state.PhysicsModelState - model_input: jaxsim.physics.model.physics_model_state.PhysicsModelInput - contact_state: jaxsim.physics.algos.soft_contacts.SoftContactsState - - @staticmethod - def zero(physics_model: physics.model.physics_model.PhysicsModel) -> "ModelData": - """ - Return a ModelData object with all fields set to zero and initialized with the right shape. - - Args: - physics_model: The considered physics model. - - Returns: - The zero ModelData object of the given physics model. - """ - - return ModelData( - model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState.zero( - physics_model=physics_model - ), - model_input=jaxsim.physics.model.physics_model_state.PhysicsModelInput.zero( - physics_model=physics_model - ), - contact_state=jaxsim.physics.algos.soft_contacts.SoftContactsState.zero( - physics_model=physics_model - ), - ) - - -@jax_dataclasses.pytree_dataclass -class StepData(JaxsimDataclass): - """ - Class used to store the data computed at each step of the simulation. - """ - - t0: float - tf: float - dt: float - - # Starting model data and real input (tau, f_ext) computed at t0 - t0_model_data: ModelData = dataclasses.field(repr=False) - t0_model_input_real: jaxsim.physics.model.physics_model_state.PhysicsModelInput = ( - dataclasses.field(repr=False) - ) - - # ABA output - t0_base_acceleration: jtp.Vector = dataclasses.field(repr=False) - t0_joint_acceleration: jtp.Vector = dataclasses.field(repr=False) - - # (new ODEState) - # Starting from t0_model_data, can be obtained by integrating the ABA output - # and tangential_deformation_dot (which is fn of ode_state at t0) - tf_model_state: jaxsim.physics.model.physics_model_state.PhysicsModelState = ( - dataclasses.field(repr=False) - ) - tf_contact_state: jaxsim.physics.algos.soft_contacts.SoftContactsState = ( - dataclasses.field(repr=False) - ) - - aux: Dict[str, Any] = dataclasses.field(default_factory=dict) - - -@jax_dataclasses.pytree_dataclass -class Model(Vmappable): - """ - High-level class to operate on a simulated model. - """ - - model_name: Static[str] - - physics_model: physics.model.physics_model.PhysicsModel = dataclasses.field( - repr=False - ) - - velocity_representation: Static[VelRepr] = dataclasses.field(default=VelRepr.Mixed) - - data: ModelData = dataclasses.field(default=None, repr=False) - - # ======================== - # Initialization and state - # ======================== - - @staticmethod - def build_from_model_description( - model_description: Union[str, pathlib.Path, rod.Model], - model_name: str | None = None, - vel_repr: VelRepr = VelRepr.Mixed, - gravity: jtp.Array = jaxsim.physics.default_gravity(), - is_urdf: bool | None = None, - considered_joints: List[str] | None = None, - ) -> "Model": - """ - Build a Model object from a model description. - - Args: - model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model. - model_name: The optional name of the model that overrides the one in the description. - vel_repr: The velocity representation to use. - gravity: The 3D gravity vector. - is_urdf: Whether the model description is a URDF or an SDF. This is automatically inferred if the model description is a path to a file. - considered_joints: The list of joints to consider. If None, all joints are considered. - - Returns: - The built Model object. - """ - - import jaxsim.parsers.rod - - # Parse the input resource (either a path to file or a string with the URDF/SDF) - # and build the -intermediate- model description - model_description = jaxsim.parsers.rod.build_model_description( - model_description=model_description, is_urdf=is_urdf - ) - - # Lump links together if not all joints are considered. - # Note: this procedure assigns a zero position to all joints not considered. - if considered_joints is not None: - model_description = model_description.reduce( - considered_joints=considered_joints - ) - - # Create the physics model from the model description - physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from( - model_description=model_description, gravity=gravity - ) - - # Build and return the high-level model - return Model.build( - physics_model=physics_model, - model_name=model_name, - vel_repr=vel_repr, - ) - - @staticmethod - def build_from_sdf( - sdf: Union[str, pathlib.Path], - model_name: str | None = None, - vel_repr: VelRepr = VelRepr.Mixed, - gravity: jtp.Array = jaxsim.physics.default_gravity(), - is_urdf: bool | None = None, - considered_joints: List[str] | None = None, - ) -> "Model": - """ - Build a Model object from an SDF description. - This is a deprecated method, use build_from_model_description instead. - """ - - msg = "Model.{} is deprecated, use Model.{} instead." - logging.warning( - msg=msg.format("build_from_sdf", "build_from_model_description") - ) - - return Model.build_from_model_description( - model_description=sdf, - model_name=model_name, - vel_repr=vel_repr, - gravity=gravity, - is_urdf=is_urdf, - considered_joints=considered_joints, - ) - - @staticmethod - def build( - physics_model: jaxsim.physics.model.physics_model.PhysicsModel, - model_name: str | None = None, - vel_repr: VelRepr = VelRepr.Mixed, - ) -> "Model": - """ - Build a Model object from a physics model. - - Args: - physics_model: The physics model. - model_name: The optional name of the model that overrides the one in the physics model. - vel_repr: The velocity representation to use. - - Returns: - The built Model object. - """ - - # Set the model name (if not provided, use the one from the model description) - model_name = ( - model_name if model_name is not None else physics_model.description.name - ) - - # Build the high-level model - model = Model( - physics_model=physics_model, - model_name=model_name, - velocity_representation=vel_repr, - ) - - # Zero the model data - with model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - model.zero() - - # Check model validity - if not model.valid(): - raise RuntimeError("The model is not valid.") - - # Return the high-level model - return model - - @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False) - def reduce( - self, considered_joints: tuple[str, ...], keep_base_pose: bool = False - ) -> None: - """ - Reduce the model by lumping together the links connected by removed joints. - - Args: - considered_joints: The sequence of joints to consider. - keep_base_pose: A flag indicating whether to keep the base pose or not. - """ - - if self.vectorized: - raise RuntimeError("Cannot reduce a vectorized model.") - - # Reduce the model description. - # If considered_joints contains joints not existing in the model, the method - # will raise an exception. - reduced_model_description = self.physics_model.description.reduce( - considered_joints=list(considered_joints) - ) - - # Create the physics model from the reduced model description - physics_model = jaxsim.physics.model.physics_model.PhysicsModel.build_from( - model_description=reduced_model_description, - gravity=self.physics_model.gravity[0:3], - ) - - # Build the reduced high-level model - reduced_model = Model.build( - physics_model=physics_model, - model_name=self.name(), - vel_repr=self.velocity_representation, - ) - - # Extract the base pose - W_p_B = self.base_position() - W_Q_B = self.base_orientation(dcm=False) - - # Replace the current model with the reduced model. - # Since the structure of the PyTree changes, we disable validation. - self.physics_model = reduced_model.physics_model - self.data = reduced_model.data - - if keep_base_pose: - self.reset_base_position(position=W_p_B) - self.reset_base_orientation(orientation=W_Q_B, dcm=False) - - @functools.partial(oop.jax_tf.method_rw, jit=False) - def zero(self) -> None: - """""" - - self.data = ModelData.zero(physics_model=self.physics_model) - - @functools.partial(oop.jax_tf.method_rw, jit=False) - def zero_input(self) -> None: - """""" - - self.data.model_input = ModelData.zero( - physics_model=self.physics_model - ).model_input - - @functools.partial(oop.jax_tf.method_rw, jit=False) - def zero_state(self) -> None: - """""" - - model_data_zero = ModelData.zero(physics_model=self.physics_model) - self.data.model_state = model_data_zero.model_state - self.data.contact_state = model_data_zero.contact_state - - @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False) - def set_velocity_representation(self, vel_repr: VelRepr) -> None: - """""" - - if self.velocity_representation is vel_repr: - return - - self.velocity_representation = vel_repr - - # ========== - # Properties - # ========== - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def valid(self) -> jtp.Bool: - """""" - - valid = True - valid = valid and all(l.valid() for l in self.links()) - valid = valid and all(j.valid() for j in self.joints()) - return jnp.array(valid, dtype=bool) - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def floating_base(self) -> jtp.Bool: - """""" - - return jnp.array(self.physics_model.is_floating_base, dtype=bool) - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def dofs(self) -> jtp.Int: - """""" - - return self.joint_positions().size - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def name(self) -> str: - """""" - - return self.model_name - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def nr_of_links(self) -> jtp.Int: - """""" - - return jnp.array(len(self.links()), dtype=int) - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def nr_of_joints(self) -> jtp.Int: - """""" - - return jnp.array(len(self.joints()), dtype=int) - - @functools.partial(oop.jax_tf.method_ro) - def total_mass(self) -> jtp.Float: - """""" - - return jnp.sum(jnp.array([l.mass() for l in self.links()]), dtype=float) - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def get_link(self, link_name: str) -> high_level.link.Link: - """""" - - if link_name not in self.link_names(): - msg = f"Link '{link_name}' is not part of model '{self.name()}'" - raise ValueError(msg) - - return self.links(link_names=(link_name,))[0] - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def get_joint(self, joint_name: str) -> high_level.joint.Joint: - """""" - - if joint_name not in self.joint_names(): - msg = f"Joint '{joint_name}' is not part of model '{self.name()}'" - raise ValueError(msg) - - return self.joints(joint_names=(joint_name,))[0] - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def link_names(self) -> tuple[str, ...]: - """""" - - return tuple(self.physics_model.description.links_dict.keys()) - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def joint_names(self) -> tuple[str, ...]: - """""" - - return tuple(self.physics_model.description.joints_dict.keys()) - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def links( - self, link_names: tuple[str, ...] | None = None - ) -> tuple[high_level.link.Link, ...]: - """""" - - all_links = { - l.name: high_level.link.Link( - link_description=l, _parent_model=self, batch_size=self.batch_size - ) - for l in sorted( - self.physics_model.description.links_dict.values(), - key=lambda l: l.index, - ) - } - - for l in all_links.values(): - l._set_mutability(self._mutability()) - - if link_names is None: - return tuple(all_links.values()) - - return tuple(all_links[name] for name in link_names) - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def joints( - self, joint_names: tuple[str, ...] | None = None - ) -> tuple[high_level.joint.Joint, ...]: - """""" - - all_joints = { - j.name: high_level.joint.Joint( - joint_description=j, _parent_model=self, batch_size=self.batch_size - ) - for j in sorted( - self.physics_model.description.joints_dict.values(), - key=lambda j: j.index, - ) - } - - for j in all_joints.values(): - j._set_mutability(self._mutability()) - - if joint_names is None: - return tuple(all_joints.values()) - - return tuple(all_joints[name] for name in joint_names) - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["link_names", "terrain"]) - def in_contact( - self, - link_names: tuple[str, ...] | None = None, - terrain: Terrain = FlatTerrain(), - ) -> jtp.Vector: - """""" - - link_names = link_names if link_names is not None else self.link_names() - - if set(link_names) - set(self.link_names()) != set(): - raise ValueError("One or more link names are not part of the model") - - from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel - - W_p_Ci, _ = collidable_points_pos_vel( - model=self.physics_model, - q=self.data.model_state.joint_positions, - qd=self.data.model_state.joint_velocities, - xfb=self.data.model_state.xfb(), - ) - - terrain_height = jax.vmap(terrain.height)(W_p_Ci[0, :], W_p_Ci[1, :]) - - below_terrain = W_p_Ci[2, :] <= terrain_height - - links_in_contact = jax.vmap( - lambda link_index: jnp.where( - self.physics_model.gc.body == link_index, - below_terrain, - jnp.zeros_like(below_terrain, dtype=bool), - ).any() - )(jnp.array([link.index() for link in self.links(link_names=link_names)])) - - return links_in_contact - - # ================= - # Multi-DoF methods - # ================= - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) - def joint_positions(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector: - """""" - - return self.data.model_state.joint_positions[ - self._joint_indices(joint_names=joint_names) - ] - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) - def joint_random_positions( - self, - joint_names: tuple[str, ...] | None = None, - key: jax.Array | None = None, - ) -> jtp.Vector: - """""" - - if key is None: - key = jax.random.PRNGKey(seed=0) - - s_min, s_max = self.joint_limits(joint_names=joint_names) - - s_random = jax.random.uniform( - minval=s_min, - maxval=s_max, - key=key, - shape=s_min.shape, - ) - - return s_random - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) - def joint_velocities( - self, joint_names: tuple[str, ...] | None = None - ) -> jtp.Vector: - """""" - - return self.data.model_state.joint_velocities[ - self._joint_indices(joint_names=joint_names) - ] - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) - def joint_generalized_forces_targets( - self, joint_names: tuple[str, ...] | None = None - ) -> jtp.Vector: - """""" - - return self.data.model_input.tau[self._joint_indices(joint_names=joint_names)] - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["joint_names"]) - def joint_limits( - self, joint_names: tuple[str, ...] | None = None - ) -> Tuple[jtp.Vector, jtp.Vector]: - """""" - - # Consider all joints if not specified otherwise - joint_names = joint_names if joint_names is not None else self.joint_names() - - # Create a (Dofs, 2) matrix containing the joint limits - limits = jnp.vstack( - jnp.array([j.position_limit() for j in self.joints(joint_names)]) - ) - - # Get the limits, reordering them in case low > high - s_low = jnp.min(limits, axis=1) - s_high = jnp.max(limits, axis=1) - - return s_low, s_high - - # ========= - # Base link - # ========= - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def base_frame(self) -> str: - """""" - - return self.physics_model.description.root.name - - @functools.partial(oop.jax_tf.method_ro) - def base_position(self) -> jtp.Vector: - """""" - - return self.data.model_state.base_position.squeeze() - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["dcm"]) - def base_orientation(self, dcm: bool = False) -> jtp.Vector: - """""" - - # Normalize the quaternion before using it. - # Our integration logic has a Baumgarte stabilization term makes the quaternion - # norm converge to 1, but it does not enforce to be 1 at all the time instants. - base_unit_quaternion = ( - self.data.model_state.base_quaternion.squeeze() - / jnp.linalg.norm(self.data.model_state.base_quaternion) - ) - - # wxyz -> xyzw - to_xyzw = np.array([1, 2, 3, 0]) - - return ( - base_unit_quaternion - if not dcm - else sixd.so3.SO3.from_quaternion_xyzw( - base_unit_quaternion[to_xyzw] - ).as_matrix() - ) - - @functools.partial(oop.jax_tf.method_ro) - def base_transform(self) -> jtp.MatrixJax: - """""" - - W_R_B = self.base_orientation(dcm=True) - W_p_B = jnp.vstack(self.base_position()) - - return jnp.vstack( - [ - jnp.block([W_R_B, W_p_B]), - jnp.array([0, 0, 0, 1]), - ] - ) - - @functools.partial(oop.jax_tf.method_ro) - def base_velocity(self) -> jtp.Vector: - """""" - - W_v_WB = jnp.hstack( - [ - self.data.model_state.base_linear_velocity, - self.data.model_state.base_angular_velocity, - ] - ) - - return self.inertial_to_active_representation(array=W_v_WB) - - @functools.partial(oop.jax_tf.method_ro) - def external_forces(self) -> jtp.Matrix: - """ - Return the active external forces acting on the robot. - - The external forces are a user input and are not computed by the physics engine. - During the simulation, these external forces are summed to other terms like - the external forces due to the contact with the environment. - - Returns: - A matrix of shape (n_links, 6) containing the external forces acting on the - robot links. The forces are expressed in the active representation. - """ - - # Get the active external forces that are always stored internally - # in Inertial representation - W_f_ext = self.data.model_input.f_ext - - inertial_to_active = lambda f: self.inertial_to_active_representation( - f, is_force=True - ) - - return jax.vmap(inertial_to_active, in_axes=0)(W_f_ext) - - # ======================= - # Single link r/w methods - # ======================= - - @functools.partial( - oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"] - ) - def apply_external_force_to_link( - self, - link_name: str, - force: jtp.Array | None = None, - torque: jtp.Array | None = None, - additive: bool = True, - ) -> None: - """""" - - # Get the target link with the correct mutability - link = self.get_link(link_name=link_name) - link._set_mutability(mutability=self._mutability()) - - # Initialize zero force components if not set - force = force if force is not None else jnp.zeros(3) - torque = torque if torque is not None else jnp.zeros(3) - - # Build the target 6D force in the active representation - f_ext = jnp.hstack([force, torque]) - - # Convert the 6D force to the inertial representation - if self.velocity_representation is VelRepr.Inertial: - W_f_ext = f_ext - - elif self.velocity_representation is VelRepr.Body: - L_f_ext = f_ext - W_H_L = link.transform() - L_X_W = sixd.se3.SE3.from_matrix(W_H_L).inverse().adjoint() - - W_f_ext = L_X_W.transpose() @ L_f_ext - - elif self.velocity_representation is VelRepr.Mixed: - LW_f_ext = f_ext - - W_p_L = link.transform()[0:3, 3] - W_H_LW = jnp.eye(4).at[0:3, 3].set(W_p_L) - LW_X_W = sixd.se3.SE3.from_matrix(W_H_LW).inverse().adjoint() - - W_f_ext = LW_X_W.transpose() @ LW_f_ext - - else: - raise ValueError(self.velocity_representation) - - # Obtain the new 6D force considering the 'additive' flag - W_f_ext_current = self.data.model_input.f_ext[link.index(), :] - new_force = W_f_ext_current + W_f_ext if additive else W_f_ext - - # Update the model data - self.data.model_input.f_ext = self.data.model_input.f_ext.at[ - link.index(), : - ].set(new_force) - - @functools.partial( - oop.jax_tf.method_rw, jit=True, static_argnames=["link_name", "additive"] - ) - def apply_external_force_to_link_com( - self, - link_name: str, - force: jtp.Array | None = None, - torque: jtp.Array | None = None, - additive: bool = True, - ) -> None: - """""" - - # Get the target link with the correct mutability - link = self.get_link(link_name=link_name) - link._set_mutability(mutability=self._mutability()) - - # Initialize zero force components if not set - force = force if force is not None else jnp.zeros(3) - torque = torque if torque is not None else jnp.zeros(3) - - # Build the target 6D force in the active representation - f_ext = jnp.hstack([force, torque]) - - # Convert the 6D force to the inertial representation - if self.velocity_representation is VelRepr.Inertial: - W_f_ext = f_ext - - elif self.velocity_representation is VelRepr.Body: - GL_f_ext = f_ext - - W_H_L = link.transform() - L_p_CoM = link.com_position(in_link_frame=True) - L_H_GL = jnp.eye(4).at[0:3, 3].set(L_p_CoM) - W_H_GL = W_H_L @ L_H_GL - GL_X_W = sixd.se3.SE3.from_matrix(W_H_GL).inverse().adjoint() - - W_f_ext = GL_X_W.transpose() @ GL_f_ext - - elif self.velocity_representation is VelRepr.Mixed: - GW_f_ext = f_ext - - W_p_CoM = link.com_position(in_link_frame=False) - W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM) - GW_X_W = sixd.se3.SE3.from_matrix(W_H_GW).inverse().adjoint() - - W_f_ext = GW_X_W.transpose() @ GW_f_ext - - else: - raise ValueError(self.velocity_representation) - - # Obtain the new 6D force considering the 'additive' flag - W_f_ext_current = self.data.model_input.f_ext[link.index(), :] - new_force = W_f_ext_current + W_f_ext if additive else W_f_ext - - # Update the model data - self.data.model_input.f_ext = self.data.model_input.f_ext.at[ - link.index(), : - ].set(new_force) - - # ================================================ - # Generalized methods and free-floating quantities - # ================================================ - - @functools.partial(oop.jax_tf.method_ro) - def generalized_position(self) -> Tuple[jtp.Matrix, jtp.Vector]: - """""" - - return self.base_transform(), self.joint_positions() - - @functools.partial(oop.jax_tf.method_ro) - def generalized_velocity(self) -> jtp.Vector: - """""" - - return jnp.hstack([self.base_velocity(), self.joint_velocities()]) - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["output_vel_repr"]) - def generalized_free_floating_jacobian( - self, output_vel_repr: VelRepr | None = None - ) -> jtp.Matrix: - """""" - - if output_vel_repr is None: - output_vel_repr = self.velocity_representation - - # The body frame of the Link.jacobian method is the link frame L. - # In this method, we want instead to use the base link B as body frame. - # Therefore, we always get the link jacobian having Inertial as output - # representation, and then we convert it to the desired output representation. - if output_vel_repr is VelRepr.Inertial: - to_output = lambda J: J - - elif output_vel_repr is VelRepr.Body: - - def to_output(W_J_Wi): - W_H_B = self.base_transform() - B_X_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint() - return B_X_W @ W_J_Wi - - elif output_vel_repr is VelRepr.Mixed: - - def to_output(W_J_Wi): - W_H_B = self.base_transform() - W_H_BW = jnp.array(W_H_B).at[0:3, 0:3].set(jnp.eye(3)) - BW_X_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint() - return BW_X_W @ W_J_Wi - - else: - raise ValueError(output_vel_repr) - - # Get the link jacobians in Inertial representation and convert them to the - # target output representation in which the body frame is the base link B - J_free_floating = jnp.vstack( - [ - to_output( - self.get_link(link_name=link_name).jacobian( - output_vel_repr=VelRepr.Inertial - ) - ) - for link_name in self.link_names() - ] - ) - - return J_free_floating - - @functools.partial(oop.jax_tf.method_ro) - def free_floating_mass_matrix(self) -> jtp.Matrix: - """""" - - M_body = jaxsim.physics.algos.crba.crba( - model=self.physics_model, - q=self.data.model_state.joint_positions, - ) - - if self.velocity_representation is VelRepr.Body: - return M_body - - elif self.velocity_representation is VelRepr.Inertial: - zero_6n = jnp.zeros(shape=(6, self.dofs())) - B_X_W = sixd.se3.SE3.from_matrix(self.base_transform()).inverse().adjoint() - - invT = jnp.vstack( - [ - jnp.block([B_X_W, zero_6n]), - jnp.block([zero_6n.T, jnp.eye(self.dofs())]), - ] - ) - - return invT.T @ M_body @ invT - - elif self.velocity_representation is VelRepr.Mixed: - zero_6n = jnp.zeros(shape=(6, self.dofs())) - W_H_BW = self.base_transform().at[0:3, 3].set(jnp.zeros(3)) - BW_X_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint() - - invT = jnp.vstack( - [ - jnp.block([BW_X_W, zero_6n]), - jnp.block([zero_6n.T, jnp.eye(self.dofs())]), - ] - ) - - return invT.T @ M_body @ invT - - else: - raise ValueError(self.velocity_representation) - - @functools.partial(oop.jax_tf.method_ro) - def free_floating_bias_forces(self) -> jtp.Vector: - """""" - - with self.editable(validate=True) as model: - model.zero_input() - - return jnp.hstack( - model.inverse_dynamics( - base_acceleration=jnp.zeros(6), joint_accelerations=None - ) - ) - - @functools.partial(oop.jax_tf.method_ro) - def free_floating_gravity_forces(self) -> jtp.Vector: - """""" - - with self.editable(validate=True) as model: - model.zero_input() - model.data.model_state.joint_velocities = jnp.zeros_like( - model.data.model_state.joint_velocities - ) - model.data.model_state.base_linear_velocity = jnp.zeros_like( - model.data.model_state.base_linear_velocity - ) - model.data.model_state.base_angular_velocity = jnp.zeros_like( - model.data.model_state.base_angular_velocity - ) - - return jnp.hstack( - model.inverse_dynamics( - base_acceleration=jnp.zeros(6), joint_accelerations=None - ) - ) - - @functools.partial(oop.jax_tf.method_ro) - def momentum(self) -> jtp.Vector: - """""" - - with self.editable(validate=True) as m: - m.set_velocity_representation(vel_repr=VelRepr.Body) - - # Compute the momentum in body-fixed velocity representation. - # Note: the first 6 rows of the mass matrix define the jacobian of the - # floating-base momentum. - B_h = m.free_floating_mass_matrix()[0:6, :] @ m.generalized_velocity() - - W_H_B = self.base_transform() - B_X_W: jtp.Array = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint() - - W_h = B_X_W.T @ B_h - return self.inertial_to_active_representation(array=W_h, is_force=True) - - # =========== - # CoM methods - # =========== - - @functools.partial(oop.jax_tf.method_ro) - def com_position(self) -> jtp.Vector: - """""" - - m = self.total_mass() - - W_H_L = self.forward_kinematics() - W_H_B = self.base_transform() - B_H_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().as_matrix() - - com_links = [ - ( - l.mass() - * B_H_W - @ W_H_L[l.index()] - @ jnp.hstack([l.com_position(in_link_frame=True), 1]) - ) - for l in self.links() - ] - - B_ph_CoM = (1 / m) * jnp.sum(jnp.array(com_links), axis=0) - - return (W_H_B @ B_ph_CoM)[0:3] - - # ========== - # Algorithms - # ========== - - @functools.partial(oop.jax_tf.method_ro) - def forward_kinematics(self) -> jtp.Array: - """""" - - W_H_i = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model( - model=self.physics_model, - q=self.data.model_state.joint_positions, - xfb=self.data.model_state.xfb(), - ) - - return W_H_i - - @functools.partial(oop.jax_tf.method_ro) - def inverse_dynamics( - self, - joint_accelerations: jtp.Vector | None = None, - base_acceleration: jtp.Vector | None = None, - ) -> Tuple[jtp.Vector, jtp.Vector]: - """ - Compute inverse dynamics with the RNEA algorithm. - - Args: - joint_accelerations: the joint accelerations to consider. - base_acceleration: the base acceleration in the active representation to consider. - - Returns: - A tuple containing the 6D force in active representation applied to the base - to obtain the considered base acceleration, and the joint torques to apply - to obtain the considered joint accelerations. - """ - - # Build joint accelerations if not provided - joint_accelerations = ( - joint_accelerations - if joint_accelerations is not None - else jnp.zeros_like(self.joint_positions()) - ) - - # Build base acceleration if not provided - base_acceleration = ( - base_acceleration if base_acceleration is not None else jnp.zeros(6) - ) - - if base_acceleration.size != 6: - raise ValueError(base_acceleration.size) - - def to_inertial(C_vd_WB, W_H_C, C_v_WB, W_vl_WC): - W_X_C = sixd.se3.SE3.from_matrix(W_H_C).adjoint() - C_X_W = sixd.se3.SE3.from_matrix(W_H_C).inverse().adjoint() - - if self.velocity_representation != VelRepr.Mixed: - return W_X_C @ C_vd_WB - else: - from jaxsim.math.cross import Cross - - C_v_WC = C_X_W @ jnp.hstack([W_vl_WC, jnp.zeros(3)]) - return W_X_C @ (C_vd_WB + Cross.vx(C_v_WC) @ C_v_WB) - - if self.velocity_representation is VelRepr.Inertial: - W_H_C = W_H_W = jnp.eye(4) - W_vl_WC = W_vl_WW = jnp.zeros(3) - - elif self.velocity_representation is VelRepr.Body: - W_H_C = W_H_B = self.base_transform() - W_vl_WC = W_vl_WB = self.base_velocity()[0:3] - - elif self.velocity_representation is VelRepr.Mixed: - W_H_B = self.base_transform() - W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_vl_WC = W_vl_W_BW = self.base_velocity()[0:3] - - else: - raise ValueError(self.velocity_representation) - - # We need to convert the derivative of the base acceleration to the Inertial - # representation. In Mixed representation, this conversion is not a plain - # transformation with just X, but it also involves a cross product in ℝ⁶. - W_v̇_WB = to_inertial( - C_vd_WB=base_acceleration, - W_H_C=W_H_C, - C_v_WB=self.base_velocity(), - W_vl_WC=W_vl_WC, - ) - - # Compute RNEA - W_f_B, tau = jaxsim.physics.algos.rnea.rnea( - model=self.physics_model, - xfb=self.data.model_state.xfb(), - q=self.data.model_state.joint_positions, - qd=self.data.model_state.joint_velocities, - qdd=joint_accelerations, - a0fb=W_v̇_WB, - f_ext=self.data.model_input.f_ext, - ) - - # Adjust shape - tau = jnp.atleast_1d(tau.squeeze()) - - # Express W_f_B in the active representation - f_B = self.inertial_to_active_representation(array=W_f_B, is_force=True) - - return f_B, tau - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["prefer_aba"]) - def forward_dynamics( - self, tau: jtp.Vector | None = None, prefer_aba: float = True - ) -> Tuple[jtp.Vector, jtp.Vector]: - """""" - - return ( - self.forward_dynamics_aba(tau=tau) - if prefer_aba - else self.forward_dynamics_crb(tau=tau) - ) - - @functools.partial(oop.jax_tf.method_ro) - def forward_dynamics_aba( - self, tau: jtp.Vector | None = None - ) -> Tuple[jtp.Vector, jtp.Vector]: - """""" - - # Build joint torques if not provided - tau = tau if tau is not None else jnp.zeros_like(self.joint_positions()) - - # Compute ABA - W_v̇_WB, s̈ = jaxsim.physics.algos.aba.aba( - model=self.physics_model, - xfb=self.data.model_state.xfb(), - q=self.data.model_state.joint_positions, - qd=self.data.model_state.joint_velocities, - tau=tau, - f_ext=self.data.model_input.f_ext, - ) - - def to_active(W_vd_WB, W_H_C, W_v_WB, W_vl_WC): - C_X_W = sixd.se3.SE3.from_matrix(W_H_C).inverse().adjoint() - - if self.velocity_representation != VelRepr.Mixed: - return C_X_W @ W_vd_WB - else: - from jaxsim.math.cross import Cross - - W_v_WC = jnp.hstack([W_vl_WC, jnp.zeros(3)]) - return C_X_W @ (W_vd_WB - Cross.vx(W_v_WC) @ W_v_WB) - - if self.velocity_representation is VelRepr.Inertial: - W_H_C = W_H_W = jnp.eye(4) - W_vl_WC = W_vl_WW = jnp.zeros(3) - - elif self.velocity_representation is VelRepr.Body: - W_H_C = W_H_B = self.base_transform() - W_vl_WC = W_vl_WB = self.base_velocity()[0:3] - - elif self.velocity_representation is VelRepr.Mixed: - W_H_B = self.base_transform() - W_H_C = W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3)) - W_vl_WC = W_vl_W_BW = self.base_velocity()[0:3] - - else: - raise ValueError(self.velocity_representation) - - # We need to convert the derivative of the base acceleration to the active - # representation. In Mixed representation, this conversion is not a plain - # transformation with just X, but it also involves a cross product in ℝ⁶. - C_v̇_WB = to_active( - W_vd_WB=W_v̇_WB.squeeze(), - W_H_C=W_H_C, - W_v_WB=jnp.hstack( - [ - self.data.model_state.base_linear_velocity, - self.data.model_state.base_angular_velocity, - ] - ), - W_vl_WC=W_vl_WC, - ) - - # Adjust shape - s̈ = jnp.atleast_1d(s̈.squeeze()) - - return C_v̇_WB, s̈ - - @functools.partial(oop.jax_tf.method_ro) - def forward_dynamics_crb( - self, tau: jtp.Vector | None = None - ) -> Tuple[jtp.Vector, jtp.Vector]: - """""" - - # Build joint torques if not provided - τ = tau if tau is not None else jnp.zeros(shape=(self.dofs(),)) - τ = jnp.atleast_1d(τ.squeeze()) - τ = jnp.vstack(τ) if τ.size > 0 else jnp.empty(shape=(0, 1)) - - # Extract motor parameters from the physics model - GR = self.motor_gear_ratios() - IM = self.motor_inertias() - KV = jnp.diag(self.motor_viscous_frictions()) - - # Compute auxiliary quantities - Γ = jnp.diag(GR) - K̅ᵥ = Γ.T @ KV @ Γ - - # Compute terms of the floating-base EoM - M = self.free_floating_mass_matrix() - h = jnp.vstack(self.free_floating_bias_forces()) - J = self.generalized_free_floating_jacobian() - f_ext = jnp.vstack(self.external_forces().flatten()) - S = jnp.block([jnp.zeros(shape=(self.dofs(), 6)), jnp.eye(self.dofs())]).T - - # Configure the slice for motors - sl_m = np.s_[M.shape[0] - self.dofs() :] - - # Add the motor related terms to the EoM - M = M.at[sl_m, sl_m].set(M[sl_m, sl_m] + jnp.diag(Γ.T @ IM @ Γ)) - h = h.at[sl_m].set(h[sl_m] + K̅ᵥ @ self.joint_velocities()[:, None]) - S = S.at[sl_m].set(S[sl_m]) - - # Compute the generalized acceleration by inverting the EoM - ν̇ = jax.lax.select( - pred=self.floating_base(), - on_true=jnp.linalg.inv(M) @ ((S @ τ) - h + J.T @ f_ext), - on_false=jnp.vstack( - [ - jnp.zeros(shape=(6, 1)), - jnp.linalg.inv(M[6:, 6:]) - @ ((S @ τ)[6:] - h[6:] + J[:, 6:].T @ f_ext), - ] - ), - ).squeeze() - - # Extract the base acceleration in the active representation. - # Note that this is an apparent acceleration (relevant in Mixed representation), - # therefore it cannot be always expressed in different frames with just a - # 6D transformation X. - v̇_WB = ν̇[0:6] - - # Extract the joint accelerations - s̈ = jnp.atleast_1d(ν̇[6:]) - - return v̇_WB, s̈ - - # ====== - # Energy - # ====== - - @functools.partial(oop.jax_tf.method_ro) - def mechanical_energy(self) -> jtp.Float: - """""" - - K = self.kinetic_energy() - U = self.potential_energy() - - return K + U - - @functools.partial(oop.jax_tf.method_ro) - def kinetic_energy(self) -> jtp.Float: - """""" - - with self.editable(validate=True) as m: - m.set_velocity_representation(vel_repr=VelRepr.Body) - - nu = m.generalized_velocity() - M = m.free_floating_mass_matrix() - - return 0.5 * nu.T @ M @ nu - - @functools.partial(oop.jax_tf.method_ro) - def potential_energy(self) -> jtp.Float: - """""" - - m = self.total_mass() - W_p_CoM = jnp.hstack([self.com_position(), 1]) - gravity = self.physics_model.gravity[3:6].squeeze() - - return -(m * jnp.hstack([gravity, 0]) @ W_p_CoM) - - # =========== - # Set targets - # =========== - - @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) - def set_joint_generalized_force_targets( - self, forces: jtp.Vector, joint_names: tuple[str, ...] | None = None - ) -> None: - """""" - - if joint_names is None: - joint_names = self.joint_names() - - if forces.size != len(joint_names): - raise ValueError("Wrong arguments size", forces.size, len(joint_names)) - - self.data.model_input.tau = self.data.model_input.tau.at[ - self._joint_indices(joint_names=joint_names) - ].set(forces) - - # ========== - # Reset data - # ========== - - @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) - def reset_joint_positions( - self, positions: jtp.Vector, joint_names: tuple[str, ...] | None = None - ) -> None: - """""" - - if joint_names is None: - joint_names = self.joint_names() - - if positions.size != len(joint_names): - raise ValueError("Wrong arguments size", positions.size, len(joint_names)) - - if positions.size == 0: - return - - # TODO: joint position limits - - self.data.model_state.joint_positions = jnp.atleast_1d( - jnp.array( - self.data.model_state.joint_positions.at[ - self._joint_indices(joint_names=joint_names) - ].set(positions), - dtype=float, - ) - ) - - @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) - def reset_joint_velocities( - self, velocities: jtp.Vector, joint_names: tuple[str, ...] | None = None - ) -> None: - """""" - - if joint_names is None: - joint_names = self.joint_names() - - if velocities.size != len(joint_names): - raise ValueError("Wrong arguments size", velocities.size, len(joint_names)) - - if velocities.size == 0: - return - - # TODO: joint velocity limits - - self.data.model_state.joint_velocities = jnp.atleast_1d( - jnp.array( - self.data.model_state.joint_velocities.at[ - self._joint_indices(joint_names=joint_names) - ].set(velocities), - dtype=float, - ) - ) - - @functools.partial(oop.jax_tf.method_rw) - def reset_base_position(self, position: jtp.Vector) -> None: - """""" - - self.data.model_state.base_position = jnp.array(position, dtype=float) - - @functools.partial(oop.jax_tf.method_rw, static_argnames=["dcm"]) - def reset_base_orientation(self, orientation: jtp.Array, dcm: bool = False) -> None: - """""" - - if dcm: - to_wxyz = np.array([3, 0, 1, 2]) - orientation_xyzw = sixd.so3.SO3.from_matrix( - orientation - ).as_quaternion_xyzw() - orientation = orientation_xyzw[to_wxyz] - - unit_quaternion = orientation / jnp.linalg.norm(orientation) - self.data.model_state.base_quaternion = jnp.array(unit_quaternion, dtype=float) - - @functools.partial(oop.jax_tf.method_rw) - def reset_base_transform(self, transform: jtp.Matrix) -> None: - """""" - - if transform.shape != (4, 4): - raise ValueError(transform.shape) - - self.reset_base_position(position=transform[0:3, 3]) - self.reset_base_orientation(orientation=transform[0:3, 0:3], dcm=True) - - @functools.partial(oop.jax_tf.method_rw) - def reset_base_velocity(self, base_velocity: jtp.VectorJax) -> None: - """""" - - if not self.physics_model.is_floating_base: - msg = "Changing the base velocity of a fixed-based model is not allowed" - raise RuntimeError(msg) - - # Remove extra dimensions - base_velocity = base_velocity.squeeze() - - # Check for a valid shape - if base_velocity.shape != (6,): - raise ValueError(base_velocity.shape) - - # Convert, if needed, to the representation used internally (VelRepr.Inertial) - if self.velocity_representation is VelRepr.Inertial: - base_velocity_inertial = base_velocity - - elif self.velocity_representation is VelRepr.Body: - w_X_b = sixd.se3.SE3.from_rotation_and_translation( - rotation=sixd.so3.SO3.from_matrix(self.base_orientation(dcm=True)), - translation=self.base_position(), - ).adjoint() - - base_velocity_inertial = w_X_b @ base_velocity - - elif self.velocity_representation is VelRepr.Mixed: - w_X_bw = sixd.se3.SE3.from_rotation_and_translation( - rotation=sixd.so3.SO3.identity(), - translation=self.base_position(), - ).adjoint() - - base_velocity_inertial = w_X_bw @ base_velocity - - else: - raise ValueError(self.velocity_representation) - - self.data.model_state.base_linear_velocity = jnp.array( - base_velocity_inertial[0:3], dtype=float - ) - - self.data.model_state.base_angular_velocity = jnp.array( - base_velocity_inertial[3:6], dtype=float - ) - - # =========== - # Integration - # =========== - - @functools.partial( - oop.jax_tf.method_rw, - static_argnames=["sub_steps", "integrator_type", "terrain"], - vmap_in_axes=(0, 0, 0, None, None, None, 0, None), - ) - def integrate( - self, - t0: jtp.Float, - tf: jtp.Float, - sub_steps: int = 1, - integrator_type: Optional[ - "jaxsim.simulation.ode_integration.IntegratorType" - ] = None, - terrain: soft_contacts.Terrain = soft_contacts.FlatTerrain(), - contact_parameters: soft_contacts.SoftContactsParams = soft_contacts.SoftContactsParams(), - clear_inputs: bool = False, - ) -> StepData: - """""" - - from jaxsim.simulation import ode_data, ode_integration - from jaxsim.simulation.ode_integration import IntegratorType - - if integrator_type is None: - integrator_type = IntegratorType.EulerForward - - x0 = ode_integration.ode.ode_data.ODEState( - physics_model=self.data.model_state, - soft_contacts=self.data.contact_state, - ) - - ode_input = ode_integration.ode.ode_data.ODEInput( - physics_model=self.data.model_input - ) - - assert isinstance(integrator_type, IntegratorType) - - # Integrate the model dynamics - ode_states, aux = ode_integration.ode_integration_fixed_step( - x0=x0, - t=jnp.array([t0, tf], dtype=float), - ode_input=ode_input, - physics_model=self.physics_model, - soft_contacts_params=contact_parameters, - num_sub_steps=sub_steps, - terrain=terrain, - integrator_type=integrator_type, - return_aux=True, - ) - - # Get quantities at t0 - t0_model_data = self.data - t0_model_input = jax.tree_util.tree_map( - lambda l: l[0], - aux["ode_input"], - ) - t0_model_input_real = jax.tree_util.tree_map( - lambda l: l[0], - aux["ode_input_real"], - ) - t0_model_acceleration = jax.tree_util.tree_map( - lambda l: l[0], - aux["model_acceleration"], - ) - - # Get quantities at tf - ode_states: ode_data.ODEState - tf_model_state = jax.tree_util.tree_map( - lambda l: l[-1], ode_states.physics_model - ) - tf_contact_state = jax.tree_util.tree_map( - lambda l: l[-1], ode_states.soft_contacts - ) - - # Clear user inputs (joint torques and external forces) if asked - model_input = jax.lax.cond( - pred=clear_inputs, - false_fun=lambda: t0_model_input.physics_model, - true_fun=lambda: jaxsim.physics.model.physics_model_state.PhysicsModelInput.zero( - physics_model=self.physics_model - ), - ) - - # Update model state - self.data = ModelData( - model_state=tf_model_state, - contact_state=tf_contact_state, - model_input=model_input, - ) - - return StepData( - t0=t0, - tf=tf, - dt=(tf - t0), - t0_model_data=t0_model_data, - t0_model_input_real=t0_model_input_real.physics_model, - t0_base_acceleration=t0_model_acceleration[0:6], - t0_joint_acceleration=t0_model_acceleration[6:], - tf_model_state=tf_model_state, - tf_contact_state=tf_contact_state, - aux={ - "t0": jax.tree_util.tree_map( - lambda l: l[0], - aux, - ), - "tf": jax.tree_util.tree_map( - lambda l: l[-1], - aux, - ), - }, - ) - - # ============== - # Motor dynamics - # ============== - - @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) - def set_motor_inertias( - self, inertias: jtp.Vector, joint_names: tuple[str, ...] | None = None - ) -> None: - joint_names = joint_names or self.joint_names() - - if inertias.size != len(joint_names): - raise ValueError("Wrong arguments size", inertias.size, len(joint_names)) - - self.physics_model._joint_motor_inertia.update( - dict(zip(self.physics_model._joint_motor_inertia, inertias)) - ) - - logging.info("Setting attribute `motor_inertias`") - - @functools.partial(oop.jax_tf.method_rw, jit=False) - def set_motor_gear_ratios( - self, gear_ratios: jtp.Vector, joint_names: tuple[str, ...] | None = None - ) -> None: - joint_names = joint_names or self.joint_names() - - if gear_ratios.size != len(joint_names): - raise ValueError("Wrong arguments size", gear_ratios.size, len(joint_names)) - - # Check on gear ratios if motor_inertias are not zero - for idx, gr in enumerate(gear_ratios): - if gr != 0 and self.motor_inertias()[idx] == 0: - raise ValueError( - f"Zero motor inertia with non-zero gear ratio found in position {idx}" - ) - - self.physics_model._joint_motor_gear_ratio.update( - dict(zip(self.physics_model._joint_motor_gear_ratio, gear_ratios)) - ) - - logging.info("Setting attribute `motor_gear_ratios`") - - @functools.partial(oop.jax_tf.method_rw, static_argnames=["joint_names"]) - def set_motor_viscous_frictions( - self, - viscous_frictions: jtp.Vector, - joint_names: tuple[str, ...] | None = None, - ) -> None: - joint_names = joint_names or self.joint_names() - - if viscous_frictions.size != len(joint_names): - raise ValueError( - "Wrong arguments size", viscous_frictions.size, len(joint_names) - ) - - self.physics_model._joint_motor_viscous_friction.update( - dict( - zip( - self.physics_model._joint_motor_viscous_friction, - viscous_frictions, - ) - ) - ) - - logging.info("Setting attribute `motor_viscous_frictions`") - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def motor_inertias(self) -> jtp.Vector: - return jnp.array( - [*self.physics_model._joint_motor_inertia.values()], dtype=float - ) - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def motor_gear_ratios(self) -> jtp.Vector: - return jnp.array( - [*self.physics_model._joint_motor_gear_ratio.values()], dtype=float - ) - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def motor_viscous_frictions(self) -> jtp.Vector: - return jnp.array( - [*self.physics_model._joint_motor_viscous_friction.values()], dtype=float - ) - - # =============== - # Private methods - # =============== - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["is_force"]) - def inertial_to_active_representation( - self, array: jtp.Array, is_force: bool = False - ) -> jtp.Array: - """""" - - W_array = array.squeeze() - - if W_array.size != 6: - raise ValueError(W_array.size) - - if self.velocity_representation is VelRepr.Inertial: - return W_array - - elif self.velocity_representation is VelRepr.Body: - W_H_B = self.base_transform() - - if not is_force: - B_Xv_W = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint() - B_array = B_Xv_W @ W_array - - else: - B_Xf_W = sixd.se3.SE3.from_matrix(W_H_B).adjoint().T - B_array = B_Xf_W @ W_array - - return B_array - - elif self.velocity_representation is VelRepr.Mixed: - W_H_BW = jnp.eye(4).at[0:3, 3].set(self.base_position()) - - if not is_force: - BW_Xv_W = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint() - BW_array = BW_Xv_W @ W_array - - else: - BW_Xf_W = sixd.se3.SE3.from_matrix(W_H_BW).adjoint().T - BW_array = BW_Xf_W @ W_array - - return BW_array - - else: - raise ValueError(self.velocity_representation) - - @functools.partial(oop.jax_tf.method_ro, static_argnames=["is_force"]) - def active_to_inertial_representation( - self, array: jtp.Array, is_force: bool = False - ) -> jtp.Array: - """""" - - array = array.squeeze() - - if array.size != 6: - raise ValueError(array.size) - - if self.velocity_representation is VelRepr.Inertial: - W_array = array - return W_array - - elif self.velocity_representation is VelRepr.Body: - B_array = array - W_H_B = self.base_transform() - - if not is_force: - W_Xv_B: jtp.Array = sixd.se3.SE3.from_matrix(W_H_B).adjoint() - W_array = W_Xv_B @ B_array - - else: - W_Xf_B = sixd.se3.SE3.from_matrix(W_H_B).inverse().adjoint().T - W_array = W_Xf_B @ B_array - - return W_array - - elif self.velocity_representation is VelRepr.Mixed: - BW_array = array - W_H_BW = jnp.eye(4).at[0:3, 3].set(self.base_position()) - - if not is_force: - W_Xv_BW: jtp.Array = sixd.se3.SE3.from_matrix(W_H_BW).adjoint() - W_array = W_Xv_BW @ BW_array - - else: - W_Xf_BW = sixd.se3.SE3.from_matrix(W_H_BW).inverse().adjoint().T - W_array = W_Xf_BW @ BW_array - - return W_array - - else: - raise ValueError(self.velocity_representation) - - def _joint_indices(self, joint_names: tuple[str, ...] | None = None) -> jtp.Vector: - """""" - - if joint_names is None: - joint_names = self.joint_names() - - if set(joint_names) - set(self.joint_names()) != set(): - raise ValueError("One or more joint names are not part of the model") - - # Note: joints share the same index as their child link, therefore the first - # joint has index=1. We need to subtract one to get the right entry of - # data stored in the PhysicsModelState arrays. - joint_indices = [ - j.joint_description.index - 1 for j in self.joints(joint_names=joint_names) - ] - - return np.array(joint_indices, dtype=int) diff --git a/src/jaxsim/math/conv.py b/src/jaxsim/math/conv.py deleted file mode 100644 index 7ce98e237..000000000 --- a/src/jaxsim/math/conv.py +++ /dev/null @@ -1,114 +0,0 @@ -import jax.numpy as jnp - -import jaxsim.typing as jtp - -from .skew import Skew - - -class Convert: - @staticmethod - def coordinates_tf(X: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix: - """ - Transform coordinates from one frame to another using a transformation matrix. - - Args: - X (jtp.Matrix): The transformation matrix (4x4 or 6x6). - p (jtp.Matrix): The coordinates to be transformed (3xN). - - Returns: - jtp.Matrix: Transformed coordinates (3xN). - - Raises: - ValueError: If the input matrix p does not have shape (3, N). - """ - X = X.squeeze() - p = p.squeeze() - - # If p has shape (X,), transform it to a column vector - p = jnp.vstack(p) if len(p.shape) == 1 else p - rows_p, cols_p = p.shape - - if rows_p != 3: - raise ValueError(p.shape) - - R = X[0:3, 0:3] - r = -Skew.vee(R.T @ X[0:3, 3:6]) - - if cols_p > 1: - r = jnp.tile(r, (1, cols_p)) - - assert r.shape == p.shape, (r.shape, p.shape) - - xp = R @ (p - r) - return jnp.vstack(xp) - - @staticmethod - def velocities_threed(v_6d: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix: - """ - Compute 3D velocities based on 6D velocities and positions. - - Args: - v_6d (jtp.Matrix): The 6D velocities (6xN). - p (jtp.Matrix): The positions (3xN). - - Returns: - jtp.Matrix: 3D velocities (3xN). - - Raises: - ValueError: If the input matrices have incompatible shapes. - """ - v = v_6d.squeeze() - p = p.squeeze() - - # If the arrays have shape (X,), transform them to column vectors - v = jnp.vstack(v) if len(v.shape) == 1 else v - p = jnp.vstack(p) if len(p.shape) == 1 else p - - rows_v, cols_v = v.shape - _, cols_p = p.shape - - if cols_v == 1 and cols_p > 1: - v = jnp.repeat(v, cols_p, axis=1) - - if rows_v == 6: - vp = v[0:3, :] + jnp.cross(v[3:6, :], p, axis=0) - else: - raise ValueError(v.shape) - - return jnp.vstack(vp) - - @staticmethod - def forces_sixd(f_3d: jtp.Matrix, p: jtp.Matrix) -> jtp.Matrix: - """ - Compute 6D forces based on 3D forces and positions. - - Args: - f_3d (jtp.Matrix): The 3D forces (3xN). - p (jtp.Matrix): The positions (3xN). - - Returns: - jtp.Matrix: 6D forces (6xN). - - Raises: - ValueError: If the input matrices have incompatible shapes. - """ - f = f_3d.squeeze() - p = p.squeeze() - - # If the arrays have shape (X,), transform them to column vectors - fp = jnp.vstack(f) if len(f.shape) == 1 else f - p = jnp.vstack(p) if len(p.shape) == 1 else p - - _, cols_p = p.shape - rows_fp, cols_fp = fp.shape - - # Number of columns must match - if cols_p != cols_fp: - raise ValueError(cols_p, cols_fp) - - if rows_fp == 3: - f = jnp.vstack([fp, jnp.cross(p, fp, axis=0)]) - else: - raise ValueError(fp.shape) - - return jnp.vstack(f) diff --git a/src/jaxsim/math/plucker.py b/src/jaxsim/math/plucker.py deleted file mode 100644 index da4383887..000000000 --- a/src/jaxsim/math/plucker.py +++ /dev/null @@ -1,100 +0,0 @@ -from typing import Tuple - -import jax.numpy as jnp - -import jaxsim.typing as jtp - -from .skew import Skew - - -class Plucker: - @staticmethod - def from_rot_and_trans(dcm: jtp.Matrix, translation: jtp.Vector) -> jtp.Matrix: - """ - Computes the Plücker matrix from a rotation matrix and a translation vector. - - Args: - dcm: A 3x3 rotation matrix. - translation: A 3x1 translation vector. - - Returns: - A 6x6 Plücker matrix. - """ - R = dcm - - X = jnp.block( - [ - [R, -R @ Skew.wedge(vector=translation)], - [jnp.zeros(shape=(3, 3)), R], - ] - ) - - return X - - @staticmethod - def to_rot_and_trans(adjoint: jtp.Matrix) -> Tuple[jtp.Matrix, jtp.Vector]: - """ - Computes the rotation matrix and translation vector from a Plücker matrix. - - Args: - adjoint: A 6x6 Plücker matrix. - - Returns: - A tuple containing the 3x3 rotation matrix and the 3x1 translation vector. - """ - X = adjoint - - R = X[0:3, 0:3] - p = -Skew.vee(R.T @ X[0:3, 3:6]) - - return R, p - - @staticmethod - def from_transform(transform: jtp.Matrix) -> jtp.Matrix: - """ - Computes the Plücker matrix from a homogeneous transformation matrix. - - Args: - transform: A 4x4 homogeneous transformation matrix. - - Returns: - A 6x6 Plücker matrix. - """ - H = transform - - R = H[0:3, 0:3] - p = H[0:3, 3] - - X = jnp.block( - [ - [R, Skew.wedge(vector=p) @ R], - [jnp.zeros(shape=(3, 3)), R], - ] - ) - - return X - - @staticmethod - def to_transform(adjoint: jtp.Matrix) -> jtp.Matrix: - """ - Computes the homogeneous transformation matrix from a Plücker matrix. - - Args: - adjoint: A 6x6 Plücker matrix. - - Returns: - A 4x4 homogeneous transformation matrix. - """ - X = adjoint - - R = X[0:3, 0:3] - o_x_R = X[0:3, 3:6] - - H = jnp.vstack( - [ - jnp.hstack([R, Skew.vee(matrix=o_x_R @ R.T)]), - [0, 0, 0, 1], - ] - ) - - return H diff --git a/src/jaxsim/physics/__init__.py b/src/jaxsim/physics/__init__.py deleted file mode 100644 index 1527fcdbc..000000000 --- a/src/jaxsim/physics/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -import numpy.typing - -from . import algos, model - - -def default_gravity() -> numpy.typing.NDArray: - import jax.numpy as jnp - - return jnp.array([0, 0, -9.80]) - - -# from . import dyn, models, spatial, threed, utils diff --git a/src/jaxsim/physics/algos/__init__.py b/src/jaxsim/physics/algos/__init__.py deleted file mode 100644 index d6fd929e7..000000000 --- a/src/jaxsim/physics/algos/__init__.py +++ /dev/null @@ -1 +0,0 @@ -StandardGravity = 9.81 diff --git a/src/jaxsim/physics/algos/aba_motors.py b/src/jaxsim/physics/algos/aba_motors.py deleted file mode 100644 index fd7873ffe..000000000 --- a/src/jaxsim/physics/algos/aba_motors.py +++ /dev/null @@ -1,284 +0,0 @@ -from typing import Tuple - -import jax -import jax.numpy as jnp -import numpy as np - -import jaxsim.typing as jtp -from jaxsim.math.adjoint import Adjoint -from jaxsim.math.cross import Cross -from jaxsim.physics.model.physics_model import PhysicsModel - -from . import utils - - -def aba( - model: PhysicsModel, - xfb: jtp.Vector, - q: jtp.Vector, - qd: jtp.Vector, - tau: jtp.Vector, - f_ext: jtp.Matrix | None = None, -) -> Tuple[jtp.Vector, jtp.Vector]: - """ - Articulated Body Algorithm (ABA) algorithm with motor dynamics for forward dynamics. - """ - - x_fb, q, qd, _, tau, f_ext = utils.process_inputs( - physics_model=model, xfb=xfb, q=q, qd=qd, tau=tau, f_ext=f_ext - ) - - # Extract data from the physics model - pre_X_λi = model.tree_transforms - M = model.spatial_inertias - i_X_pre = model.joint_transforms(q=q) - S = model.motion_subspaces(q=q) - λ = model.parent_array() - - # Extract motor parameters from the physics model - Γ = jnp.array([*model._joint_motor_gear_ratio.values()]) - IM = jnp.array( - [jnp.eye(6) * m for m in [*model._joint_motor_inertia.values()]] * model.NB - ) - K̅ᵥ = Γ.T * jnp.array([*model._joint_motor_viscous_friction.values()]) * Γ - m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0) - - # Initialize buffers - v = jnp.array([jnp.zeros([6, 1])] * model.NB) - MA = jnp.array([jnp.zeros([6, 6])] * model.NB) - pA = jnp.array([jnp.zeros([6, 1])] * model.NB) - c = jnp.array([jnp.zeros([6, 1])] * model.NB) - i_X_λi = jnp.zeros_like(i_X_pre) - - m_v = jnp.array([jnp.zeros([6, 1])] * model.NB) - m_c = jnp.array([jnp.zeros([6, 1])] * model.NB) - pR = jnp.array([jnp.zeros([6, 1])] * model.NB) - - # Base pose B_X_W and velocity - base_quat = jnp.vstack(x_fb[0:4]) - base_pos = jnp.vstack(x_fb[4:7]) - base_vel = jnp.vstack(jnp.hstack([x_fb[10:13], x_fb[7:10]])) - - # 6D transform of base velocity - B_X_W = Adjoint.from_quaternion_and_translation( - quaternion=base_quat, - translation=base_pos, - inverse=True, - normalize_quaternion=True, - ) - i_X_λi = i_X_λi.at[0].set(B_X_W) - - # Transforms link -> base - i_X_0 = jnp.zeros_like(pre_X_λi) - i_X_0 = i_X_0.at[0].set(jnp.eye(6)) - - # Initialize base quantities - if model.is_floating_base: - # Base velocity v₀ - v_0 = B_X_W @ base_vel - v = v.at[0].set(v_0) - - # AB inertia (Mᴬ) and AB bias forces (pᴬ) - MA_0 = M[0] - MA = MA.at[0].set(MA_0) - pA_0 = Cross.vx_star(v[0]) @ MA_0 @ v[0] - Adjoint.inverse( - B_X_W - ).T @ jnp.vstack(f_ext[0]) - pA = pA.at[0].set(pA_0) - - Pass1Carry = Tuple[ - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - ] - - pass_1_carry = (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0) - - # Pass 1 - def loop_body_pass1(carry: Pass1Carry, i: jtp.Int) -> Tuple[Pass1Carry, None]: - ii = i - 1 - i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0 = carry - - # Compute parent-to-child transform - i_X_λi_i = i_X_pre[i] @ pre_X_λi[i] - i_X_λi = i_X_λi.at[i].set(i_X_λi_i) - - # Propagate link velocity - vJ = S[i] * qd[ii] * (qd.size != 0) - m_vJ = m_S[i] * qd[ii] * (qd.size != 0) - - v_i = i_X_λi[i] @ v[λ[i]] + vJ - v = v.at[i].set(v_i) - - m_v_i = i_X_λi[i] @ v[λ[i]] + m_vJ - m_v = m_v.at[i].set(m_v_i) - - c_i = Cross.vx(v[i]) @ vJ - c = c.at[i].set(c_i) - m_c_i = Cross.vx(m_v[i]) @ m_vJ - m_c = m_c.at[i].set(m_c_i) - - # Initialize articulated-body inertia - MA_i = jnp.array(M[i]) - MA = MA.at[i].set(MA_i) - - # Initialize articulated-body bias forces - i_X_0_i = i_X_λi[i] @ i_X_0[model.parent[i]] - i_X_0 = i_X_0.at[i].set(i_X_0_i) - i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T - - pA_i = Cross.vx_star(v[i]) @ M[i] @ v[i] - i_Xf_W @ jnp.vstack(f_ext[i]) - pA = pA.at[i].set(pA_i) - - pR_i = Cross.vx_star(m_v[i]) @ IM[i] @ m_v[i] - K̅ᵥ[i] * m_v[i] - pR = pR.at[i].set(pR_i) - - return (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), None - - (i_X_λi, v, c, m_v, m_c, MA, pA, pR, i_X_0), _ = jax.lax.scan( - f=loop_body_pass1, - init=pass_1_carry, - xs=np.arange(start=1, stop=model.NB), - ) - - U = jnp.zeros_like(S) - m_U = jnp.zeros_like(S) - d = jnp.zeros(shape=(model.NB, 1)) - u = jnp.zeros(shape=(model.NB, 1)) - m_u = jnp.zeros(shape=(model.NB, 1)) - - Pass2Carry = Tuple[ - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - ] - - pass_2_carry = (U, m_U, d, u, m_u, MA, pA) - - # Pass 2 - def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> Tuple[Pass2Carry, None]: - ii = i - 1 - U, m_U, d, u, m_u, MA, pA = carry - - # Compute intermediate results - u_i = tau[ii] - S[i].T @ pA[i] if tau.size != 0 else -S[i].T @ pA[i] - u = u.at[i].set(u_i.squeeze()) - - has_motors = jnp.allclose(Γ[i], 1.0) - - m_u_i = ( - tau[ii] / Γ[i] * has_motors - m_S[i].T @ pR[i] - if tau.size != 0 - else -m_S[i].T @ pR[i] - ) - m_u = m_u.at[i].set(m_u_i.squeeze()) - - U_i = MA[i] @ S[i] - U = U.at[i].set(U_i) - - m_U_i = IM[i] @ m_S[i] - m_U = m_U.at[i].set(m_U_i) - - d_i = S[i].T @ MA[i] @ S[i] + m_S[i].T @ IM[i] @ m_S[i] - d = d.at[i].set(d_i.squeeze()) - - # Compute the articulated-body inertia and bias forces of this link - Ma = MA[i] + IM[i] - U[i] / d[i] @ U[i].T - m_U[i] / d[i] @ m_U[i].T - pa = ( - pA[i] - + pR[i] - + Ma[i] @ c[i] - + IM[i] @ m_c[i] - + U[i] / d[i] * u[i] - + m_U[i] / d[i] * m_u[i] - ) - - # Propagate them to the parent, handling the base link - def propagate( - MA_pA: Tuple[jtp.MatrixJax, jtp.MatrixJax] - ) -> Tuple[jtp.MatrixJax, jtp.MatrixJax]: - MA, pA = MA_pA - - MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i] - MA = MA.at[λ[i]].set(MA_λi) - - pA_λi = pA[λ[i]] + i_X_λi[i].T @ pa - pA = pA.at[λ[i]].set(pA_λi) - - return MA, pA - - MA, pA = jax.lax.cond( - pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(), - true_fun=propagate, - false_fun=lambda MA_pA: MA_pA, - operand=(MA, pA), - ) - - return (U, m_U, d, u, m_u, MA, pA), None - - (U, m_U, d, u, m_u, MA, pA), _ = jax.lax.scan( - f=loop_body_pass2, - init=pass_2_carry, - xs=np.flip(np.arange(start=1, stop=model.NB)), - ) - - if model.is_floating_base: - a0 = jnp.linalg.solve(-MA[0], pA[0]) - else: - a0 = -B_X_W @ jnp.vstack(model.gravity) - - a = jnp.zeros_like(S) - a = a.at[0].set(a0) - qdd = jnp.zeros_like(q) - - Pass3Carry = Tuple[jtp.MatrixJax, jtp.VectorJax] - pass_3_carry = (a, qdd) - - # Pass 3 - def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> Tuple[Pass3Carry, None]: - ii = i - 1 - a, qdd = carry - - # Propagate link accelerations - a_i = i_X_λi[i] @ a[λ[i]] + c[i] - - # Compute joint accelerations - qdd_ii = (u[i] + m_u[i] - (U[i].T + m_U[i].T) @ a_i) / d[i] - qdd = qdd.at[ii].set(qdd_ii.squeeze()) if qdd.size != 0 else qdd - - a_i = a_i + S[i] * qdd[ii] if qdd.size != 0 else a_i - a = a.at[i].set(a_i) - - return (a, qdd), None - - (a, qdd), _ = jax.lax.scan( - f=loop_body_pass3, - init=pass_3_carry, - xs=np.arange(1, model.NB), - ) - - # Handle 1 DoF models - qdd = jnp.atleast_1d(qdd.squeeze()) - qdd = jnp.vstack(qdd) if qdd.size > 0 else jnp.empty(shape=(0, 1)) - - # Get the resulting base acceleration (w/o gravity) in body-fixed representation - B_a_WB = a[0] - - # Convert the base acceleration to inertial-fixed representation, and add gravity - W_a_WB = jnp.vstack( - jnp.linalg.solve(B_X_W, B_a_WB) + jnp.vstack(model.gravity) - if model.is_floating_base - else jnp.zeros(6) - ) - - return W_a_WB, qdd diff --git a/src/jaxsim/physics/algos/rnea_motors.py b/src/jaxsim/physics/algos/rnea_motors.py deleted file mode 100644 index 84a8c02cd..000000000 --- a/src/jaxsim/physics/algos/rnea_motors.py +++ /dev/null @@ -1,196 +0,0 @@ -from typing import Tuple - -import jax -import jax.numpy as jnp -import numpy as np - -import jaxsim.typing as jtp -from jaxsim.math.adjoint import Adjoint -from jaxsim.math.cross import Cross -from jaxsim.physics.model.physics_model import PhysicsModel - -from . import utils - - -def rnea( - model: PhysicsModel, - xfb: jtp.Vector, - q: jtp.Vector, - qd: jtp.Vector, - qdd: jtp.Vector, - a0fb: jtp.Vector = jnp.zeros(6), - f_ext: jtp.Matrix | None = None, -) -> Tuple[jtp.Vector, jtp.Vector]: - """ - Recursive Newton-Euler Algorithm (RNEA) algorithm for inverse dynamics. - """ - - xfb, q, qd, qdd, _, f_ext = utils.process_inputs( - physics_model=model, xfb=xfb, q=q, qd=qd, qdd=qdd, f_ext=f_ext - ) - - a0fb = a0fb.squeeze() - gravity = model.gravity.squeeze() - - if a0fb.shape[0] != 6: - raise ValueError(a0fb.shape) - - M = model.spatial_inertias - pre_X_λi = model.tree_transforms - i_X_pre = model.joint_transforms(q=q) - S = model.motion_subspaces(q=q) - i_X_λi = jnp.zeros_like(pre_X_λi) - - Γ = jnp.array([*model._joint_motor_gear_ratio.values()]) - IM = jnp.array([*model._joint_motor_inertia.values()]) - K_v = jnp.array([*model._joint_motor_viscous_friction.values()]) - K̅ᵥ = jnp.diag(Γ.T * jnp.diag(K_v) * Γ) - m_S = jnp.concatenate([S[:1], S[1:] * Γ[:, None, None]], axis=0) - - i_X_0 = jnp.zeros_like(pre_X_λi) - i_X_0 = i_X_0.at[0].set(jnp.eye(6)) - - # Parent array mapping: i -> λ(i). - # Exception: λ(0) must not be used, it's initialized to -1. - λ = model.parent_array() - - v = jnp.array([jnp.zeros([6, 1])] * model.NB) - a = jnp.array([jnp.zeros([6, 1])] * model.NB) - f = jnp.array([jnp.zeros([6, 1])] * model.NB) - - v_m = jnp.array([jnp.zeros([6, 1])] * model.NB) - a_m = jnp.array([jnp.zeros([6, 1])] * model.NB) - f_m = jnp.array([jnp.zeros([6, 1])] * model.NB) - - # 6D transform of base velocity - B_X_W = Adjoint.from_quaternion_and_translation( - quaternion=xfb[0:4], - translation=xfb[4:7], - inverse=True, - normalize_quaternion=True, - ) - i_X_λi = i_X_λi.at[0].set(B_X_W) - - a_0 = -B_X_W @ jnp.vstack(gravity) - a = a.at[0].set(a_0) - - if model.is_floating_base: - W_v_WB = jnp.vstack(jnp.hstack([xfb[10:13], xfb[7:10]])) - - v_0 = B_X_W @ W_v_WB - v = v.at[0].set(v_0) - - a_0 = B_X_W @ (jnp.vstack(a0fb) - jnp.vstack(gravity)) - a = a.at[0].set(a_0) - - f_0 = ( - M[0] @ a[0] - + Cross.vx_star(v[0]) @ M[0] @ v[0] - - Adjoint.inverse(B_X_W).T @ jnp.vstack(f_ext[0]) - ) - f = f.at[0].set(f_0) - - ForwardPassCarry = Tuple[ - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - jtp.MatrixJax, - ] - forward_pass_carry = (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m) - - def forward_pass( - carry: ForwardPassCarry, i: jtp.Int - ) -> Tuple[ForwardPassCarry, None]: - ii = i - 1 - i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m = carry - - vJ = S[i] * qd[ii] - vJ_m = m_S[i] * qd[ii] - - i_X_λi_i = i_X_pre[i] @ pre_X_λi[i] - i_X_λi = i_X_λi.at[i].set(i_X_λi_i) - - v_i = i_X_λi[i] @ v[λ[i]] + vJ - v = v.at[i].set(v_i) - - v_i_m = i_X_λi[i] @ v_m[λ[i]] + vJ_m - v_m = v_m.at[i].set(v_i_m) - - a_i = i_X_λi[i] @ a[λ[i]] + S[i] * qdd[ii] + Cross.vx(v[i]) @ vJ - a = a.at[i].set(a_i) - - a_i_m = i_X_λi[i] @ a_m[λ[i]] + m_S[i] * qdd[ii] + Cross.vx(v_m[i]) @ vJ_m - a_m = a_m.at[i].set(a_i_m) - - i_X_0_i = i_X_λi[i] @ i_X_0[λ[i]] - i_X_0 = i_X_0.at[i].set(i_X_0_i) - i_Xf_W = Adjoint.inverse(i_X_0[i] @ B_X_W).T - - f_i = ( - M[i] @ a[i] - + Cross.vx_star(v[i]) @ M[i] @ v[i] - - i_Xf_W @ jnp.vstack(f_ext[i]) - ) - f = f.at[i].set(f_i) - - f_i_m = IM[i] * a_m[i] + Cross.vx_star(v_m[i]) * IM[i] @ v_m[i] - f_m = f_m.at[i].set(f_i_m) - - return (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), None - - (i_X_λi, v, v_m, a, a_m, i_X_0, f, f_m), _ = jax.lax.scan( - f=forward_pass, - init=forward_pass_carry, - xs=np.arange(start=1, stop=model.NB), - ) - - tau = jnp.zeros_like(q) - - BackwardPassCarry = Tuple[jtp.MatrixJax, jtp.MatrixJax, jtp.MatrixJax] - backward_pass_carry = (tau, f, f_m) - - def backward_pass( - carry: BackwardPassCarry, i: jtp.Int - ) -> Tuple[BackwardPassCarry, None]: - ii = i - 1 - tau, f, f_m = carry - - value = S[i].T @ f[i] + m_S[i].T @ f_m[i] # + K̅ᵥ[i] * qd[ii] - tau = tau.at[ii].set(value.squeeze()) - - def update_f(ffm: Tuple[jtp.MatrixJax, jtp.MatrixJax]) -> jtp.MatrixJax: - f, f_m = ffm - f_λi = f[λ[i]] + i_X_λi[i].T @ f[i] - f = f.at[λ[i]].set(f_λi) - - f_m_λi = f_m[λ[i]] + i_X_λi[i].T @ f_m[i] - f_m = f_m.at[λ[i]].set(f_m_λi) - return f, f_m - - f, f_m = jax.lax.cond( - pred=jnp.array([λ[i] != 0, model.is_floating_base]).any(), - true_fun=update_f, - false_fun=lambda f: f, - operand=(f, f_m), - ) - - return (tau, f, f_m), None - - (tau, f, f_m), _ = jax.lax.scan( - f=backward_pass, - init=backward_pass_carry, - xs=np.flip(np.arange(start=1, stop=model.NB)), - ) - - # Handle 1 DoF models - tau = jnp.atleast_1d(tau.squeeze()) - tau = jnp.vstack(tau) if tau.size > 0 else jnp.empty(shape=(0, 1)) - - # Express the base 6D force in the world frame - W_f0 = B_X_W.T @ jnp.vstack(f[0]) - - return W_f0, tau diff --git a/src/jaxsim/physics/model/__init__.py b/src/jaxsim/physics/model/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/jaxsim/physics/model/physics_model.py b/src/jaxsim/physics/model/physics_model.py deleted file mode 100644 index 34143f55b..000000000 --- a/src/jaxsim/physics/model/physics_model.py +++ /dev/null @@ -1,388 +0,0 @@ -import dataclasses -from typing import Dict, Union - -import jax.lax -import jax.numpy as jnp -import jax_dataclasses -import jaxlie -import numpy as np -from jax_dataclasses import Static - -import jaxsim.parsers -import jaxsim.physics -import jaxsim.typing as jtp -from jaxsim.parsers.descriptions import JointDescriptor, JointType -from jaxsim.physics import default_gravity -from jaxsim.utils import JaxsimDataclass, not_tracing - -from .ground_contact import GroundContact -from .physics_model_state import PhysicsModelState - - -@jax_dataclasses.pytree_dataclass -class PhysicsModel(JaxsimDataclass): - """ - A read-only class to store all the information necessary to run RBDAs on a model. - - This class contains information about the physics model, including the number of bodies, initial state, gravity, - floating base configuration, ground contact points, and more. - - Attributes: - NB (Static[int]): The number of bodies in the physics model. - initial_state (PhysicsModelState): The initial state of the physics model (default: None). - gravity (jtp.Vector): The gravity vector (default: [0, 0, 0, 0, 0, 0]). - is_floating_base (Static[bool]): A flag indicating whether the model has a floating base (default: False). - gc (GroundContact): The ground contact points of the model (default: empty GroundContact instance). - description (Static[jaxsim.parsers.descriptions.model.ModelDescription]): A description of the model (default: None). - """ - - NB: Static[int] - initial_state: PhysicsModelState = dataclasses.field(default=None) - gravity: jtp.Vector = dataclasses.field( - default_factory=lambda: jnp.hstack( - [np.zeros(3), jaxsim.physics.default_gravity()] - ) - ) - is_floating_base: Static[bool] = dataclasses.field(default=False) - gc: GroundContact = dataclasses.field(default_factory=lambda: GroundContact()) - description: Static[jaxsim.parsers.descriptions.model.ModelDescription] = ( - dataclasses.field(default=None) - ) - - _parent_array_dict: Static[Dict[int, int]] = dataclasses.field(default_factory=dict) - _jtype_dict: Static[Dict[int, Union[JointType, JointDescriptor]]] = ( - dataclasses.field(default_factory=dict) - ) - _tree_transforms_dict: Dict[int, jtp.Matrix] = dataclasses.field( - default_factory=dict - ) - _link_inertias_dict: Dict[int, jtp.Matrix] = dataclasses.field(default_factory=dict) - - _joint_friction_static: Dict[int, float] = dataclasses.field(default_factory=dict) - _joint_friction_viscous: Dict[int, float] = dataclasses.field(default_factory=dict) - - _joint_limit_spring: Dict[int, float] = dataclasses.field(default_factory=dict) - _joint_limit_damper: Dict[int, float] = dataclasses.field(default_factory=dict) - - _joint_motor_inertia: Dict[int, float] = dataclasses.field(default_factory=dict) - _joint_motor_gear_ratio: Dict[int, float] = dataclasses.field(default_factory=dict) - _joint_motor_viscous_friction: Dict[int, float] = dataclasses.field( - default_factory=dict - ) - - _link_masses: jtp.Vector = dataclasses.field(init=False) - _link_spatial_inertias: jtp.Vector = dataclasses.field(init=False) - _joint_position_limits_min: jtp.Matrix = dataclasses.field(init=False) - _joint_position_limits_max: jtp.Matrix = dataclasses.field(init=False) - - def __post_init__(self): - if self.initial_state is None: - initial_state = PhysicsModelState.zero(physics_model=self) - object.__setattr__(self, "initial_state", initial_state) - - ordered_links = sorted( - list(self.description.links_dict.values()), - key=lambda l: l.index, - ) - - ordered_joints = sorted( - list(self.description.joints_dict.values()), - key=lambda j: j.index, - ) - - from jaxsim.utils import Mutability - - with self.mutable_context( - mutability=Mutability.MUTABLE_NO_VALIDATION, restore_after_exception=False - ): - self._link_masses = jnp.stack([link.mass for link in ordered_links]) - self._link_spatial_inertias = jnp.stack( - [self._link_inertias_dict[l.index] for l in ordered_links] - ) - - s_min = jnp.array([j.position_limit[0] for j in ordered_joints]) - s_max = jnp.array([j.position_limit[1] for j in ordered_joints]) - self._joint_position_limits_min = jnp.vstack([s_min, s_max]).min(axis=0) - self._joint_position_limits_max = jnp.vstack([s_min, s_max]).max(axis=0) - - @staticmethod - def build_from( - model_description: jaxsim.parsers.descriptions.model.ModelDescription, - gravity: jtp.Vector = default_gravity(), - ) -> "PhysicsModel": - if gravity.size != 3: - raise ValueError(gravity.size) - - # Currently, we assume that the link frame matches the frame of its parent joint - for l in model_description: - if not jnp.allclose(l.pose, jnp.eye(4)): - raise ValueError(f"Link '{l.name}' has unsupported pose:\n{l.pose}") - - # =================================== - # Initialize physics model parameters - # =================================== - - # Get the number of bodies, including the base link - num_of_bodies = len(model_description) - - # Build the parent array λ of the floating-base model. - # Note: the parent of the base link is not set since it's not defined. - parent_array_dict = { - link.index: link.parent.index - for link in model_description - if link.parent is not None - } - - # Get the 6D inertias of all links - link_spatial_inertias_dict = { - link.index: link.inertia for link in iter(model_description) - } - - # Dict from the joint index to its type. - # Note: the joint index is equal to its child link index. - joint_types_dict = { - joint.index: joint.jtype for joint in model_description.joints - } - - # Dicts from the joint index to the static and viscous friction. - # Note: the joint index is equal to its child link index. - joint_friction_static = { - joint.index: jnp.array(joint.friction_static, dtype=float) - for joint in model_description.joints - } - joint_friction_viscous = { - joint.index: jnp.array(joint.friction_viscous, dtype=float) - for joint in model_description.joints - } - - # Dicts from the joint index to the spring and damper joint limits parameters. - # Note: the joint index is equal to its child link index. - joint_limit_spring = { - joint.index: jnp.array(joint.position_limit_spring, dtype=float) - for joint in model_description.joints - } - joint_limit_damper = { - joint.index: jnp.array(joint.position_limit_damper, dtype=float) - for joint in model_description.joints - } - - # Dicts from the joint index to the motor inertia, gear ratio and viscous friction. - # Note: the joint index is equal to its child link index. - joint_motor_inertia = { - joint.index: jnp.array(joint.motor_inertia, dtype=float) - for joint in model_description.joints - } - joint_motor_gear_ratio = { - joint.index: jnp.array(joint.motor_gear_ratio, dtype=float) - for joint in model_description.joints - } - joint_motor_viscous_friction = { - joint.index: jnp.array(joint.motor_viscous_friction, dtype=float) - for joint in model_description.joints - } - - # Transform between model's root and model's base link - # (this is just the pose of the base link in the SDF description) - base_link = model_description.links_dict[model_description.link_names()[0]] - R_H_B = model_description.transform(name=base_link.name) - tree_transform_0 = jaxlie.SE3.from_matrix(matrix=R_H_B).adjoint() - - # Helper to compute the transform pre(i)_H_λ(i). - # Given a joint 'i', it is the coordinate transform between its predecessor - # frame [pre(i)] and the frame of its parent link [λ(i)]. - prei_H_λi = lambda j: model_description.relative_transform( - relative_to=j.name, name=j.parent.name - ) - - # Compute the tree transforms: pre(i)_X_λ(i). - # Given a joint 'i', it is the coordinate transform between its predecessor - # frame [pre(i)] and the frame of its parent link [λ(i)]. - tree_transforms_dict = { - 0: tree_transform_0, - **{ - j.index: jaxlie.SE3.from_matrix(matrix=prei_H_λi(j)).adjoint() - for j in model_description.joints - }, - } - - # ======================= - # Build the initial state - # ======================= - - # Initial joint positions - q0 = jnp.array( - [ - model_description.joints_dict[j.name].initial_position - for j in model_description.joints - ] - ) - - # Build the initial state - initial_state = PhysicsModelState( - joint_positions=q0, - joint_velocities=jnp.zeros_like(q0), - base_position=model_description.root_pose.root_position, - base_quaternion=model_description.root_pose.root_quaternion, - ) - - # ======================= - # Build the physics model - # ======================= - - # Initialize the model - physics_model = PhysicsModel( - NB=num_of_bodies, - initial_state=initial_state, - _parent_array_dict=parent_array_dict, - _jtype_dict=joint_types_dict, - _tree_transforms_dict=tree_transforms_dict, - _link_inertias_dict=link_spatial_inertias_dict, - _joint_friction_static=joint_friction_static, - _joint_friction_viscous=joint_friction_viscous, - _joint_limit_spring=joint_limit_spring, - _joint_limit_damper=joint_limit_damper, - _joint_motor_gear_ratio=joint_motor_gear_ratio, - _joint_motor_inertia=joint_motor_inertia, - _joint_motor_viscous_friction=joint_motor_viscous_friction, - gravity=jnp.hstack([gravity.squeeze(), np.zeros(3)]), - is_floating_base=True, - gc=GroundContact.build_from(model_description=model_description), - description=model_description, - ) - - # Floating-base models - if not model_description.fixed_base: - return physics_model - - # Fixed-base models - with jax_dataclasses.copy_and_mutate(physics_model) as physics_model_fixed: - physics_model_fixed.is_floating_base = False - - return physics_model_fixed - - def dofs(self) -> int: - return len(list(self._jtype_dict.keys())) - - def set_gravity(self, gravity: jtp.Vector) -> None: - gravity = gravity.squeeze() - - if gravity.size == 3: - self.gravity = jnp.hstack([gravity, 0, 0, 0]) - - elif gravity.size == 6: - self.gravity = gravity - - else: - raise ValueError(gravity.shape) - - @property - def parent(self) -> jtp.Vector: - return self.parent_array() - - def parent_array(self) -> jtp.Vector: - """Returns λ(i)""" - return jnp.array([-1] + list(self._parent_array_dict.values()), dtype=int) - - def support_body_array(self, body_index: jtp.Int) -> jtp.Vector: - """Returns κ(i)""" - - κ_bool = self.support_body_array_bool(body_index=body_index) - return jnp.array(jnp.where(κ_bool)[0], dtype=int) - - def support_body_array_bool(self, body_index: jtp.Int) -> jtp.Vector: - active_link = body_index - κ_bool = jnp.zeros(self.NB, dtype=bool) - - for i in np.flip(np.arange(start=0, stop=self.NB)): - κ_bool, active_link = jax.lax.cond( - pred=(i == active_link), - false_fun=lambda: (κ_bool, active_link), - true_fun=lambda: ( - κ_bool.at[active_link].set(True), - self.parent[active_link], - ), - ) - - return κ_bool - - @property - def tree_transforms(self) -> jtp.Array: - X_tree = jnp.array( - [ - self._tree_transforms_dict.get(idx, jnp.eye(6)) - for idx in np.arange(start=0, stop=self.NB) - ] - ) - - return X_tree - - @property - def spatial_inertias(self) -> jtp.Array: - M_links = jnp.array( - [ - self._link_inertias_dict.get(idx, jnp.zeros(6)) - for idx in np.arange(start=0, stop=self.NB) - ] - ) - - return M_links - - def jtype(self, joint_index: int) -> JointType: - if joint_index == 0 or joint_index >= self.NB: - raise ValueError(joint_index) - - return self._jtype_dict[joint_index] - - def joint_transforms(self, q: jtp.Vector) -> jtp.Array: - from jaxsim.math.joint import jcalc - - if not_tracing(q): - if q.shape[0] != self.dofs(): - raise ValueError(q.shape) - - Xj = jnp.stack( - [jnp.zeros(shape=(6, 6))] - + [ - jcalc(jtyp=self.jtype(index + 1), q=joint_position)[0] - for index, joint_position in enumerate(q) - ] - ) - - return Xj - - def motion_subspaces(self, q: jtp.Vector) -> jtp.Array: - from jaxsim.math.joint import jcalc - - if not_tracing(var=q): - if q.shape[0] != self.dofs(): - raise ValueError(q.shape) - - SS = jnp.stack( - [jnp.vstack(jnp.zeros(6))] - + [ - jcalc(jtyp=self.jtype(index + 1), q=joint_position)[1] - for index, joint_position in enumerate(q) - ] - ) - - return SS - - def __eq__(self, other: "PhysicsModel") -> bool: - same = True - same = same and self.NB == other.NB - same = same and np.allclose(self.gravity, other.gravity) - - return same - - def __hash__(self): - return hash(self.__repr__()) - - def __repr__(self) -> str: - attributes = [ - f"dofs: {self.dofs()},", - f"links: {self.NB},", - f"floating_base: {self.is_floating_base},", - ] - attributes_string = "\n ".join(attributes) - - return f"{type(self).__name__}(\n {attributes_string}\n)" diff --git a/src/jaxsim/simulation/__init__.py b/src/jaxsim/simulation/__init__.py deleted file mode 100644 index 21dd7c132..000000000 --- a/src/jaxsim/simulation/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .ode_data import ODEInput, ODEState diff --git a/src/jaxsim/simulation/integrators.py b/src/jaxsim/simulation/integrators.py deleted file mode 100644 index ca0e00930..000000000 --- a/src/jaxsim/simulation/integrators.py +++ /dev/null @@ -1,393 +0,0 @@ -import enum -from typing import Any, Callable - -import jax -import jax.numpy as jnp -from jax.tree_util import tree_map - -import jaxsim.typing as jtp -from jaxsim.math.quaternion import Quaternion -from jaxsim.physics.algos.soft_contacts import SoftContactsState -from jaxsim.physics.model.physics_model_state import PhysicsModelState -from jaxsim.simulation.ode_data import ODEState -from jaxsim.sixd import se3, so3 - -Time = jtp.FloatLike -TimeStep = jtp.FloatLike -TimeHorizon = jtp.VectorLike - -State = jtp.PyTree -StateDerivative = jtp.PyTree - -StateDerivativeCallable = Callable[ - [State, Time], tuple[StateDerivative, dict[str, Any]] -] - - -class IntegratorType(enum.IntEnum): - RungeKutta4 = enum.auto() - EulerForward = enum.auto() - EulerSemiImplicit = enum.auto() - EulerSemiImplicitManifold = enum.auto() - - -# ======================= -# Single-step integration -# ======================= - - -def integrator_fixed_single_step( - dx_dt: StateDerivativeCallable, - x0: State | ODEState, - t0: Time, - tf: Time, - integrator_type: IntegratorType, - num_sub_steps: int = 1, -) -> tuple[State | ODEState, dict[str, Any]]: - """ - Advance a state vector by integrating a sytem dynamics with a fixed-step integrator. - - Args: - dx_dt: Callable that computes the state derivative. - x0: Initial state. - t0: Initial time. - tf: Final time. - integrator_type: Integrator type. - num_sub_steps: Number of sub-steps to break the integration into. - - Returns: - The final state and a dictionary including auxiliary data at t0. - """ - - # Compute the sub-step size. - # We break dt in configurable sub-steps. - dt = tf - t0 - sub_step_dt = dt / num_sub_steps - - # Initialize the carry - Carry = tuple[State | ODEState, Time] - carry_init: Carry = (x0, t0) - - def forward_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]: - """ - Forward Euler integrator. - """ - - # Unpack the carry - x_t0, t0 = carry - - # Compute the state derivative - dxdt_t0, _ = dx_dt(x_t0, t0) - - # Integrate the dynamics - x_tf = jax.tree_util.tree_map( - lambda x, dxdt: x + sub_step_dt * dxdt, x_t0, dxdt_t0 - ) - - # Update the time - tf = t0 + sub_step_dt - - # Pack the carry - carry = (x_tf, tf) - - return carry, None - - def rk4_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]: - """ - Runge-Kutta 4 integrator. - """ - - # Unpack the carry - x_t0, t0 = carry - - # Helper to forward the state to compute k2 and k3 at midpoint and k4 at final - euler_mid = lambda x, dxdt: x + (0.5 * sub_step_dt) * dxdt - euler_fin = lambda x, dxdt: x + sub_step_dt * dxdt - - # Compute the RK4 slopes - k1, _ = dx_dt(x_t0, t0) - k2, _ = dx_dt(tree_map(euler_mid, x_t0, k1), t0 + 0.5 * sub_step_dt) - k3, _ = dx_dt(tree_map(euler_mid, x_t0, k2), t0 + 0.5 * sub_step_dt) - k4, _ = dx_dt(tree_map(euler_fin, x_t0, k3), t0 + sub_step_dt) - - # Average the slopes and compute the RK4 state derivative - average = lambda k1, k2, k3, k4: (k1 + 2 * k2 + 2 * k3 + k4) / 6 - dxdt = jax.tree_util.tree_map(average, k1, k2, k3, k4) - - # Integrate the dynamics - x_tf = jax.tree_util.tree_map(euler_fin, x_t0, dxdt) - - # Update the time - tf = t0 + sub_step_dt - - # Pack the carry - carry = (x_tf, tf) - - return carry, None - - def semi_implicit_euler_body_fun(carry: Carry, xs: None) -> tuple[Carry, None]: - """ - Semi-implicit Euler integrator. - """ - - # Unpack the carry - x_t0, t0 = carry - - # Compute the state derivative. - # We only keep the quantities related to the acceleration and discard those - # related to the velocity since we are going to use those implicitly integrated - # from the accelerations. - StateDerivative = ODEState - dxdt_t0: StateDerivative = dx_dt(x_t0, t0)[0] - - # Extract the initial position ∈ ℝ⁷⁺ⁿ and initial velocity ∈ ℝ⁶⁺ⁿ. - # This integrator, contrarily to most of the other ones, is not generic. - # It expects to operate on an x object of class ODEState. - pos_t0 = x_t0.physics_model.position() - vel_t0 = x_t0.physics_model.velocity() - - # Extract the velocity derivative - d_vel_dt = dxdt_t0.physics_model.velocity() - - # ============================================= - # Perform semi-implicit Euler integration [1-4] - # ============================================= - - # 1. Integrate the accelerations obtaining the implicit velocities - # 2. Compute the derivative of the generalized position - # 3. Integrate the implicit velocities - # 4. Integrate the remaining state - # 5. Outside the loop: integrate the quaternion on SO(3) manifold - - # ---------------------------------------------------------------- - # 1. Integrate the accelerations obtaining the implicit velocities - # ---------------------------------------------------------------- - - vel_tf = vel_t0 + sub_step_dt * d_vel_dt - - # ----------------------------------------------------- - # 2. Compute the derivative of the generalized position - # ----------------------------------------------------- - - # Extract the implicit angular velocity and the initial base quaternion - W_ω_WB = vel_tf[3:6] - W_Q_B = x_t0.physics_model.base_quaternion - - # Compute the quaternion derivative and the base position derivative - W_Qd_B = Quaternion.derivative( - quaternion=W_Q_B, omega=W_ω_WB, omega_in_body_fixed=False - ).squeeze() - - # Compute the transform of the mixed base frame at t0 - W_H_BW = jnp.vstack( - [ - jnp.block([jnp.eye(3), jnp.vstack(x_t0.physics_model.base_position)]), - jnp.array([0, 0, 0, 1]), - ] - ) - - # The derivative W_ṗ_B of the base position is the linear component of the - # mixed velocity B[W]_v_WB. We need to compute it from the velocity in - # inertial-fixed representation W_vl_WB. - W_v_WB = vel_tf[0:6] - BW_Xv_W = se3.SE3.from_matrix(W_H_BW).inverse().adjoint() - BW_vl_WB = (BW_Xv_W @ W_v_WB)[0:3] - - # Compute the derivative of the generalized position - d_pos_tf = ( - jnp.hstack([BW_vl_WB, vel_tf[6:]]) - if integrator_type is IntegratorType.EulerSemiImplicitManifold - else jnp.hstack([BW_vl_WB, W_Qd_B, vel_tf[6:]]) - ) - - # ------------------------------------ - # 3. Integrate the implicit velocities - # ------------------------------------ - - pos_tf = pos_t0 + sub_step_dt * d_pos_tf - joint_positions = ( - pos_tf[3:] - if integrator_type is IntegratorType.EulerSemiImplicitManifold - else pos_tf[7:] - ) - base_quaternion = ( - jnp.zeros_like(x_t0.base_quaternion) - if integrator_type is IntegratorType.EulerSemiImplicitManifold - else pos_tf[3:7] - ) - - # --------------------------------- - # 4. Integrate the remaining state - # --------------------------------- - - # Integrate the derivative of the tangential material deformation - m = x_t0.soft_contacts.tangential_deformation - ṁ = dxdt_t0.soft_contacts.tangential_deformation - tangential_deformation_tf = m + sub_step_dt * ṁ - - # Pack the new state into an ODEState object - x_tf = ODEState( - physics_model=PhysicsModelState( - base_position=pos_tf[0:3], - base_quaternion=base_quaternion, - joint_positions=joint_positions, - base_linear_velocity=vel_tf[0:3], - base_angular_velocity=vel_tf[3:6], - joint_velocities=vel_tf[6:], - ), - soft_contacts=SoftContactsState( - tangential_deformation=tangential_deformation_tf - ), - ) - - # Update the time - tf = t0 + sub_step_dt - - # Pack the carry - carry = (x_tf, tf) - - return carry, None - - _integrator_registry = { - IntegratorType.RungeKutta4: rk4_body_fun, - IntegratorType.EulerForward: forward_euler_body_fun, - IntegratorType.EulerSemiImplicit: semi_implicit_euler_body_fun, - IntegratorType.EulerSemiImplicitManifold: semi_implicit_euler_body_fun, - } - - # Get the body function for the selected integrator - body_fun = _integrator_registry[integrator_type] - - # Integrate over the given horizon - (x_tf, _), _ = jax.lax.scan( - f=body_fun, init=carry_init, xs=None, length=num_sub_steps - ) - - if integrator_type is IntegratorType.EulerSemiImplicitManifold: - # Indices to convert quaternions between serializations - to_xyzw = jnp.array([1, 2, 3, 0]) - to_wxyz = jnp.array([3, 0, 1, 2]) - - # Get the initial quaternion and the implicitly integrated angular velocity - W_ω_WB_tf = x_tf.physics_model.base_angular_velocity - W_Q_B_t0 = so3.SO3.from_quaternion_xyzw( - x0.physics_model.base_quaternion[to_xyzw] - ) - - # Integrate the quaternion on its manifold using the implicit angular velocity, - # transformed in body-fixed representation since jaxlie uses this convention - B_R_W = W_Q_B_t0.inverse().as_matrix() - W_Q_B_tf = W_Q_B_t0 @ so3.SO3.exp(tangent=dt * B_R_W @ W_ω_WB_tf) - - # Store the quaternion in the final state - x_tf = x_tf.replace( - physics_model=x_tf.physics_model.replace( - base_quaternion=W_Q_B_tf.as_quaternion_xyzw()[to_wxyz] - ) - ) - - # Compute the aux dictionary at t0 - _, aux_t0 = dx_dt(x0, t0) - - return x_tf, aux_t0 - - -# =============================== -# Adapter: single step -> horizon -# =============================== - - -def integrate_single_step_over_horizon( - integrator_single_step: Callable[[Time, Time, State], tuple[State, dict[str, Any]]], - t: TimeHorizon, - x0: State, -) -> tuple[State, dict[str, Any]]: - """ - Integrate a single-step integrator over a given horizon. - - Args: - integrator_single_step: A single-step integrator. - t: The vector of time instants of the integration horizon. - x0: The initial state of the integration horizon. - - Returns: - The final state and auxiliary data produced by the integrator. - """ - - # Initialize the carry - carry_init = (x0, t) - - def body_fun(carry: tuple, idx: int) -> tuple[tuple, jtp.PyTree]: - # Unpack the carry - x_t0, horizon = carry - - # Get the integration interval - t0 = horizon[idx] - tf = horizon[idx + 1] - - # Perform a single-step integration of the ODE - x_tf, aux_t0 = integrator_single_step(t0, tf, x_t0) - - # Prepare returned data - out = (x_t0, aux_t0) - carry = (x_tf, horizon) - - return carry, out - - # Integrate over the given horizon - _, (x_horizon, aux_horizon) = jax.lax.scan( - f=body_fun, init=carry_init, xs=jnp.arange(start=0, stop=len(t), dtype=int) - ) - - return x_horizon, aux_horizon - - -# =================================================================== -# Integration over horizon (same APIs of jax.experimental.ode.odeint) -# =================================================================== - - -def odeint( - func, - y0: State, - t: TimeHorizon, - *args, - num_sub_steps: int = 1, - return_aux: bool = False, - integrator_type: IntegratorType = None, -): - """ - Integrate a system of ODEs with a fixed-step integrator. - - Args: - func: A function that computes the time-derivative of the state. - y0: The initial state. - t: The vector of time instants of the integration horizon. - *args: Additional arguments to be passed to the function func. - num_sub_steps: The number of sub-steps to be performed within each integration step. - return_aux: Whether to return the auxiliary data produced by the integrator. - - Returns: - The state of the system at the end of the integration horizon, and optionally - the auxiliary data produced by the integrator. - """ - - # Close func over additional inputs and parameters - dx_dt_closure = lambda x, ts: func(x, ts, *args) - - # Close one-step integration over its arguments - integrator_single_step = lambda t0, tf, x0: integrator_fixed_single_step( - dx_dt=dx_dt_closure, - x0=x0, - t0=t0, - tf=tf, - num_sub_steps=num_sub_steps, - integrator_type=integrator_type, - ) - - # Integrate the state and compute optional auxiliary data over the horizon - out, aux = integrate_single_step_over_horizon( - integrator_single_step=integrator_single_step, t=t, x0=y0 - ) - - return (out, aux) if return_aux else out diff --git a/src/jaxsim/simulation/ode.py b/src/jaxsim/simulation/ode.py deleted file mode 100644 index 3dd9a429b..000000000 --- a/src/jaxsim/simulation/ode.py +++ /dev/null @@ -1,290 +0,0 @@ -from typing import Any, Dict, Tuple - -import jax -import jax.numpy as jnp -import numpy as np - -import jaxsim.typing as jtp -from jaxsim.physics import algos -from jaxsim.physics.algos.soft_contacts import ( - SoftContacts, - SoftContactsParams, - collidable_points_pos_vel, -) -from jaxsim.physics.algos.terrain import FlatTerrain, Terrain -from jaxsim.physics.model.physics_model import PhysicsModel - -from . import ode_data - - -def compute_contact_forces( - physics_model: PhysicsModel, - ode_state: ode_data.ODEState, - soft_contacts_params: SoftContactsParams = SoftContactsParams(), - terrain: Terrain = FlatTerrain(), -) -> Tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]: - """ - Compute the contact forces acting on the collidable points of the model. - - Args: - physics_model: The physics model to consider. - ode_state: The state of the ODE corresponding to the physics model. - soft_contacts_params: The parameters of the soft contacts model. - terrain: The terrain model. - - Returns: - A tuple containing: - - The contact forces expressed in the world frame acting on the model's links. - - The derivative of the tangential deformation of the terrain dynamics. - - The contact forces expressed in the world frame acting on the model's collidable points. - """ - - # Compute position and linear mixed velocity of all model's collidable points - # collidable_points_kinematics - pos_cp, vel_cp = collidable_points_pos_vel( - model=physics_model, - q=ode_state.physics_model.joint_positions, - qd=ode_state.physics_model.joint_velocities, - xfb=ode_state.physics_model.xfb(), - ) - - # Compute the forces acting on the collidable points due to contact with - # the compliant ground surface. Apply vmap to process all points together. - contact_forces_points, tangential_deformation_dot = jax.vmap( - SoftContacts(parameters=soft_contacts_params, terrain=terrain).contact_model - )(pos_cp.T, vel_cp.T, ode_state.soft_contacts.tangential_deformation.T) - - contact_forces_points = contact_forces_points.T - tangential_deformation_dot = tangential_deformation_dot.T - - # Initialize the contact forces, one per body - contact_forces_links = jnp.zeros_like( - ode_data.ODEInput.zero(physics_model).physics_model.f_ext - ) - - # Combine the contact forces of all collidable points belonging to the same body - for body_idx in set(physics_model.gc.body): - body_idx = int(body_idx) - contact_forces_links = contact_forces_links.at[body_idx, :].set( - jnp.sum(contact_forces_points[:, physics_model.gc.body == body_idx], axis=1) - ) - - return contact_forces_links, tangential_deformation_dot, contact_forces_points.T - - -def dx_dt( - x: ode_data.ODEState, - t: jtp.Float | None, - physics_model: PhysicsModel, - soft_contacts_params: SoftContactsParams = SoftContactsParams(), - ode_input: ode_data.ODEInput | None = None, - terrain: Terrain = FlatTerrain(), -) -> Tuple[ode_data.ODEState, Dict[str, Any]]: - """ - Compute the state derivative of the ODE corresponding to the physics model. - - Args: - x: The state of the ODE. - t: The current time. - physics_model: The physics model to consider. - soft_contacts_params: The parameters of the soft contacts model. - ode_input: The input of the ODE. - terrain: The terrain model. - - Returns: - A tuple containing: - - The state derivative of the ODE. - - A dictionary containing auxiliary information. - """ - - if t is not None and isinstance(t, np.ndarray) and t.size != 1: - raise ValueError(t.size) - - # Initialize arguments - ode_state = x - ode_input = ( - ode_input - if ode_input is not None - else ode_data.ODEInput.zero(physics_model=physics_model) - ) - - # ====================== - # Compute contact forces - # ====================== - - # Initialize the collidable points contact forces - contact_forces_points = None - - # Initialize the contact forces, one per body - contact_forces_links = jnp.zeros_like(ode_input.physics_model.f_ext) - - # Initialize the derivative of the tangential deformation - tangential_deformation_dot = jnp.zeros_like( - ode_state.soft_contacts.tangential_deformation - ) - - if len(physics_model.gc.body) > 0: - ( - contact_forces_links, - tangential_deformation_dot, - contact_forces_points, - ) = compute_contact_forces( - physics_model=physics_model, - soft_contacts_params=soft_contacts_params, - ode_state=ode_state, - terrain=terrain, - ) - - # ===================== - # Joint position limits - # ===================== - - if physics_model.dofs() > 0: - # Get the joint position limits - s_min, s_max = jnp.array( - [j.position_limit for j in physics_model.description.joints_dict.values()] - ).T - - # Get the spring/damper parameters of joint limits enforcement - k_damper = jnp.array(list(physics_model._joint_limit_damper.values())) - - # Compute the joint torques that enforce joint limits - s = ode_state.physics_model.joint_positions - tau_min = jnp.where(s <= s_min, k_damper * (s_min - s), 0) - tau_max = jnp.where(s >= s_max, k_damper * (s_max - s), 0) - tau_limit = tau_max + tau_min - - else: - tau_limit = jnp.zeros_like(ode_input.physics_model.tau) - - # ============== - # Joint friction - # ============== - - if physics_model.dofs() > 0: - # Static and viscous joint friction parameters - kc = jnp.array(list(physics_model._joint_friction_static.values())) - kv = jnp.array(list(physics_model._joint_friction_viscous.values())) - - # Compute the joint friction torque - tau_friction = -( - jnp.diag(kc) @ jnp.sign(ode_state.physics_model.joint_positions) - + jnp.diag(kv) @ ode_state.physics_model.joint_velocities - ) - - else: - tau_friction = jnp.zeros_like(ode_input.physics_model.tau) - - # ======================== - # Compute forward dynamics - # ======================== - - # Compute the total forces applied to the bodies - total_forces = ode_input.physics_model.f_ext + contact_forces_links - - # Compute the joint torques to actuate - tau = ode_input.physics_model.tau + tau_friction + tau_limit - - # Compute forward dynamics with the ABA algorithm - W_a_WB, qdd = algos.aba.aba( - model=physics_model, - xfb=ode_state.physics_model.xfb(), - q=ode_state.physics_model.joint_positions, - qd=ode_state.physics_model.joint_velocities, - tau=tau, - f_ext=total_forces, - ) - - # ========================================= - # Compute the state derivative of base link - # ========================================= - - if not physics_model.is_floating_base: - W_Qd_B = jnp.zeros(4) - BW_v_WB = jnp.zeros(3) - - else: - from jaxsim.math.conv import Convert - from jaxsim.math.quaternion import Quaternion - - W_Qd_B = Quaternion.derivative( - quaternion=ode_state.physics_model.base_quaternion, - omega=ode_state.physics_model.base_angular_velocity, - omega_in_body_fixed=False, - ).squeeze() - - # Compute linear component of mixed velocity - BW_v_WB = Convert.velocities_threed( - v_6d=jnp.hstack( - [ - ode_state.physics_model.base_linear_velocity, - ode_state.physics_model.base_angular_velocity, - ] - ), - p=ode_state.physics_model.base_position, - ).squeeze() - - # Derivative of xfb (floating-base state) - xd_fb = jnp.hstack([W_Qd_B, BW_v_WB, W_a_WB.squeeze()]).squeeze() - - # ===================================== - # Build the full derivative of ODEState - # ===================================== - - def fix_one_dof(vector: jtp.Vector) -> jtp.Vector | None: - """Fix the shape of computed quantities for models with just 1 DoF.""" - - if vector is None: - return None - - return jnp.array([vector]) if vector.shape == () else vector - - # Fill the PhysicsModelState object included in the input ODEState to store the - # returned PhysicsModelState derivative - physics_model_state_derivative = ode_state.physics_model.replace( - joint_positions=fix_one_dof(ode_state.physics_model.joint_velocities.squeeze()), - joint_velocities=fix_one_dof(qdd.squeeze()), - base_quaternion=xd_fb.squeeze()[0:4], - base_position=xd_fb.squeeze()[4:7], - base_angular_velocity=xd_fb.squeeze()[10:13], - base_linear_velocity=xd_fb.squeeze()[7:10], - ) - - # Fill the SoftContactsState object included in the input ODEState to store the - # returned SoftContactsState derivative - soft_contacts_state_derivative = ode_state.soft_contacts.replace( - tangential_deformation=tangential_deformation_dot.squeeze(), - ) - - # We store the state derivative using the ODEState class so that the pytree - # structure remains consistent, allowing to use our generic pytree integrators - state_derivative = ode_data.ODEState( - physics_model=physics_model_state_derivative, - soft_contacts=soft_contacts_state_derivative, - ) - - # =============================== - # Build auxiliary data and return - # =============================== - - # Real ODEInput containing the real joint forces that have been actuated and - # the total external forces (= original external forces + terrain + limits) - ode_input_real = ode_data.ODEInput( - physics_model=ode_data.PhysicsModelInput(tau=tau, f_ext=total_forces) - ) - - # Pack the inertial-fixed floating-base acceleration - W_nud_WB = jnp.hstack([W_a_WB.squeeze(), qdd.squeeze()]) - - # Build the auxiliary data - aux_dict = { - "model_acceleration": W_nud_WB, - "ode_input": ode_input, - "ode_input_real": ode_input_real, - "contact_forces_links": contact_forces_links, - "contact_forces_points": contact_forces_points, - "tangential_deformation_dot": tangential_deformation_dot, - } - - # Return the state derivative as a generic PyTree, and the dict with auxiliary info - return state_derivative, aux_dict diff --git a/src/jaxsim/simulation/ode_integration.py b/src/jaxsim/simulation/ode_integration.py deleted file mode 100644 index 90ad227ac..000000000 --- a/src/jaxsim/simulation/ode_integration.py +++ /dev/null @@ -1,62 +0,0 @@ -import enum -import functools -from typing import Any, Dict, Tuple, Union - -import jax.flatten_util -from jax.experimental.ode import odeint - -import jaxsim.typing as jtp -from jaxsim.physics.algos.soft_contacts import SoftContactsParams -from jaxsim.physics.algos.terrain import FlatTerrain, Terrain -from jaxsim.physics.model.physics_model import PhysicsModel -from jaxsim.simulation import integrators, ode -from jaxsim.simulation.integrators import IntegratorType - - -@jax.jit -def ode_integration_rk4_adaptive( - x0: jtp.Array, - t: integrators.TimeHorizon, - physics_model: PhysicsModel, - *args, - **kwargs, -) -> jtp.Array: - # Close function over its inputs and parameters - dx_dt_closure = lambda x, ts: ode.dx_dt(x, ts, physics_model, *args) - - return odeint(dx_dt_closure, x0, t, **kwargs) - - -@functools.partial( - jax.jit, static_argnames=["num_sub_steps", "integrator_type", "return_aux"] -) -def ode_integration_fixed_step( - x0: ode.ode_data.ODEState, - t: integrators.TimeHorizon, - physics_model: PhysicsModel, - integrator_type: IntegratorType, - soft_contacts_params: SoftContactsParams = SoftContactsParams(), - terrain: Terrain = FlatTerrain(), - ode_input: ode.ode_data.ODEInput | None = None, - *args, - num_sub_steps: int = 1, - return_aux: bool = False, -) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict]]: - # Close func over additional inputs and parameters - dx_dt_closure = lambda x, ts: ode.dx_dt( - x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args - ) - - # Integrate over the horizon - out = integrators.odeint( - func=dx_dt_closure, - y0=x0, - t=t, - num_sub_steps=num_sub_steps, - return_aux=return_aux, - integrator_type=integrator_type, - ) - - # Return output pytree and, optionally, the aux dict - state = out if not return_aux else out[0] - return (state, out[1]) if return_aux else state diff --git a/src/jaxsim/simulation/simulator.py b/src/jaxsim/simulation/simulator.py deleted file mode 100644 index ac51b6ee8..000000000 --- a/src/jaxsim/simulation/simulator.py +++ /dev/null @@ -1,543 +0,0 @@ -import dataclasses -import functools -import pathlib -from typing import Dict, List, Optional, Union - -try: - from typing import Self -except ImportError: - from typing_extensions import Self - -import jax -import jax.numpy as jnp -import jax_dataclasses -import rod -from jax_dataclasses import Static - -import jaxsim.high_level -import jaxsim.physics -import jaxsim.typing as jtp -from jaxsim import logging -from jaxsim.high_level.common import VelRepr -from jaxsim.high_level.model import Model, StepData -from jaxsim.parsers import descriptions -from jaxsim.physics.algos.soft_contacts import SoftContactsParams -from jaxsim.physics.algos.terrain import FlatTerrain, Terrain -from jaxsim.physics.model.physics_model import PhysicsModel -from jaxsim.utils import Mutability, Vmappable, oop - -from . import simulator_callbacks as scb -from .ode_integration import IntegratorType - - -@jax_dataclasses.pytree_dataclass -class SimulatorData(Vmappable): - """ - Data used by the simulator. - - It can be used as JaxSim state in a functional programming style. - """ - - # Simulation time stored in ns in order to prevent floats approximation - time_ns: jtp.Int = dataclasses.field( - default_factory=lambda: jnp.array(0, dtype=jnp.uint64) - ) - - # Terrain and contact parameters - terrain: Terrain = dataclasses.field(default_factory=lambda: FlatTerrain()) - contact_parameters: SoftContactsParams = dataclasses.field( - default_factory=lambda: SoftContactsParams() - ) - - # Dictionary containing all handled models - models: Dict[str, Model] = dataclasses.field(default_factory=dict) - - # Default gravity vector (could be overridden for individual models) - gravity: jtp.Vector = dataclasses.field( - default_factory=lambda: jaxsim.physics.default_gravity() - ) - - -@jax_dataclasses.pytree_dataclass -class JaxSim(Vmappable): - """The JaxSim simulator.""" - - # Step size stored in ns in order to prevent floats approximation - step_size_ns: Static[jtp.Int] = dataclasses.field( - default_factory=lambda: jnp.array(1_000_000, dtype=jnp.uint64) - ) - - # Number of sub-steps performed at each integration step. - # Note: there is no collision detection performed in sub-steps. - steps_per_run: Static[jtp.Int] = dataclasses.field(default=1) - - # Default velocity representation (could be overridden for individual models) - velocity_representation: Static[VelRepr] = dataclasses.field( - default=VelRepr.Inertial - ) - - # Integrator type - integrator_type: Static[IntegratorType] = dataclasses.field( - default=IntegratorType.EulerForward - ) - - # Simulator data - data: SimulatorData = dataclasses.field(default_factory=lambda: SimulatorData()) - - @staticmethod - def build( - step_size: jtp.Float, - steps_per_run: jtp.Int = 1, - velocity_representation: VelRepr = VelRepr.Inertial, - integrator_type: IntegratorType = IntegratorType.EulerSemiImplicit, - simulator_data: SimulatorData | None = None, - ) -> "JaxSim": - """ - Build a JaxSim simulator object. - - Args: - step_size: The integration step size in seconds. - steps_per_run: Number of sub-steps performed at each integration step. - velocity_representation: Default velocity representation of simulated models. - integrator_type: Type of integrator used for integrating the equations of motion. - simulator_data: Optional simulator data to initialize the simulator state. - - Returns: - The JaxSim simulator object. - """ - - return JaxSim( - step_size_ns=jnp.array(step_size * 1e9, dtype=jnp.uint64), - steps_per_run=int(steps_per_run), - velocity_representation=velocity_representation, - integrator_type=integrator_type, - data=simulator_data if simulator_data is not None else SimulatorData(), - ) - - @functools.partial( - oop.jax_tf.method_rw, static_argnames=["remove_models"], validate=False - ) - def reset(self, remove_models: bool = True) -> None: - """ - Reset the simulator. - - Args: - remove_models: Flag indicating whether to remove all models from the simulator. - If False, the models are kept but their state is reset. - """ - - self.data.time_ns = jnp.zeros_like(self.data.time_ns) - - if remove_models: - self.data.models = {} - else: - _ = [m.zero() for m in self.models()] - - @functools.partial(oop.jax_tf.method_rw, jit=False) - def set_step_size(self, step_size: float) -> None: - """ - Set the integration step size. - - Args: - step_size: The integration step size in seconds. - """ - - self.step_size_ns = jnp.array(step_size * 1e9, dtype=jnp.uint64) - - @functools.partial(oop.jax_tf.method_ro, jit=False) - def step_size(self) -> jtp.Float: - """ - Get the integration step size. - - Returns: - The integration step size in seconds. - """ - - return jnp.array(self.step_size_ns / 1e9, dtype=float) - - @functools.partial(oop.jax_tf.method_ro) - def dt(self) -> jtp.Float: - """ - Return the integration step size in seconds. - - Returns: - The integration step size in seconds. - """ - - return jnp.array((self.step_size_ns * self.steps_per_run) / 1e9, dtype=float) - - @functools.partial(oop.jax_tf.method_ro) - def time(self) -> jtp.Float: - """ - Return the current simulation time in seconds. - - Returns: - The current simulation time in seconds. - """ - - return jnp.array(self.data.time_ns / 1e9, dtype=float) - - @functools.partial(oop.jax_tf.method_ro) - def gravity(self) -> jtp.Vector: - """ - Return the 3D gravity vector. - - Returns: - The 3D gravity vector. - """ - - return jnp.array(self.data.gravity, dtype=float) - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def model_names(self) -> tuple[str, ...]: - """ - Return the list of model names. - - Returns: - The list of model names. - """ - - return tuple(self.data.models.keys()) - - @functools.partial( - oop.jax_tf.method_ro, static_argnames=["model_name"], jit=False, vmap=False - ) - def get_model(self, model_name: str) -> Model: - """ - Return the model with the given name. - - Args: - model_name: The name of the model to return. - - Returns: - The model with the given name. - """ - - if model_name not in self.data.models: - raise ValueError(f"Failed to find model '{model_name}'") - - return self.data.models[model_name] - - @functools.partial(oop.jax_tf.method_ro, jit=False, vmap=False) - def models(self, model_names: tuple[str, ...] | None = None) -> tuple[Model, ...]: - """ - Return the simulated models. - - Args: - model_names: Optional list of model names to return. - If None, all models are returned. - - Returns: - The list of simulated models. - """ - - model_names = model_names if model_names is not None else self.model_names() - return tuple(self.data.models[name] for name in model_names) - - @functools.partial(oop.jax_tf.method_rw) - def set_gravity(self, gravity: jtp.Vector) -> None: - """ - Set the gravity vector to all the simulated models. - - Args: - gravity: The 3D gravity vector. - """ - - gravity = jnp.array(gravity, dtype=float) - - if gravity.size != 3: - raise ValueError(gravity) - - self.data.gravity = gravity - - for model in self.data.models.values(): - model.physics_model.set_gravity(gravity=gravity) - - @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False) - def insert_model_from_description( - self, - model_description: Union[pathlib.Path, str, rod.Model], - model_name: str | None = None, - considered_joints: List[str] | None = None, - ) -> Model: - """ - Insert a model from a model description. - - Args: - model_description: A path to an SDF/URDF file, a string containing its content, or a pre-parsed/pre-built rod model. - model_name: The optional name of the model that overrides the one in the description. - considered_joints: Optional list of joints to consider. It is also useful to specify the joint serialization. - - Returns: - The newly inserted model. - """ - - if self.vectorized: - raise RuntimeError("Cannot insert a model in a vectorized simulation") - - # Build the model from the given model description - model = jaxsim.high_level.model.Model.build_from_model_description( - model_description=model_description, - model_name=model_name, - vel_repr=self.velocity_representation, - considered_joints=considered_joints, - ) - - # Make sure the model is not already part of the simulation - if model.name() in self.model_names(): - msg = f"Model '{model.name()}' is already part of the simulation" - raise ValueError(msg) - - # Insert the model - self.data.models[model.name()] = model - - # Return the newly inserted model - return self.data.models[model.name()] - - @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False) - def insert_model_from_sdf( - self, - sdf: Union[pathlib.Path, str], - model_name: str | None = None, - considered_joints: List[str] | None = None, - ) -> Model: - """ - Insert a model from an SDF resource. - """ - - msg = "JaxSim.{} is deprecated, use JaxSim.{} instead." - logging.warning( - msg=msg.format("insert_model_from_sdf", "insert_model_from_description") - ) - - return self.insert_model_from_description( - model_description=sdf, - model_name=model_name, - considered_joints=considered_joints, - ) - - @functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False) - def insert_model( - self, - model_description: descriptions.ModelDescription, - model_name: str | None = None, - ) -> Model: - """ - Insert a model from a model description object. - - Args: - model_description: The model description object. - model_name: Optional name of the model to insert. - - Returns: - The newly inserted model. - """ - - if self.vectorized: - raise RuntimeError("Cannot insert a model in a vectorized simulation") - - model_name = model_name if model_name is not None else model_description.name - - if model_name in self.model_names(): - msg = f"Model '{model_name}' is already part of the simulation" - raise ValueError(msg) - - # Build the physics model the model description - physics_model = PhysicsModel.build_from( - model_description=model_description, gravity=self.gravity() - ) - - # Build the high-level model from the physics model - model = jaxsim.high_level.model.Model.build( - model_name=model_name, - physics_model=physics_model, - vel_repr=self.velocity_representation, - ) - - # Insert the model into the simulators - self.data.models[model.name()] = model - - # Return the newly inserted model - return self.data.models[model.name()] - - @functools.partial( - oop.jax_tf.method_rw, - jit=False, - validate=False, - static_argnames=["model_name"], - ) - def remove_model(self, model_name: str) -> None: - """ - Remove a model from the simulator. - - Args: - model_name: The name of the model to remove. - """ - - if model_name not in self.model_names(): - msg = f"Model '{model_name}' is not part of the simulation" - raise ValueError(msg) - - _ = self.data.models.pop(model_name) - - @functools.partial(oop.jax_tf.method_rw, vmap_in_axes=(0, None)) - def step(self, clear_inputs: bool = False) -> Dict[str, StepData]: - """ - Advance the simulation by one step. - - Args: - clear_inputs: Zero the inputs of the models after the integration. - - Returns: - A dictionary containing the StepData of all models. - """ - - # Compute the initial and final time of the integration as integers - t0_ns = jnp.array(self.data.time_ns, dtype=jnp.uint64) - dt_ns = jnp.array(self.step_size_ns * self.steps_per_run, dtype=jnp.uint64) - - # Compute the final time using integer arithmetics - tf_ns = t0_ns + dt_ns - - # We collect the StepData of all models - step_data = {} - - for model in self.models(): - # Integrate individually all models and collect their StepData. - # We use the context manager to make sure that the PyTree of the models - # never changes, so that it never triggers JIT recompilations. - with model.editable(validate=True) as integrated_model: - step_data[model.name()] = integrated_model.integrate( - t0=jnp.array(t0_ns, dtype=float) / 1e9, - tf=jnp.array(tf_ns, dtype=float) / 1e9, - sub_steps=self.steps_per_run, - integrator_type=self.integrator_type, - terrain=self.data.terrain, - contact_parameters=self.data.contact_parameters, - clear_inputs=clear_inputs, - ) - - self.data.models[model.name()].data = integrated_model.data - - # Store the final time - self.data.time_ns += dt_ns - - return step_data - - @functools.partial( - oop.jax_tf.method_ro, - static_argnames=["horizon_steps"], - vmap_in_axes=(0, None, 0, None), - ) - def step_over_horizon( - self, - horizon_steps: jtp.Int, - callback_handler: ( - Union["scb.SimulatorCallback", "scb.CallbackHandler"] | None - ) = None, - clear_inputs: jtp.Bool = False, - ) -> Union[ - "JaxSim", - tuple["JaxSim", tuple["scb.SimulatorCallback", tuple[jtp.PyTree, jtp.PyTree]]], - ]: - """ - Advance the simulation by a given number of steps. - - Args: - horizon_steps: The number of steps to advance the simulation. - callback_handler: A callback handler to inject custom login in the simulation loop. - clear_inputs: Zero the inputs of the models after the integration. - - Returns: - The updated simulator if no callback handler is provided, otherwise a tuple - containing the updated simulator and a tuple containing callback data. - The optional callback data is a tuple containing the updated callback object, - the produced pre-step output, and the produced post-step output. - """ - - # Process a mutable copy of the simulator - original_mutability = self._mutability() - sim = self.copy().mutable(validate=True) - - # Helper to get callbacks from the handler - get_cb = lambda h, cb_name: ( - getattr(h, cb_name) if h is not None and hasattr(h, cb_name) else None - ) - - # Get the callbacks - configure_cb: Optional[scb.ConfigureCallbackSignature] = get_cb( - h=callback_handler, cb_name="configure_cb" - ) - pre_step_cb: Optional[scb.PreStepCallbackSignature] = get_cb( - h=callback_handler, cb_name="pre_step_cb" - ) - post_step_cb: Optional[scb.PostStepCallbackSignature] = get_cb( - h=callback_handler, cb_name="post_step_cb" - ) - - # Callback: configuration - sim = configure_cb(sim) if configure_cb is not None else sim - - # Initialize the carry - Carry = tuple[JaxSim, scb.CallbackHandler] - carry_init: Carry = (sim, callback_handler) - - def body_fun( - carry: Carry, xs: None - ) -> tuple[Carry, tuple[jtp.PyTree, jtp.PyTree]]: - sim, callback_handler = carry - - # Make sure to pass a mutable version of the simulator to the callbacks - sim = sim.mutable(validate=True) - - # Callback: pre-step - sim, out_pre_step = ( - pre_step_cb(sim) if pre_step_cb is not None else (sim, None) - ) - - # Integrate all models - step_data = sim.step(clear_inputs=clear_inputs) - - # Callback: post-step - sim, out_post_step = ( - post_step_cb(sim, step_data) - if post_step_cb is not None - else (sim, None) - ) - - # Pack the carry - carry = (sim, callback_handler) - - return carry, (out_pre_step, out_post_step) - - # Integrate over the given horizon - (sim, callback_handler), ( - out_pre_step_horizon, - out_post_step_horizon, - ) = jax.lax.scan(f=body_fun, init=carry_init, xs=None, length=horizon_steps) - - # Enforce original mutability of the entire simulator - sim._set_mutability(original_mutability) - - return ( - sim - if callback_handler is None - else ( - sim, - (callback_handler, (out_pre_step_horizon, out_post_step_horizon)), - ) - ) - - def vectorize(self: Self, batch_size: int) -> Self: - """ - Inherit docs. - """ - - jaxsim_vec: JaxSim = super().vectorize(batch_size=batch_size) # noqa - - # We need to manually specify the batch size of the handled models - with jaxsim_vec.mutable_context(mutability=Mutability.MUTABLE): - for model in jaxsim_vec.models(): - model.batch_size = batch_size - - return jaxsim_vec diff --git a/src/jaxsim/simulation/simulator_callbacks.py b/src/jaxsim/simulation/simulator_callbacks.py deleted file mode 100644 index 9de5f09e3..000000000 --- a/src/jaxsim/simulation/simulator_callbacks.py +++ /dev/null @@ -1,79 +0,0 @@ -import abc -from typing import Callable, Dict, Tuple - -import jaxsim.typing as jtp -from jaxsim.high_level.model import StepData - -ConfigureCallbackSignature = Callable[["jaxsim.JaxSim"], "jaxsim.JaxSim"] -PreStepCallbackSignature = Callable[ - ["jaxsim.JaxSim"], Tuple["jaxsim.JaxSim", jtp.PyTree] -] -PostStepCallbackSignature = Callable[ - ["jaxsim.JaxSim", Dict[str, StepData]], Tuple["jaxsim.JaxSim", jtp.PyTree] -] - - -class SimulatorCallback(abc.ABC): - """ - A base class for simulator callbacks. - """ - - pass - - -class ConfigureCallback(SimulatorCallback): - """ - A callback class to define logic for configuring the simulator before taking the first step. - """ - - @property - def configure_cb(self) -> ConfigureCallbackSignature: - return lambda sim: self.configure(sim=sim) - - @abc.abstractmethod - def configure(self, sim: "jaxsim.JaxSim") -> "jaxsim.JaxSim": - pass - - -class PreStepCallback(SimulatorCallback): - """ - A callback class for performing actions before each simulation step. - """ - - @property - def pre_step_cb(self) -> PreStepCallbackSignature: - return lambda sim: self.pre_step(sim=sim) - - @abc.abstractmethod - def pre_step(self, sim: "jaxsim.JaxSim") -> Tuple["jaxsim.JaxSim", jtp.PyTree]: - pass - - -class PostStepCallback(SimulatorCallback): - """ - A callback class for performing actions after each simulation step. - """ - - @property - def post_step_cb(self) -> PostStepCallbackSignature: - return lambda sim, step_data: self.post_step(sim=sim, step_data=step_data) - - @abc.abstractmethod - def post_step( - self, sim: "jaxsim.JaxSim", step_data: Dict[str, StepData] - ) -> Tuple["jaxsim.JaxSim", jtp.PyTree]: - pass - - -class CallbackHandler(ConfigureCallback, PreStepCallback, PostStepCallback): - """ - A class that handles callbacks for the simulator. - - Note: - The are different simulation stages with associated callbacks: - - `configure`: runs before the first step is taken. - - `pre_step`: runs at each step before integrating the dynamics and advancing the time. - - `post_step`: runs at each step after the integration of the dynamics. - """ - - pass diff --git a/src/jaxsim/simulation/utils.py b/src/jaxsim/simulation/utils.py deleted file mode 100644 index d03d8c39d..000000000 --- a/src/jaxsim/simulation/utils.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Tuple - -from jaxsim import logging - - -def check_valid_shape( - what: str, shape: Tuple, expected_shape: Tuple, valid: bool -) -> bool: - valid_shape = shape == expected_shape - - if not valid_shape: - logging.debug(f"Shape of {what} differs: {shape}, {expected_shape}") - return False - - return valid diff --git a/src/jaxsim/utils/__init__.py b/src/jaxsim/utils/__init__.py index 0e9509c29..8d55d4ecc 100644 --- a/src/jaxsim/utils/__init__.py +++ b/src/jaxsim/utils/__init__.py @@ -3,7 +3,3 @@ from .hashless import HashlessObject from .jaxsim_dataclass import JaxsimDataclass from .tracing import not_tracing, tracing -from .vmappable import Vmappable - -# Leave this below the others to prevent circular imports -from .oop import jax_tf # isort: skip diff --git a/src/jaxsim/utils/oop.py b/src/jaxsim/utils/oop.py deleted file mode 100644 index f3169524a..000000000 --- a/src/jaxsim/utils/oop.py +++ /dev/null @@ -1,536 +0,0 @@ -import contextlib -import dataclasses -import functools -import inspect -import os -from typing import Any, Callable, Generator, TypeVar - -import jax -import jax.flatten_util -from typing_extensions import ParamSpec - -from jaxsim import logging -from jaxsim.utils import tracing - -from . import Mutability, Vmappable - -_P = ParamSpec("_P") -_R = TypeVar("_R") - - -class jax_tf: - """ - Class containing decorators applicable to methods of Vmappable objects. - """ - - # Environment variables that can be used to disable the transformations - EnvVarOOP: str = "JAXSIM_OOP_DECORATORS" - EnvVarJitOOP: str = "JAXSIM_OOP_DECORATORS_JIT" - EnvVarVmapOOP: str = "JAXSIM_OOP_DECORATORS_VMAP" - EnvVarCacheOOP: str = "JAXSIM_OOP_DECORATORS_CACHE" - - @staticmethod - def method_ro( - fn: Callable[_P, _R], - jit: bool = True, - static_argnames: tuple[str, ...] | list[str] = (), - vmap: bool | None = None, - vmap_in_axes: tuple[int, ...] | int | None = None, - vmap_out_axes: tuple[int, ...] | int | None = None, - ) -> Callable[_P, _R]: - """ - Decorator for r/o methods of classes inheriting from Vmappable. - """ - - return jax_tf.method( - fn=fn, - read_only=True, - validate=True, - jit_enabled=jit, - static_argnames=static_argnames, - vmap_enabled=vmap, - vmap_in_axes=vmap_in_axes, - vmap_out_axes=vmap_out_axes, - ) - - @staticmethod - def method_rw( - fn: Callable[_P, _R], - validate: bool = True, - jit: bool = True, - static_argnames: tuple[str, ...] | list[str] = (), - vmap: bool | None = None, - vmap_in_axes: tuple[int, ...] | int | None = None, - vmap_out_axes: tuple[int, ...] | int | None = None, - ) -> Callable[_P, _R]: - """ - Decorator for r/w methods of classes inheriting from Vmappable. - """ - - return jax_tf.method( - fn=fn, - read_only=False, - validate=validate, - jit_enabled=jit, - static_argnames=static_argnames, - vmap_enabled=vmap, - vmap_in_axes=vmap_in_axes, - vmap_out_axes=vmap_out_axes, - ) - - @staticmethod - def method( - fn: Callable[_P, _R], - read_only: bool = True, - validate: bool = True, - jit_enabled: bool = True, - static_argnames: tuple[str, ...] | list[str] = (), - vmap_enabled: bool | None = None, - vmap_in_axes: tuple[int, ...] | int | None = None, - vmap_out_axes: tuple[int, ...] | int | None = None, - ): - """ - Decorator for methods of classes inheriting from Vmappable. - - This decorator enables executing the methods on an object characterized by a - desired mutability, that is selected considering the r/o and validation flags. - It also allows to transform the method with the jit/vmap transformations. - If the Vmappable object is vectorized, the method is automatically vmapped, and - the in_axes are properly post-processed to simplify the combination with jit. - - Args: - fn: The method to decorate. - read_only: Whether the method operates on a read-only object. - validate: Whether r/w methods should preserve the pytree structure. - jit_enabled: Whether to apply the jit transformation. - static_argnames: The names of the arguments that should be static. - vmap_enabled: Whether to apply the vmap transformation. - vmap_in_axes: The in_axes to use for the vmap transformation. - vmap_out_axes: The out_axes to use for the vmap transformation. - - Returns: - The decorated method. - """ - - @functools.wraps(fn) - def wrapper(*args: _P.args, **kwargs: _P.kwargs): - """The wrapper function that is returned by this decorator.""" - - # Methods of classes inheriting from Vmappable decorated by this wrapper - # automatically support jit/vmap/mutability features when called standalone. - # However, when objects are arguments of plain functions transformed with - # jit/vmap, and decorated methods are called inside those functions, we need - # to disable this decorator to avoid double wrapping and execution errors. - # We do so by iterating over the arguments, and checking whether they are - # being traced by JAX. - for argument in list(args) + list(kwargs.values()): - try: - argument_flat, _ = jax.flatten_util.ravel_pytree(argument) - - if tracing(argument_flat): - return fn(*args, **kwargs) - except: - continue - - # =============================================================== - # Wrap fn so that jit/vmap/mutability transformations are applied - # =============================================================== - - # Initialize the mutability of the instance over which the method is running. - # * In r/o methods, this approach prevents any type of mutation. - # * In r/w methods, this approach allows to catch early JIT recompilations - # caused by unwanted changes in the pytree structure. - if read_only: - mutability = Mutability.FROZEN - else: - mutability = ( - Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION - ) - - # Extract the class instance over which fn is called - instance: Vmappable = args[0] - assert isinstance(instance, Vmappable) - - # Save the original mutability - original_mutability = instance._mutability() - - # Inspect the environment to detect whether to enforce disabling jit/vmap - deco_on = jax_tf.env_var_on(jax_tf.EnvVarOOP) - jit_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarJitOOP) and deco_on - vmap_enabled_env = jax_tf.env_var_on(jax_tf.EnvVarVmapOOP) and deco_on - - # Allow disabling the cache of jit-compiled functions. - # It can be useful for debugging or testing purposes. - wrap_fn = ( - jax_tf.wrap_fn - if jax_tf.env_var_on(jax_tf.EnvVarCacheOOP) and deco_on - else jax_tf.wrap_fn.__wrapped__ - ) - - # Get the transformed function (possibly cached by functools.cache). - # Note that all the arguments of the following methods, when hashed, should - # uniquely identify the returned function so that a new function is built - # when arguments change and either jit or vmap have to be called again. - fn_db = wrap_fn( - fn=fn, # noqa - mutability=mutability, - jit=jit_enabled_env and jit_enabled, - static_argnames=tuple(static_argnames), - vmap=vmap_enabled_env - and ( - vmap_enabled is True - or (vmap_enabled is None and instance.vectorized) - ), - in_axes=vmap_in_axes, - out_axes=vmap_out_axes, - ) - - # Call the transformed (mutable/jit/vmap) method - out, obj = fn_db(*args, **kwargs) - - if read_only: - # Restore the original mutability - instance._set_mutability(mutability=original_mutability) - - return out - - # ================================================================= - # From here we assume that the wrapper is operating on a r/w method - # ================================================================= - - from jax_dataclasses._dataclasses import JDC_STATIC_MARKER - - # Select the right runtime mutability. The only difference here is when a r/w - # method is called on a frozen object. In this case, we enable updating the - # pytree data and preserve its structure only if validation is enabled. - mutability_dict = { - Mutability.MUTABLE_NO_VALIDATION: Mutability.MUTABLE_NO_VALIDATION, - Mutability.MUTABLE: Mutability.MUTABLE, - Mutability.FROZEN: ( - Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION - ), - } - - # We need to replace all the dynamic leafs of the original instance with those - # computed by the functional transformation. - # We do so by iterating over the fields of the jax_dataclasses and ignoring - # all the fields that are marked as static. - # Caveats: https://github.com/ami-iit/jaxsim/pull/48#issuecomment-1746635121. - with instance.mutable_context( - mutability=mutability_dict[instance._mutability()] - ): - for f in dataclasses.fields(instance): # noqa - if ( - hasattr(f, "type") - and hasattr(f.type, "__metadata__") - and JDC_STATIC_MARKER in f.type.__metadata__ - ): - continue - - try: - setattr(instance, f.name, getattr(obj, f.name)) - except AssertionError as exc: - logging.debug(f"Old object:\n{getattr(instance, f.name)}") - logging.debug(f"New object:\n{getattr(obj, f.name)}") - raise RuntimeError( - f"Failed to update field '{f.name}'" - ) from exc - - return out - - return wrapper - - @staticmethod - @functools.cache - def wrap_fn( - fn: Callable, - mutability: Mutability, - jit: bool, - static_argnames: tuple[str, ...] | list[str], - vmap: bool, - in_axes: tuple[int, ...] | int | None, - out_axes: tuple[int, ...] | int | None, - ) -> Callable: - """ - Transform a method with jit/vmap and execute it on an object characterized - by the desired mutability. - - Note: - The method should take the object (self) as first argument. - - Note: - This returned transformed method is cached by considering the hash of all - the arguments. It will re-apply jit/vmap transformations only if needed. - - Args: - fn: The method to consider. - mutability: The mutability of the object on which the method is called. - jit: Whether to apply jit transformations. - static_argnames: The names of the arguments that should be considered static. - vmap: Whether to apply vmap transformations. - in_axes: The axes along which to vmap input arguments. - out_axes: The axes along which to vmap output arguments. - - Note: - In order to simplify the application of vmap, we close the method arguments - over all the non-mapped input arguments. Furthermore, for improving the - compatibility with jit, we also close the vmap application over the static - arguments. - - Returns: - The transformed method operating on an object with the desired mutability. - We maintain the same signature of the original method. - """ - - # Extract the signature of the function - sig = inspect.signature(fn) - - # All static arguments must be actual arguments of fn - for name in static_argnames: - if name not in sig.parameters: - raise ValueError(f"Static argument '{name}' not found in {fn}") - - # If in_axes is a tuple, its dimension should match the number of arguments - if isinstance(in_axes, tuple) and len(in_axes) != len(sig.parameters): - msg = "The length of 'in_axes' must match the number of arguments ({})" - raise ValueError(msg.format(len(sig.parameters))) - - # Check that static arguments are not mapped with vmap. - # This case would not work since static arguments are not traces and vmap need - # to trace arguments in order to map them. - if isinstance(in_axes, tuple): - for mapped_axis, arg_name in zip(in_axes, sig.parameters.keys()): - if mapped_axis is not None and arg_name in static_argnames: - raise ValueError( - f"Static argument '{arg_name}' cannot be mapped with vmap" - ) - - def fn_tf_vmap(*args, function_to_vmap: Callable, **kwargs): - """Wrapper applying the vmap transformation""" - - # Canonicalize the arguments so that all of them are kwargs - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - - # Build a dictionary mapping all arguments to a mapped axis, even when - # the None is passed (defaults to in_axes=0) or and int is passed (defaults - # to in_axes=). - match in_axes: - case None: - argname_to_mapped_axis = {name: 0 for name in bound.arguments} - case tuple(): - argname_to_mapped_axis = { - name: in_axes[i] for i, name in enumerate(bound.arguments) - } - case int(): - argname_to_mapped_axis = {name: in_axes for name in bound.arguments} - case _: - raise ValueError(in_axes) - - # Build a dictionary (argument_name -> argument) for all mapped arguments. - # Note that a mapped argument is an argument whose axis is not None and - # is not a static jit argument. - vmap_mapped_args = { - arg: value - for arg, value in bound.arguments.items() - if argname_to_mapped_axis[arg] is not None - and arg not in static_argnames - } - - # Build a dictionary (argument_name -> argument) for all unmapped arguments - vmap_unmapped_args = { - arg: value - for arg, value in bound.arguments.items() - if arg not in vmap_mapped_args - } - - # Disable mapping of non-vectorized default arguments - for arg, value in argname_to_mapped_axis.items(): - if arg in vmap_mapped_args and value == sig.parameters[arg].default: - logging.debug(f"Disabling vmapping of default argument '{arg}'") - argname_to_mapped_axis[arg] = None - - # Close the function over the unmapped arguments of vmap - fn_closed = lambda *mapped_args: function_to_vmap( - **vmap_unmapped_args, **dict(zip(vmap_mapped_args.keys(), mapped_args)) - ) - - # Create the in_axes tuple of only the mapped arguments - in_axes_mapped = tuple( - argname_to_mapped_axis[name] for name in vmap_mapped_args - ) - - # If all in_axes are the same, simplify in_axes tuple to be just an integer - if len(set(in_axes_mapped)) == 1: - in_axes_mapped = list(set(in_axes_mapped))[0] - - # If, instead, in_axes has different elements, we need to replace the mapped - # axis of "self" with a pytree having as leafs the mapped axis. - # This is because the vmap in_axes specification must be a tree prefix of - # the corresponding value. - if isinstance(in_axes_mapped, tuple) and "self" in vmap_mapped_args: - argname_to_mapped_axis["self"] = jax.tree_util.tree_map( - lambda _: argname_to_mapped_axis["self"], vmap_mapped_args["self"] - ) - in_axes_mapped = tuple( - argname_to_mapped_axis[name] for name in vmap_mapped_args - ) - - # Apply the vmap transformation and call the function passing only the - # mapped arguments. The unmapped arguments have been closed over. - # Note: we altered the "in_axes" tuple so that it does not have any - # None elements. - # Note: if "in_axes_mapped" is a tuple, the following fails if we pass kwargs, - # we need to pass the unpacked args tuple instead. - return jax.vmap( - fn_closed, - in_axes=in_axes_mapped, - **dict(out_axes=out_axes) if out_axes is not None else {}, - )(*list(vmap_mapped_args.values())) - - def fn_tf_jit(*args, function_to_jit: Callable, **kwargs): - """Wrapper applying the jit transformation""" - - # Canonicalize the arguments so that all of them are kwargs - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - - # Apply the jit transformation and call the function passing all arguments - # as keyword arguments - return jax.jit(function_to_jit, static_argnames=static_argnames)( - **bound.arguments - ) - - # First applied wrapper that executes fn in a mutable context - fn_mutable = functools.partial( - jax_tf.call_class_method_in_mutable_context, - fn=fn, - jit=jit, - mutability=mutability, - ) - - # Second applied wrapper that transforms fn with vmap - fn_vmap = ( - fn_mutable - if not vmap - else functools.partial(fn_tf_vmap, function_to_vmap=fn_mutable) - ) - - # Third applied wrapper that transforms fn with jit - fn_jit_vmap = ( - fn_vmap - if not jit - else functools.partial(fn_tf_jit, function_to_jit=fn_vmap) - ) - - return fn_jit_vmap - - @staticmethod - def call_class_method_in_mutable_context( - *args, fn: Callable, jit: bool, mutability: Mutability, **kwargs - ) -> tuple[Any, Vmappable]: - """ - Wrapper to call a method on an object with the desired mutable context. - - Args: - fn: The method to call. - jit: Whether the method is being jit compiled or not. - mutability: The desired mutability context. - *args: The positional arguments to pass to the method (including self). - **kwargs: The keyword arguments to pass to the method. - - Returns: - A tuple containing the return value of the method and the object - possibly updated by the method if it is in read-write. - - Note: - This approach enables to jit-compile methods of a stateful object without - leaking traces, therefore obtaining a jax-compatible OOP pattern. - """ - - # Log here whether the method is being jit compiled or not. - # This log message does not get printed from compiled code, so here is the - # most appropriate place to be sure that we log it correctly. - if jit: - logging.debug(msg=f"JIT compiling {fn}") - - # Canonicalize the arguments so that all of them are kwargs - sig = inspect.signature(fn) - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - - # Extract the class instance over which fn is called - instance: Vmappable = bound.arguments["self"] - - # Select the right mutability. If the instance is mutable with validation - # disabled, we override the input mutability so that we do not fail in case - # of mismatched tree structure. - mut = ( - Mutability.MUTABLE_NO_VALIDATION - if instance._mutability() is Mutability.MUTABLE_NO_VALIDATION - else mutability - ) - - # Call fn in a mutable context - with instance.mutable_context(mutability=mut): - # Methods could call other decorated methods. When it happens, the decorator - # of the called method is invoked, that applies jit and vmap transformations. - # This is not desired as it calls vmap inside an already vmapped method. - # We work around this occurrence by disabling the jit/vmap decorators of all - # methods called inside fn through a context manager. - # Note that we already work around this in the beginning of the wrapper - # function by detecting traced arguments, but the decorator works also - # when jit=False and vmap=False, therefore only enforcing the mutability. - with jax_tf.disabled_oop_decorators(): - out = fn(**bound.arguments) - - return out, instance - - @staticmethod - def env_var_on(var_name: str, default_value: str = "1") -> bool: - """ - Check whether an environment variable is set to a value that is considered on. - - Args: - var_name: The name of the environment variable. - default_value: The default variable value to consider if the variable has not - been exported. - - Returns: - True if the environment variable contains an on value, False otherwise. - """ - - on_values = {"1", "true", "on", "yes"} - return os.environ.get(var_name, default_value).lower() in on_values - - @staticmethod - @contextlib.contextmanager - def disabled_oop_decorators() -> Generator[None, None, None]: - """ - Context manager to disable the application of jax transformations performed by - the decorators of this class. - - Note: when the transformations are disabled, the only logic still applied is - the selection of the object mutability over which the method is running. - """ - - # Check whether the environment variable is part of the environment and - # save its value. We restore the original value before exiting the context. - env_cache = ( - None if jax_tf.EnvVarOOP not in os.environ else os.environ[jax_tf.EnvVarOOP] - ) - - # Disable both jit and vmap transformations - os.environ[jax_tf.EnvVarOOP] = "0" - - try: - # Execute the code in the context with disabled transformations - yield - - finally: - # Restore the original value of the environment variable or remove it if - # it was not present before entering the context - if env_cache is not None: - os.environ[jax_tf.EnvVarOOP] = env_cache - else: - _ = os.environ.pop(jax_tf.EnvVarOOP) diff --git a/src/jaxsim/utils/vmappable.py b/src/jaxsim/utils/vmappable.py deleted file mode 100644 index 0e449f4b8..000000000 --- a/src/jaxsim/utils/vmappable.py +++ /dev/null @@ -1,117 +0,0 @@ -import dataclasses -from typing import Type - -import jax -import jax.numpy as jnp -import jax_dataclasses - -from . import JaxsimDataclass, Mutability - -try: - from typing import Self -except ImportError: - from typing_extensions import Self - - -@jax_dataclasses.pytree_dataclass -class Vmappable(JaxsimDataclass): - """Abstract class with utilities for vmappable pytrees.""" - - batch_size: jax_dataclasses.Static[int] = dataclasses.field( - default=int(0), repr=False, compare=False, hash=False, kw_only=True - ) - - @property - def vectorized(self) -> bool: - """Marks this pytree as vectorized.""" - - return self.batch_size > 0 - - @classmethod - def build_from_list(cls: Type[Self], list_of_obj: list[Self]) -> Self: - """ - Build a vectorized pytree from a list of pytree of the same type. - - Args: - list_of_obj: The list of pytrees to vectorize. - - Returns: - The vectorized pytree having as leaves the stacked leaves of the input list. - """ - - if set(type(el) for el in list_of_obj) != {cls}: - msg = "The input list must contain only objects of type '{}'" - raise ValueError(msg.format(cls.__name__)) - - # Create a pytree by stacking all the leafs of the input list - data_vec: Vmappable = jax.tree_map( - lambda *leafs: jnp.array(leafs), *list_of_obj - ) - - # Store the batch dimension - with data_vec.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - data_vec.batch_size = len(list_of_obj) - - # Detect the most common mutability in the input list - mutabilities = [e._mutability() for e in list_of_obj] - mutability = max(set(mutabilities), key=mutabilities.count) - - # Update the mutability of the vectorized pytree - data_vec._set_mutability(mutability) - - return data_vec - - def vectorize(self: Self, batch_size: int) -> Self: - """ - Return a vectorized version of this pytree. - - Args: - batch_size: The batch size. - - Returns: - A vectorized version of this pytree obtained by stacking the leaves of the - original pytree along a new batch dimension (the first one). - """ - - if self.vectorized: - raise RuntimeError("Cannot vectorize an already vectorized object") - - if batch_size == 0: - return self.copy() - - # TODO validate if mutability is maintained - - return self.__class__.build_from_list(list_of_obj=[self] * batch_size) - - def extract_element(self: Self, index: int) -> Self: - """ - Extract the i-th element from a vectorized pytree. - - Args: - index: The index of the element to extract. - - Returns: - A non vectorized pytree obtained by extracting the i-th element from the - vectorized pytree. - """ - - if index < 0: - raise ValueError("The index of the desired element cannot be negative") - - if index == 0 and self.batch_size == 0: - return self.copy() - - if not self.vectorized: - raise RuntimeError("Cannot extract elements from a non-vectorized object") - - if index >= self.batch_size: - raise ValueError("The index must be smaller than the batch size") - - # Get the i-th pytree by extracting the i-th element from the vectorized pytree - data = jax.tree_map(lambda leaf: leaf[index], self) - - # Update the batch size of the extracted scalar pytree - with data.mutable_context(mutability=Mutability.MUTABLE): - data.batch_size = 0 - - return data