diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index f0d65c4a0..b8ae78585 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -37,14 +37,11 @@ def collidable_point_kinematics( the linear component of the mixed 6D frame velocity. """ - # Switch to inertial-fixed since the RBDAs expect velocities in this representation. - with data.switch_velocity_representation(VelRepr.Inertial): - - W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( - model=model, - link_transforms=data._link_transforms, - link_velocities=data._link_velocities, - ) + W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel( + model=model, + link_transforms=data._link_transforms, + link_velocities=data._link_velocities, + ) return W_p_Ci, W_ṗ_Ci @@ -164,7 +161,11 @@ def estimate_good_soft_contacts_parameters( def estimate_good_contact_parameters( model: js.model.JaxSimModel, *, + standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, static_friction_coefficient: jtp.FloatLike = 0.5, + number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + damping_ratio: jtp.FloatLike = 1.0, + max_penetration: jtp.FloatLike | None = None, **kwargs, ) -> jaxsim.rbda.contacts.ContactParamsTypes: """ @@ -172,7 +173,12 @@ def estimate_good_contact_parameters( Args: model: The model to consider. + standard_gravity: The standard gravity acceleration. static_friction_coefficient: The static friction coefficient. + number_of_active_collidable_points_steady_state: + The number of active collidable points in steady state. + damping_ratio: The damping ratio. + max_penetration: The maximum penetration allowed. kwargs: Additional model-specific parameters passed to the builder method of the parameters class. @@ -190,8 +196,66 @@ def estimate_good_contact_parameters( specific application. """ + def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: + """ + Displacement between the CoM and the lowest collidable point using zero + joint positions. + """ + + zero_data = js.data.JaxSimModelData.build( + model=model, + ) + + W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2] + + if model.floating_base(): + W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1] + return 2 * (W_pz_CoM - W_pz_C.min()) + + return 2 * W_pz_CoM + + max_δ = ( + max_penetration + if max_penetration is not None + # Consider as default a 0.5% of the model height. + else 0.005 * estimate_model_height(model=model) + ) + + nc = number_of_active_collidable_points_steady_state + match model.contact_model: + case contacts.SoftContacts(): + assert isinstance(model.contact_model, contacts.SoftContacts) + + parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model( + model=model, + standard_gravity=standard_gravity, + static_friction_coefficient=static_friction_coefficient, + max_penetration=max_δ, + number_of_active_collidable_points_steady_state=nc, + damping_ratio=damping_ratio, + **kwargs, + ) + + case contacts.RigidContacts(): + assert isinstance(model.contact_model, contacts.RigidContacts) + + # Disable Baumgarte stabilization by default since it does not play + # well with the forward Euler integrator. + K = kwargs.get("K", 0.0) + + parameters = contacts.RigidContactsParams.build( + mu=static_friction_coefficient, + **( + dict( + K=K, + D=2 * jnp.sqrt(K), + ) + | kwargs + ), + ) + case contacts.RelaxedRigidContacts(): assert isinstance(model.contact_model, contacts.RelaxedRigidContacts) diff --git a/src/jaxsim/api/contact_model.py b/src/jaxsim/api/contact_model.py index 1eeca4c5d..adf73b454 100644 --- a/src/jaxsim/api/contact_model.py +++ b/src/jaxsim/api/contact_model.py @@ -5,6 +5,7 @@ import jaxsim.api as js import jaxsim.typing as jtp +from jaxsim.rbda.contacts import SoftContacts @jax.jit @@ -15,7 +16,7 @@ def link_contact_forces( *, link_forces: jtp.MatrixLike | None = None, joint_torques: jtp.VectorLike | None = None, -) -> jtp.Matrix: +) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]: """ Compute the 6D contact forces of all links of the model in inertial representation. @@ -33,11 +34,14 @@ def link_contact_forces( """ # Compute the contact forces for each collidable point with the active contact model. - W_f_C, _ = model.contact_model.compute_contact_forces( + W_f_C, extended_contact_state = model.contact_model.compute_contact_forces( model=model, data=data, - link_forces=link_forces, - joint_force_references=joint_torques, + **( + dict(link_forces=link_forces, joint_force_references=joint_torques) + if not isinstance(model.contact_model, SoftContacts) + else {} + ), ) # Compute the 6D forces applied to the links equivalent to the forces applied @@ -46,7 +50,7 @@ def link_contact_forces( model=model, data=data, contact_forces=W_f_C ) - return W_f_L + return W_f_L, extended_contact_state @staticmethod diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index c13f9ea6c..04b398ebc 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -60,6 +60,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): _link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None) _link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None) + # Extended state for soft and rigid contact models. + contact_state: dict[str, jtp.Array] = dataclasses.field(default=None) + @staticmethod def build( model: js.model.JaxSimModel, @@ -70,6 +73,8 @@ def build( base_angular_velocity: jtp.VectorLike | None = None, joint_velocities: jtp.VectorLike | None = None, velocity_representation: VelRepr = VelRepr.Mixed, + *, + contact_state: dict[str, jtp.Array] | None = None, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with the given state. @@ -85,6 +90,7 @@ def build( The base angular velocity in the selected representation. joint_velocities: The joint velocities. velocity_representation: The velocity representation to use. It defaults to mixed if not provided. + contact_state: The optional contact state. Returns: A `JaxSimModelData` initialized with the given state. @@ -167,18 +173,29 @@ def build( ) ) + contact_state = ( + { + "tangential_deformation": jnp.zeros_like( + model.kin_dyn_parameters.contact_parameters.point + ) + } + if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts) + else contact_state or {} + ) + model_data = JaxSimModelData( velocity_representation=velocity_representation, _base_quaternion=base_quaternion, _base_position=base_position, _joint_positions=joint_positions, - _base_linear_velocity=v_WB[0:3], - _base_angular_velocity=v_WB[3:6], + _base_linear_velocity=W_v_WB[0:3], + _base_angular_velocity=W_v_WB[3:6], _joint_velocities=joint_velocities, _base_transform=W_H_B, _joint_transforms=joint_transforms, _link_transforms=link_transforms, - _link_velocities=link_velocities, + _link_velocities=link_velocities_inertial, + contact_state=contact_state or {}, ) if not model_data.valid(model=model): @@ -386,6 +403,8 @@ def replace( base_angular_velocity: jtp.Vector | None = None, base_position: jtp.Vector | None = None, velocity_representation: VelRepr | None = None, + *, + contact_state: dict[str, jtp.Array] | None = None, validate: bool = False, ) -> Self: """ @@ -438,8 +457,8 @@ def replace( base_quaternion=base_quaternion, joint_positions=joint_positions, joint_velocities=joint_velocities, - base_linear_velocity=base_linear_velocity, - base_angular_velocity=base_angular_velocity, + base_linear_velocity_inertial=base_linear_velocity, + base_angular_velocity_inertial=base_angular_velocity, ) return super().replace( @@ -454,6 +473,7 @@ def replace( _joint_transforms=joint_transforms, _link_transforms=link_transforms, _link_velocities=link_velocities, + contact_state=contact_state, validate=validate, ) diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index 147071f2f..4a8b982a9 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -1,3 +1,4 @@ +import jax import jax.numpy as jnp import jaxsim @@ -12,6 +13,8 @@ def semi_implicit_euler_integration( data: js.data.JaxSimModelData, base_acceleration_inertial: jtp.Vector, joint_accelerations: jtp.Vector, + *, + extended_contact_state: jtp.Vector, ) -> JaxSimModelData: """Integrate the system state using the semi-implicit Euler method.""" # Step the dynamics forward. @@ -57,6 +60,14 @@ def semi_implicit_euler_integration( new_joint_position = data.joint_positions + dt * new_joint_velocities + # Integrate the leaves of the contact state PyTree. + integrated_contact_state = jax.tree.map( + lambda x, x_dot: None if x is None else x + dt * x_dot, + data.contact_state, + extended_contact_state, + is_leaf=lambda x: x is None, + ) + data = data.replace( model=model, validate=True, @@ -70,6 +81,7 @@ def semi_implicit_euler_integration( # it's equivalent to the one in inertial representation # See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9 base_angular_velocity=base_ang_velocity_mixed, + contact_state=integrated_contact_state, ) return data diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index cf91ab5aa..44a949569 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -41,13 +41,13 @@ class JaxSimModel(JaxsimDataclass): default_factory=jaxsim.terrain.FlatTerrain.build, repr=False ) - gravity: Static[float] = jaxsim.math.STANDARD_GRAVITY + gravity: Static[float] = -jaxsim.math.STANDARD_GRAVITY contact_model: Static[jaxsim.rbda.contacts.ContactModel | None] = dataclasses.field( default=None, repr=False ) - contacts_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field( + contact_params: Static[jaxsim.rbda.contacts.ContactsParams] = dataclasses.field( default=None, repr=False ) @@ -111,6 +111,7 @@ def build_from_model_description( terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, contact_params: jaxsim.rbda.contacts.ContactsParams | None = None, + gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, is_urdf: bool | None = None, considered_joints: Sequence[str] | None = None, ) -> JaxSimModel: @@ -131,6 +132,7 @@ def build_from_model_description( The contact model to consider. If not specified, a soft contacts model is used. contact_params: The parameters of the contact model. + gravity: The gravity constant. is_urdf: The optional flag to force the model description to be parsed as a URDF. This is usually automatically inferred. @@ -163,7 +165,8 @@ def build_from_model_description( time_step=time_step, terrain=terrain, contact_model=contact_model, - contacts_params=contact_params, + contact_params=contact_params, + gravity=gravity, ) # Store the origin of the model, in case downstream logic needs it. @@ -181,7 +184,7 @@ def build( time_step: jtp.FloatLike | None = None, terrain: jaxsim.terrain.Terrain | None = None, contact_model: jaxsim.rbda.contacts.ContactModel | None = None, - contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None, + contact_params: jaxsim.rbda.contacts.ContactsParams | None = None, gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY, ) -> JaxSimModel: """ @@ -201,7 +204,7 @@ def build( contact_model: The contact model to consider. If not specified, a soft contacts model is used. - contacts_params: The parameters of the soft contacts. + contact_params: The parameters of the soft contacts. gravity: The gravity constant. Returns: @@ -234,8 +237,8 @@ def build( else jaxsim.rbda.contacts.RelaxedRigidContacts.build() ) - if contacts_params is None: - contacts_params = contact_model._parameters_class() + if contact_params is None: + contact_params = contact_model._parameters_class() # Build the model. model = cls( @@ -246,8 +249,8 @@ def build( time_step=time_step, terrain=terrain, contact_model=contact_model, - contacts_params=contacts_params, - gravity=gravity, + contact_params=contact_params, + gravity=-gravity, # The following is wrapped as hashless since it's a static argument, and we # don't want to trigger recompilation if it changes. All relevant parameters # needed to compute kinematics and dynamics quantities are stored in the @@ -447,6 +450,8 @@ def reduce( time_step=model.time_step, terrain=model.terrain, contact_model=model.contact_model, + contact_params=model.contact_params, + gravity=-model.gravity, ) # Store the origin of the model, in case downstream logic needs it. @@ -2026,7 +2031,7 @@ def step( transform=W_H_L, is_force=True, ) - )(O_f_L_external, data.link_transforms) + )(O_f_L_external, data._link_transforms) τ_references = jnp.atleast_1d( jnp.array(joint_force_references, dtype=float).squeeze() @@ -2052,7 +2057,7 @@ def step( # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. - W_f_L_terrain = js.contact_model.link_contact_forces( + W_f_L_terrain, extended_contact_state = js.contact_model.link_contact_forces( model=model, data=data, link_forces=W_f_L_external, @@ -2065,6 +2070,26 @@ def step( W_f_L_total = W_f_L_external + W_f_L_terrain + # ============================= + # Update the contact state data + # ============================= + + contact_state = {} + + match model.contact_model: + + case jaxsim.rbda.contacts.SoftContacts(): + contact_state["tangential_deformation"] = extended_contact_state["m_dot"] + + case ( + jaxsim.rbda.contacts.RigidContacts() + | jaxsim.rbda.contacts.RelaxedRigidContacts() + ): + pass + + case _: + raise ValueError(f"Invalid contact model: {model.contact_model}") + # =============================== # Compute the system acceleration # =============================== @@ -2086,6 +2111,57 @@ def step( data=data, base_acceleration_inertial=W_v̇_WB, joint_accelerations=s̈, + extended_contact_state=contact_state, + ) + + if isinstance(model.contact_model, jaxsim.rbda.contacts.RigidContacts): + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + W_p_C = js.contact.collidable_point_positions(model, data_tf)[ + indices_of_enabled_collidable_points + ] + + # Compute the penetration depth of the collidable points. + δ, *_ = jax.vmap( + jaxsim.rbda.contacts.common.compute_penetration_data, + in_axes=(0, 0, None), + )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) + + with data_tf.switch_velocity_representation(VelRepr.Mixed): + J_WC = js.contact.jacobian(model, data_tf)[ + indices_of_enabled_collidable_points + ] + M = js.model.free_floating_mass_matrix(model, data_tf) + BW_ν_pre_impact = data_tf.generalized_velocity() + + # Compute the impact velocity. + # It may be discontinuous in case new contacts are made. + BW_ν_post_impact = ( + jaxsim.rbda.contacts.RigidContacts.compute_impact_velocity( + generalized_velocity=BW_ν_pre_impact, + inactive_collidable_points=(δ <= 0), + M=M, + J_WC=J_WC, + ) + ) + + # Reset the generalized velocity. + data_tf = data_tf.replace( + model=model, + base_linear_velocity=BW_ν_post_impact[0:3], + base_angular_velocity=BW_ν_post_impact[3:6], + joint_velocities=BW_ν_post_impact[6:], + ) + + # ne parliamo dopo + # Restore the input velocity representation + data_tf = data_tf.replace( + model=model, + velocity_representation=data.velocity_representation, + validate=False, ) return data_tf diff --git a/src/jaxsim/math/__init__.py b/src/jaxsim/math/__init__.py index e7c221742..cf0bcb107 100644 --- a/src/jaxsim/math/__init__.py +++ b/src/jaxsim/math/__init__.py @@ -11,4 +11,4 @@ # Define the default standard gravity constant. -STANDARD_GRAVITY = -9.81 +STANDARD_GRAVITY = 9.81 diff --git a/src/jaxsim/rbda/__init__.py b/src/jaxsim/rbda/__init__.py index 25bafc1ae..eb89bd208 100644 --- a/src/jaxsim/rbda/__init__.py +++ b/src/jaxsim/rbda/__init__.py @@ -2,7 +2,7 @@ from .aba import aba from .collidable_points import collidable_points_pos_vel from .crba import crba -from .forward_kinematics import forward_kinematics, forward_kinematics_model +from .forward_kinematics import forward_kinematics_model from .jacobian import ( jacobian, jacobian_derivative_full_doubly_left, diff --git a/src/jaxsim/rbda/contacts/__init__.py b/src/jaxsim/rbda/contacts/__init__.py index 3688468cf..32f05e229 100644 --- a/src/jaxsim/rbda/contacts/__init__.py +++ b/src/jaxsim/rbda/contacts/__init__.py @@ -1,5 +1,9 @@ -from . import relaxed_rigid +from . import relaxed_rigid, rigid, soft from .common import ContactModel, ContactsParams from .relaxed_rigid import RelaxedRigidContacts, RelaxedRigidContactsParams +from .rigid import RigidContacts, RigidContactsParams +from .soft import SoftContacts, SoftContactsParams -ContactParamsTypes = RelaxedRigidContactsParams +ContactParamsTypes = ( + SoftContactsParams | RigidContactsParams | RelaxedRigidContactsParams +) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index f5610bf73..88aec8a42 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -341,7 +341,7 @@ def compute_contact_forces( model=model, position_constraint=position_constraint, velocity_constraint=velocity, - parameters=model.contacts_params, + parameters=model.contact_params, ) # Compute the Delassus matrix and the free mixed linear acceleration of diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py new file mode 100644 index 000000000..2e77f4582 --- /dev/null +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -0,0 +1,456 @@ +from __future__ import annotations + +import dataclasses +from typing import Any + +import jax +import jax.numpy as jnp +import jax_dataclasses + +import jaxsim.api as js +import jaxsim.typing as jtp +from jaxsim import logging +from jaxsim.api.common import ModelDataWithVelocityRepresentation, VelRepr + +from . import common +from .common import ContactModel, ContactsParams + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class RigidContactsParams(ContactsParams): + """Parameters of the rigid contacts model.""" + + # Static friction coefficient + mu: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + # Baumgarte proportional term + K: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.0, dtype=float) + ) + + # Baumgarte derivative term + D: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.0, dtype=float) + ) + + def __hash__(self) -> int: + from jaxsim.utils.wrappers import HashedNumpyArray + + return hash( + ( + HashedNumpyArray.hash_of_array(self.mu), + HashedNumpyArray.hash_of_array(self.K), + HashedNumpyArray.hash_of_array(self.D), + ) + ) + + def __eq__(self, other: RigidContactsParams) -> bool: + return hash(self) == hash(other) + + @classmethod + def build( + cls: type[Self], + *, + mu: jtp.FloatLike | None = None, + K: jtp.FloatLike | None = None, + D: jtp.FloatLike | None = None, + ) -> Self: + """Create a `RigidContactParams` instance.""" + + return cls( + mu=jnp.array( + mu + if mu is not None + else cls.__dataclass_fields__["mu"].default_factory() + ).astype(float), + K=jnp.array( + K if K is not None else cls.__dataclass_fields__["K"].default_factory() + ).astype(float), + D=jnp.array( + D if D is not None else cls.__dataclass_fields__["D"].default_factory() + ).astype(float), + ) + + def valid(self) -> jtp.BoolLike: + """Check if the parameters are valid.""" + return bool( + jnp.all(self.mu >= 0.0) + and jnp.all(self.K >= 0.0) + and jnp.all(self.D >= 0.0) + ) + + +@jax_dataclasses.pytree_dataclass +class RigidContacts(ContactModel): + """Rigid contacts model.""" + + regularization_delassus: jax_dataclasses.Static[float] = dataclasses.field( + default=1e-6, kw_only=True + ) + + _solver_options_keys: jax_dataclasses.Static[tuple[str, ...]] = dataclasses.field( + default=("solver_tol",), kw_only=True + ) + _solver_options_values: jax_dataclasses.Static[tuple[Any, ...]] = dataclasses.field( + default=(1e-3,), kw_only=True + ) + + @property + def solver_options(self) -> dict[str, Any]: + """Get the solver options as a dictionary.""" + + return dict( + zip( + self._solver_options_keys, + self._solver_options_values, + strict=True, + ) + ) + + @classmethod + def build( + cls: type[Self], + regularization_delassus: jtp.FloatLike | None = None, + solver_options: dict[str, Any] | None = None, + **kwargs, + ) -> Self: + """ + Create a `RigidContacts` instance with specified parameters. + + Args: + regularization_delassus: + The regularization term to add to the diagonal of the Delassus matrix. + solver_options: The options to pass to the QP solver. + **kwargs: Extra arguments which are ignored. + + Returns: + The `RigidContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + # Get the default solver options. + default_solver_options = dict( + zip(cls._solver_options_keys, cls._solver_options_values, strict=True) + ) + + # Create the solver options to set by combining the default solver options + # with the user-provided solver options. + solver_options = default_solver_options | ( + solver_options if solver_options is not None else {} + ) + + # Make sure that the solver options are hashable. + # We need to check this because the solver options are static. + try: + hash(tuple(solver_options.values())) + except TypeError as exc: + raise ValueError( + "The values of the solver options must be hashable." + ) from exc + + return cls( + regularization_delassus=float( + regularization_delassus + if regularization_delassus is not None + else cls.__dataclass_fields__["regularization_delassus"].default + ), + _solver_options_keys=tuple(solver_options.keys()), + _solver_options_values=tuple(solver_options.values()), + **kwargs, + ) + + @staticmethod + def compute_impact_velocity( + inactive_collidable_points: jtp.ArrayLike, + M: jtp.MatrixLike, + J_WC: jtp.MatrixLike, + generalized_velocity: jtp.VectorLike, + ) -> jtp.Vector: + """ + Return the new velocity of the system after a potential impact. + + Args: + inactive_collidable_points: The activation state of the collidable points. + M: The mass matrix of the system (in mixed representation). + J_WC: The Jacobian matrix of the collidable points (in mixed representation). + generalized_velocity: The generalized velocity of the system. + + Note: + The mass matrix `M`, the Jacobian `J_WC`, and the generalized velocity `generalized_velocity` + must be expressed in the same velocity representation. + """ + + # Compute system velocity after impact maintaining zero linear velocity of active points. + sl = jnp.s_[:, 0:3, :] + Jl_WC = J_WC[sl] + + # Zero out the jacobian rows of inactive points. + Jl_WC = jnp.vstack( + jnp.where( + inactive_collidable_points[:, jnp.newaxis, jnp.newaxis], + jnp.zeros_like(Jl_WC), + Jl_WC, + ) + ) + + A = jnp.vstack( + [ + jnp.hstack([M, -Jl_WC.T]), + jnp.hstack([Jl_WC, jnp.zeros((Jl_WC.shape[0], Jl_WC.shape[0]))]), + ] + ) + b = jnp.hstack([M @ generalized_velocity, jnp.zeros(Jl_WC.shape[0])]) + + BW_ν_post_impact = jnp.linalg.lstsq(A, b)[0] + + return BW_ν_post_impact[0 : M.shape[0]] + + @jax.jit + def compute_contact_forces( + self, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: + """ + Compute the contact forces. + + Args: + model: The model to consider. + data: The data of the considered model. + link_forces: + Optional `(n_links, 6)` matrix of external forces acting on the links, + expressed in the same representation of data. + joint_force_references: + Optional `(n_joints,)` vector of joint forces. + + Returns: + A tuple containing as first element the computed contact forces. + """ + + # Import qpax privately just in this method. + import qpax + + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + n_collidable_points = len(indices_of_enabled_collidable_points) + + link_forces = jnp.atleast_2d( + jnp.array(link_forces, dtype=float).squeeze() + if link_forces is not None + else jnp.zeros((model.number_of_links(), 6)) + ) + + joint_force_references = jnp.atleast_1d( + jnp.array(joint_force_references, dtype=float).squeeze() + if joint_force_references is not None + else jnp.zeros((model.number_of_joints(),)) + ) + + # Compute kin-dyn quantities used in the contact model. + with data.switch_velocity_representation(VelRepr.Mixed): + BW_ν = data.generalized_velocity() + + M = js.model.free_floating_mass_matrix(model=model, data=data) + + J_WC = js.contact.jacobian(model=model, data=data) + J̇_WC = js.contact.jacobian_derivative(model=model, data=data) + + W_H_C = js.contact.transforms(model=model, data=data) + + # Compute the position and linear velocities (mixed representation) of + # all enabled collidable points belonging to the robot. + position, velocity = js.contact.collidable_point_kinematics( + model=model, data=data + ) + + # Compute the penetration depth and velocity of the collidable points. + # Note that this function considers the penetration in the normal direction. + δ, δ_dot, n̂ = jax.vmap(common.compute_penetration_data, in_axes=(0, 0, None))( + position, velocity, model.terrain + ) + + # Build a references object to simplify converting link forces. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + velocity_representation=data.velocity_representation, + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + + # Compute the generalized free acceleration. + with data.switch_velocity_representation(VelRepr.Mixed): + BW_ν̇_free = jnp.hstack( + js.ode.system_acceleration( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + joint_torques=references.joint_force_references(model=model), + ) + ) + + # Compute the free linear acceleration of the collidable points. + # Since we use doubly-mixed jacobian, this corresponds to W_p̈_C. + free_contact_acc = RigidContacts._linear_acceleration_of_collidable_points( + BW_nu=BW_ν, + BW_nu_dot=BW_ν̇_free, + CW_J_WC_BW=J_WC, + CW_J_dot_WC_BW=J̇_WC, + ).flatten() + + # Compute stabilization term. + baumgarte_term = RigidContacts._compute_baumgarte_stabilization_term( + inactive_collidable_points=(δ <= 0), + δ=δ, + δ_dot=δ_dot, + n=n̂, + K=model.contact_params.K, + D=model.contact_params.D, + ).flatten() + + # Compute the Delassus matrix. + delassus_matrix = RigidContacts._delassus_matrix(M=M, J_WC=J_WC) + + # Initialize regularization term of the Delassus matrix for + # better numerical conditioning. + Iε = self.regularization_delassus * jnp.eye(delassus_matrix.shape[0]) + + # Construct the quadratic cost function. + Q = delassus_matrix + Iε + q = free_contact_acc - baumgarte_term + + # Construct the inequality constraints. + G = RigidContacts._compute_ineq_constraint_matrix( + inactive_collidable_points=(δ <= 0), mu=model.contact_params.mu + ) + h_bounds = RigidContacts._compute_ineq_bounds( + n_collidable_points=n_collidable_points + ) + + # Construct the equality constraints. + A = jnp.zeros((0, 3 * n_collidable_points)) + b = jnp.zeros((0,)) + + # Solve the following optimization problem with qpax: + # + # min_{x} 0.5 x⊤ Q x + q⊤ x + # + # s.t. A x = b + # G x ≤ h + # + # TODO: add possibility to notify if the QP problem did not converge. + solution, _, _, _, converged, _ = qpax.solve_qp( # noqa: F841 + Q=Q, q=q, A=A, b=b, G=G, h=h_bounds, **self.solver_options + ) + + # Reshape the optimized solution to be a matrix of 3D contact forces. + CW_fl_C = solution.reshape(-1, 3) + + # Convert the contact forces from mixed to inertial-fixed representation. + W_f_C = jax.vmap( + lambda CW_fl_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(CW_fl_C), + transform=W_H_C, + other_representation=VelRepr.Mixed, + is_force=True, + ) + ), + )(CW_fl_C, W_H_C) + + return W_f_C, {} + + @staticmethod + def _delassus_matrix( + M: jtp.MatrixLike, + J_WC: jtp.MatrixLike, + ) -> jtp.Matrix: + + sl = jnp.s_[:, 0:3, :] + J_WC_lin = jnp.vstack(J_WC[sl]) + + delassus_matrix = J_WC_lin @ jnp.linalg.pinv(M) @ J_WC_lin.T + return delassus_matrix + + @staticmethod + def _compute_ineq_constraint_matrix( + inactive_collidable_points: jtp.Vector, mu: jtp.FloatLike + ) -> jtp.Matrix: + """ + Compute the inequality constraint matrix for a single collidable point. + + Rows 0-3: enforce the friction pyramid constraint, + Row 4: last one is for the non negativity of the vertical force + Row 5: contact complementarity condition + """ + G_single_point = jnp.array( + [ + [1, 0, -mu], + [0, 1, -mu], + [-1, 0, -mu], + [0, -1, -mu], + [0, 0, -1], + [0, 0, 0], + ] + ) + G = jnp.tile(G_single_point, (len(inactive_collidable_points), 1, 1)) + G = G.at[:, 5, 2].set(inactive_collidable_points) + + G = jax.scipy.linalg.block_diag(*G) + return G + + @staticmethod + def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: + + n_constraints = 6 * n_collidable_points + return jnp.zeros(shape=(n_constraints,)) + + @staticmethod + def _linear_acceleration_of_collidable_points( + BW_nu: jtp.ArrayLike, + BW_nu_dot: jtp.ArrayLike, + CW_J_WC_BW: jtp.MatrixLike, + CW_J_dot_WC_BW: jtp.MatrixLike, + ) -> jtp.Matrix: + + BW_ν = BW_nu + BW_ν̇ = BW_nu_dot + CW_J̇_WC_BW = CW_J_dot_WC_BW + + # Compute the linear acceleration of the collidable points. + # Since we use doubly-mixed jacobians, this corresponds to W_p̈_C. + CW_a_WC = jnp.vstack(CW_J̇_WC_BW) @ BW_ν + jnp.vstack(CW_J_WC_BW) @ BW_ν̇ + + CW_a_WC = CW_a_WC.reshape(-1, 6) + return CW_a_WC[:, 0:3].squeeze() + + @staticmethod + def _compute_baumgarte_stabilization_term( + inactive_collidable_points: jtp.ArrayLike, + δ: jtp.ArrayLike, + δ_dot: jtp.ArrayLike, + n: jtp.ArrayLike, + K: jtp.FloatLike, + D: jtp.FloatLike, + ) -> jtp.Array: + + return jnp.where( + inactive_collidable_points[:, jnp.newaxis], + jnp.zeros_like(n), + (K * δ + D * δ_dot)[:, jnp.newaxis] * n, + ) diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py new file mode 100644 index 000000000..86889d2d4 --- /dev/null +++ b/src/jaxsim/rbda/contacts/soft.py @@ -0,0 +1,484 @@ +from __future__ import annotations + +import dataclasses +import functools + +import jax +import jax.numpy as jnp +import jax_dataclasses + +import jaxsim.api as js +import jaxsim.math +import jaxsim.typing as jtp +from jaxsim import logging +from jaxsim.math import STANDARD_GRAVITY +from jaxsim.terrain import Terrain + +from . import common + +try: + from typing import Self +except ImportError: + from typing_extensions import Self + + +@jax_dataclasses.pytree_dataclass +class SoftContactsParams(common.ContactsParams): + """Parameters of the soft contacts model.""" + + K: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(1e6, dtype=float) + ) + + D: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(2000, dtype=float) + ) + + mu: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + p: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + q: jtp.Float = dataclasses.field( + default_factory=lambda: jnp.array(0.5, dtype=float) + ) + + def __hash__(self) -> int: + + from jaxsim.utils.wrappers import HashedNumpyArray + + return hash( + ( + HashedNumpyArray.hash_of_array(self.K), + HashedNumpyArray.hash_of_array(self.D), + HashedNumpyArray.hash_of_array(self.mu), + HashedNumpyArray.hash_of_array(self.p), + HashedNumpyArray.hash_of_array(self.q), + ) + ) + + def __eq__(self, other: SoftContactsParams) -> bool: + + if not isinstance(other, SoftContactsParams): + return NotImplemented + + return hash(self) == hash(other) + + @classmethod + def build( + cls: type[Self], + *, + K: jtp.FloatLike = 1e6, + D: jtp.FloatLike = 2_000, + mu: jtp.FloatLike = 0.5, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> Self: + """ + Create a SoftContactsParams instance with specified parameters. + + Args: + K: The stiffness parameter. + D: The damping parameter of the soft contacts model. + mu: The static friction coefficient. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model + + Returns: + A SoftContactsParams instance with the specified parameters. + """ + + return SoftContactsParams( + K=jnp.array(K, dtype=float), + D=jnp.array(D, dtype=float), + mu=jnp.array(mu, dtype=float), + p=jnp.array(p, dtype=float), + q=jnp.array(q, dtype=float), + ) + + @classmethod + def build_default_from_jaxsim_model( + cls: type[Self], + model: js.model.JaxSimModel, + *, + standard_gravity: jtp.FloatLike = STANDARD_GRAVITY, + static_friction_coefficient: jtp.FloatLike = 0.5, + max_penetration: jtp.FloatLike = 0.001, + number_of_active_collidable_points_steady_state: jtp.IntLike = 1, + damping_ratio: jtp.FloatLike = 1.0, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> SoftContactsParams: + """ + Create a SoftContactsParams instance with good default parameters. + + Args: + model: The target model. + standard_gravity: The standard gravity constant. + static_friction_coefficient: + The static friction coefficient between the model and the terrain. + max_penetration: The maximum penetration depth. + number_of_active_collidable_points_steady_state: + The number of contacts supporting the weight of the model + in steady state. + damping_ratio: The ratio controlling the damping behavior. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model + + Returns: + A `SoftContactsParams` instance with the specified parameters. + + Note: + The `damping_ratio` parameter allows to operate on the following conditions: + - ξ > 1.0: over-damped + - ξ = 1.0: critically damped + - ξ < 1.0: under-damped + """ + + # Use symbols for input parameters. + ξ = damping_ratio + δ_max = max_penetration + μc = static_friction_coefficient + + # Compute the total mass of the model. + m = jnp.array(model.kin_dyn_parameters.link_parameters.mass).sum() + + # Rename the standard gravity. + g = standard_gravity + + # Compute the average support force on each collidable point. + f_average = m * g / number_of_active_collidable_points_steady_state + + # Compute the stiffness to get the desired steady-state penetration. + # Note that this is dependent on the non-linear exponent used in + # the damping term of the Hunt/Crossley model. + K = f_average / jnp.power(δ_max, 1 + p) + + # Compute the damping using the damping ratio. + critical_damping = 2 * jnp.sqrt(K * m) + D = ξ * critical_damping + + return SoftContactsParams.build(K=K, D=D, mu=μc, p=p, q=q) + + def valid(self) -> jtp.BoolLike: + """ + Check if the parameters are valid. + + Returns: + `True` if the parameters are valid, `False` otherwise. + """ + + return jnp.hstack( + [ + self.K >= 0.0, + self.D >= 0.0, + self.mu >= 0.0, + self.p >= 0.0, + self.q >= 0.0, + ] + ).all() + + +@jax_dataclasses.pytree_dataclass +class SoftContacts(common.ContactModel): + """Soft contacts model.""" + + @classmethod + def build( + cls: type[Self], + model: js.model.JaxSimModel | None = None, + **kwargs, + ) -> Self: + """ + Create a `SoftContacts` instance with specified parameters. + + Args: + model: + The robot model considered by the contact model. + If passed, it is used to estimate good default parameters. + **kwargs: Additional parameters to pass to the contact model. + + Returns: + The `SoftContacts` instance. + """ + + if len(kwargs) != 0: + logging.debug(msg=f"Ignoring extra arguments: {kwargs}") + + return cls(**kwargs) + + @classmethod + def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: + """ + Build zero state variables of the contact model. + """ + + # Initialize the material deformation to zero. + tangential_deformation = jnp.zeros( + shape=(len(model.kin_dyn_parameters.contact_parameters.body), 3), + dtype=float, + ) + + return {"tangential_deformation": tangential_deformation} + + @staticmethod + @functools.partial(jax.jit, static_argnames=("terrain",)) + def hunt_crossley_contact_model( + position: jtp.VectorLike, + velocity: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, + terrain: Terrain, + K: jtp.FloatLike, + D: jtp.FloatLike, + mu: jtp.FloatLike, + p: jtp.FloatLike = 0.5, + q: jtp.FloatLike = 0.5, + ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact force using the Hunt/Crossley model. + + Args: + position: The position of the collidable point. + velocity: The velocity of the collidable point. + tangential_deformation: The material deformation of the collidable point. + terrain: The terrain model. + K: The stiffness parameter. + D: The damping parameter of the soft contacts model. + mu: The static friction coefficient. + p: + The exponent p corresponding to the damping-related non-linearity + of the Hunt/Crossley model. + q: + The exponent q corresponding to the spring-related non-linearity + of the Hunt/Crossley model + + Returns: + A tuple containing the computed contact force and the derivative of the + material deformation. + """ + + # Convert the input vectors to arrays. + W_p_C = jnp.array(position, dtype=float).squeeze() + W_ṗ_C = jnp.array(velocity, dtype=float).squeeze() + m = jnp.array(tangential_deformation, dtype=float).squeeze() + + # Use symbol for the static friction. + μ = mu + + # Compute the penetration depth, its rate, and the considered terrain normal. + δ, δ̇, n̂ = common.compute_penetration_data(p=W_p_C, v=W_ṗ_C, terrain=terrain) + + # There are few operations like computing the norm of a vector with zero length + # or computing the square root of zero that are problematic in an AD context. + # To avoid these issues, we introduce a small tolerance ε to their arguments + # and make sure that we do not check them against zero directly. + ε = jnp.finfo(float).eps + + # Compute the powers of the penetration depth. + # Inject ε to address AD issues in differentiating the square root when + # p and q are fractional. + δp = jnp.power(δ + ε, p) + δq = jnp.power(δ + ε, q) + + # ======================== + # Compute the normal force + # ======================== + + # Non-linear spring-damper model (Hunt/Crossley model). + # This is the force magnitude along the direction normal to the terrain. + force_normal_mag = (K * δp) * δ + (D * δq) * δ̇ + + # Depending on the magnitude of δ̇, the normal force could be negative. + force_normal_mag = jnp.maximum(0.0, force_normal_mag) + + # Compute the 3D linear force in C[W] frame. + f_normal = force_normal_mag * n̂ + + # ============================ + # Compute the tangential force + # ============================ + + # Extract the tangential component of the velocity. + v_tangential = W_ṗ_C - jnp.dot(W_ṗ_C, n̂) * n̂ + + # Extract the normal and tangential components of the material deformation. + m_normal = jnp.dot(m, n̂) * n̂ + m_tangential = m - jnp.dot(m, n̂) * n̂ + + # Compute the tangential force in the sticking case. + # Using the tangential component of the material deformation should not be + # necessary if the sticking-slipping transition occurs in a terrain area + # with a locally constant normal. However, this assumption is not true in + # general, especially for highly uneven terrains. + f_tangential = -((K * δp) * m_tangential + (D * δq) * v_tangential) + + # Detect the contact type (sticking or slipping). + # Note that if there is no contact, sticking is set to True, and this detail + # is exploited in the computation of the `contact_status` variable. + sticking = jnp.logical_or( + δ <= 0, f_tangential.dot(f_tangential) <= (μ * force_normal_mag) ** 2 + ) + + # Compute the direction of the tangential force. + # To prevent dividing by zero, we use a switch statement. + norm = jaxsim.math.safe_norm(f_tangential) + f_tangential_direction = f_tangential / ( + norm + jnp.finfo(float).eps * (norm == 0) + ) + + # Project the tangential force to the friction cone if slipping. + f_tangential = jnp.where( + sticking, + f_tangential, + jnp.minimum(μ * force_normal_mag, norm) * f_tangential_direction, + ) + + # Set the tangential force to zero if there is no contact. + f_tangential = jnp.where(δ <= 0, jnp.zeros(3), f_tangential) + + # ===================================== + # Compute the material deformation rate + # ===================================== + + # Compute the derivative of the material deformation. + # Note that we included an additional relaxation of `m_normal` in the + # sticking case, so that the normal deformation that could have accumulated + # from a previous slipping phase can relax to zero. + ṁ_no_contact = -(K / D) * m + ṁ_sticking = v_tangential - (K / D) * m_normal + ṁ_slipping = -(f_tangential + (K * δp) * m_tangential) / (D * δq) + + # Compute the contact status: + # 0: slipping + # 1: sticking + # 2: no contact + contact_status = sticking.astype(int) + contact_status += (δ <= 0).astype(int) + + # Select the right material deformation rate depending on the contact status. + ṁ = jax.lax.select_n(contact_status, ṁ_slipping, ṁ_sticking, ṁ_no_contact) + + # ========================================== + # Compute and return the final contact force + # ========================================== + + # Sum the normal and tangential forces. + CW_fl = f_normal + f_tangential + + return CW_fl, ṁ + + @staticmethod + @functools.partial(jax.jit, static_argnames=("terrain",)) + def compute_contact_force( + position: jtp.VectorLike, + velocity: jtp.VectorLike, + tangential_deformation: jtp.VectorLike, + parameters: SoftContactsParams, + terrain: Terrain, + ) -> tuple[jtp.Vector, jtp.Vector]: + """ + Compute the contact force. + + Args: + position: The position of the collidable point. + velocity: The velocity of the collidable point. + tangential_deformation: The material deformation of the collidable point. + parameters: The parameters of the soft contacts model. + terrain: The terrain model. + + Returns: + A tuple containing the computed contact force and the derivative of the + material deformation. + """ + + CW_fl, ṁ = SoftContacts.hunt_crossley_contact_model( + position=position, + velocity=velocity, + tangential_deformation=tangential_deformation, + terrain=terrain, + K=parameters.K, + D=parameters.D, + mu=parameters.mu, + p=parameters.p, + q=parameters.q, + ) + + # Pack a mixed 6D force. + CW_f = jnp.hstack([CW_fl, jnp.zeros(3)]) + + # Compute the 6D force transform from the mixed to the inertial-fixed frame. + W_Xf_CW = jaxsim.math.Adjoint.from_quaternion_and_translation( + translation=jnp.array(position), inverse=True + ).T + + # Compute the 6D force in the inertial-fixed frame. + W_f = W_Xf_CW @ CW_f + + return W_f, ṁ + + @staticmethod + @jax.jit + def compute_contact_forces( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: + """ + Compute the contact forces. + + Args: + model: The model to consider. + data: The data of the considered model. + + Returns: + A tuple containing as first element the computed contact forces, and as + second element a dictionary with derivative of the material deformation. + """ + + # Get the indices of the enabled collidable points. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + + # Compute the position and linear velocities (mixed representation) of + # all the collidable points belonging to the robot and extract the ones + # for the enabled collidable points. + W_p_C, W_ṗ_C = js.contact.collidable_point_kinematics(model=model, data=data) + + # Extract the material deformation corresponding to the collidable points. + m = ( + data.contact_state["tangential_deformation"] + if "tangential_deformation" in data.contact_state + else jnp.zeros_like(W_p_C) + ) + + m_enabled = m[indices_of_enabled_collidable_points] + + # Initialize the tangential deformation rate array for every collidable point. + ṁ = jnp.zeros_like(m) + + # Compute the contact forces only for the enabled collidable points. + # Since we treat them as independent, we can vmap the computation. + W_f, ṁ_enabled = jax.vmap( + lambda p, v, m: SoftContacts.compute_contact_force( + position=p, + velocity=v, + tangential_deformation=m, + parameters=model.contact_params, + terrain=model.terrain, + ) + )(W_p_C, W_ṗ_C, m_enabled) + + ṁ = ṁ.at[indices_of_enabled_collidable_points].set(ṁ_enabled) + + return W_f, dict(m_dot=ṁ) diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index 355fe347c..58c230c7d 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -111,32 +111,3 @@ def propagate_kinematics( ) return jax.vmap(Adjoint.to_transform)(W_X_i), W_v_Wi - - -def forward_kinematics( - model: js.model.JaxSimModel, - link_index: jtp.Int, - base_position: jtp.VectorLike, - base_quaternion: jtp.VectorLike, - joint_positions: jtp.VectorLike, -) -> jtp.Matrix: - """ - Compute the forward kinematics of a specific link. - - Args: - model: The model to consider. - link_index: The index of the link to consider. - base_position: The position of the base link. - base_quaternion: The quaternion of the base link. - joint_positions: The positions of the joints. - - Returns: - The SE(3) transform of the link. - """ - - return forward_kinematics_model( - model=model, - base_position=base_position, - base_quaternion=base_quaternion, - joint_positions=joint_positions, - )[link_index] diff --git a/src/jaxsim/rbda/utils.py b/src/jaxsim/rbda/utils.py index 9b2614a52..b2435dbf4 100644 --- a/src/jaxsim/rbda/utils.py +++ b/src/jaxsim/rbda/utils.py @@ -132,6 +132,12 @@ def process_inputs( if W_Q_B.shape != (4,): raise ValueError(W_Q_B.shape, (4,)) + # Check that the quaternion does not contain NaN values. + exceptions.raise_value_error_if( + condition=jnp.isnan(W_Q_B).any(), + msg="A RBDA received a quaternion that contains NaN values.", + ) + # Check that the quaternion is unary since our RBDAs make this assumption in order # to prevent introducing additional normalizations that would affect AD. exceptions.raise_value_error_if( diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 8943a6b7f..334452821 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -9,6 +9,7 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim import VelRepr +from jaxsim.rbda.contacts import SoftContacts, SoftContactsParams # All JaxSim algorithms, excluding the variable-step integrators, should support # being automatically differentiated until second order, both in FWD and REV modes. @@ -288,6 +289,58 @@ def test_ad_jacobian( ) +def test_ad_soft_contacts( + jaxsim_models_types: js.model.JaxSimModel, + prng_key: jax.Array, +): + + model = jaxsim_models_types + + with model.editable(validate=False) as model: + model.contact_model = jaxsim.rbda.contacts.SoftContacts.build(model=model) + + _, subkey1, subkey2, subkey3 = jax.random.split(prng_key, num=4) + p = jax.random.uniform(subkey1, shape=(3,), minval=-1) + v = jax.random.uniform(subkey2, shape=(3,), minval=-1) + m = jax.random.uniform(subkey3, shape=(3,), minval=-1) + + # Get the soft contacts parameters. + parameters = js.contact.estimate_good_contact_parameters(model=model) + + # ==== + # Test + # ==== + + # Get a closure exposing only the parameters to be differentiated. + def close_over_inputs_and_parameters( + p: jtp.VectorLike, + v: jtp.VectorLike, + m: jtp.VectorLike, + params: SoftContactsParams, + ) -> tuple[jtp.Vector, jtp.Vector]: + + W_f_Ci, CW_ṁ = SoftContacts.compute_contact_force( + position=p, + velocity=v, + tangential_deformation=m, + parameters=params, + terrain=model.terrain, + ) + + return W_f_Ci, CW_ṁ + + # Check derivatives against finite differences. + check_grads( + f=close_over_inputs_and_parameters, + args=(p, v, m, parameters), + order=AD_ORDER, + modes=["rev", "fwd"], + eps=ε, + # On GPU, the tolerance needs to be increased. + rtol=0.02 if "gpu" in {d.platform for d in p.devices()} else None, + ) + + def test_ad_integration( jaxsim_models_types: js.model.JaxSimModel, prng_key: jax.Array, diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index cf6898f18..26f767652 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -53,15 +53,6 @@ def test_free_floating_bias_forces( ) -@pytest.mark.benchmark -def test_forward_kinematics( - jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size -): - model = jaxsim_model_ergocub_reduced - - benchmark_test_function(js.model.forward_kinematics, model, benchmark, batch_size) - - @pytest.mark.benchmark def test_free_floating_mass_matrix( jaxsim_model_ergocub_reduced: js.model.JaxSimModel, benchmark, batch_size diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 50fa4beb9..19a49649f 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -61,7 +61,7 @@ def test_box_with_external_forces( additive=False, ) - # Initialize the integrator. + # Initialize the simulation horizon. tf = 0.5 T_ns = jnp.arange(start=0, stop=tf * 1e9, step=model.time_step * 1e9, dtype=int) @@ -187,6 +187,107 @@ def run_simulation( return data +def test_simulation_with_soft_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + # Define the maximum penetration of each collidable point at steady state. + max_penetration = 0.001 + + with model.editable(validate=False) as model: + + model.contact_model = jaxsim.rbda.contacts.SoftContacts.build() + model.contact_params = js.contact.estimate_good_contact_parameters( + model=model, + number_of_active_collidable_points_steady_state=4, + static_friction_coefficient=1.0, + damping_ratio=1.0, + max_penetration=max_penetration, + ) + + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) + + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) + + +def test_simulation_with_rigid_contacts( + jaxsim_model_box: js.model.JaxSimModel, +): + + model = jaxsim_model_box + + with model.editable(validate=False) as model: + + # In order to achieve almost no penetration, we need to use a fairly large + # Baumgarte stabilization term. + model.contact_model = jaxsim.rbda.contacts.RigidContacts.build( + solver_options={"solver_tol": 1e-3} + ) + model.contact_params = model.contact_model._parameters_class(K=1e5) + + # Enable a subset of the collidable points. + enabled_collidable_points_mask = np.zeros( + len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool + ) + enabled_collidable_points_mask[[0, 1, 2, 3]] = True + model.kin_dyn_parameters.contact_parameters.enabled = tuple( + enabled_collidable_points_mask.tolist() + ) + + assert np.sum(model.kin_dyn_parameters.contact_parameters.enabled) == 4 + + # Initialize the maximum penetration of each collidable point at steady state. + # This model is rigid, so we expect (almost) no penetration. + max_penetration = 0.000 + + # Check jaxsim_model_box@conftest.py. + box_height = 0.1 + + # Build the data of the model. + data_t0 = js.data.JaxSimModelData.build( + model=model, + base_position=jnp.array([0.0, 0.0, box_height * 2]), + velocity_representation=VelRepr.Inertial, + ) + + # =========================================== + # Run the simulation and test the final state + # =========================================== + + data_tf = run_simulation(model=model, data_t0=data_t0, tf=1.0) + + assert data_tf.base_position[0:2] == pytest.approx(data_t0.base_position[0:2]) + assert data_tf.base_position[2] + max_penetration == pytest.approx(box_height / 2) + + def test_simulation_with_relaxed_rigid_contacts( jaxsim_model_box: js.model.JaxSimModel, ): @@ -198,6 +299,8 @@ def test_simulation_with_relaxed_rigid_contacts( model.contact_model = jaxsim.rbda.contacts.RelaxedRigidContacts.build( solver_options={"tol": 1e-3}, ) + model.contact_params = model.contact_model._parameters_class() + # Enable a subset of the collidable points. enabled_collidable_points_mask = np.zeros( len(model.kin_dyn_parameters.contact_parameters.body), dtype=bool