From 32617915e2df53b70cb326905a4af02076ac0894 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Tue, 17 Sep 2024 10:12:18 +0200 Subject: [PATCH] Fix jaxsim.api.model.link_contact_forces This was a leftover from when we updated the velocity representation of the model Jacobian to not use B and B[W] in body-fixed and mixed in favor of L and L[W]. --- src/jaxsim/api/model.py | 47 ++++++++++++++++++++++------------------- src/jaxsim/api/ode.py | 7 ++++++ 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index d9ea4679e..195dcd391 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1747,14 +1747,18 @@ def link_contact_forces( data: The data of the considered model. Returns: - A (nL, 6) array containing the stacked 6D contact forces of the links, + A `(nL, 6)` array containing the stacked 6D contact forces of the links, expressed in the frame corresponding to the active representation. """ + # Note: the following code should be kept in sync with the function + # `jaxsim.api.ode.system_velocity_dynamics`. We cannot merge them since + # there we need to get also aux_data. + # Compute the 6D forces applied to each collidable point expressed in the # inertial frame. with data.switch_velocity_representation(VelRepr.Inertial): - W_f_Ci = js.contact.collidable_point_forces(model=model, data=data) + W_f_C = js.contact.collidable_point_forces(model=model, data=data) # Construct the vector defining the parent link index of each collidable point. # We use this vector to sum the 6D forces of all collidable points rigidly @@ -1763,29 +1767,28 @@ def link_contact_forces( model.kin_dyn_parameters.contact_parameters.body, dtype=int ) + # Create the mask that associate each collidable point to their parent link. + # We use this mask to sum the collidable points to the right link. + mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( + model.number_of_links() + ) + # Sum the forces of all collidable points rigidly attached to a body. - # Since the contact forces W_f_Ci are expressed in the world frame, + # Since the contact forces W_f_C are expressed in the world frame, # we don't need any coordinate transformation. - W_f_Li = jax.vmap( - lambda nc: ( - jnp.vstack( - jnp.equal(parent_link_index_of_collidable_points, nc).astype(int) - ) - * W_f_Ci - ).sum(axis=0) - )(jnp.arange(model.number_of_links())) - - # Convert the 6D forces to the active representation. - f_Li = jax.vmap( - lambda W_f_L: data.inertial_to_other_representation( - array=W_f_L, - other_representation=data.velocity_representation, - transform=data.base_transform(), - is_force=True, - ) - )(W_f_Li) + W_f_L = mask.T @ W_f_C + + # Create a references object to store the link forces. + references = js.references.JaxSimModelReferences.build( + model=model, link_forces=W_f_L, velocity_representation=VelRepr.Inertial + ) + + # Use the references object to convert the link forces to the velocity + # representation of data. + with references.switch_velocity_representation(data.velocity_representation): + f_L = references.link_forces(model=model, data=data) - return f_Li + return f_L # ====== diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index d9e85d021..bb4911960 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -132,9 +132,16 @@ def system_velocity_dynamics( # with the terrain. W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float) + # Initialize a dictionary of auxiliary data. + # This dictionary is used to store additional data computed by the contact model. aux_data = {} + if len(model.kin_dyn_parameters.contact_parameters.body) > 0: + # Note: the following code should be kept in sync with the function + # `jaxsim.api.model.link_contact_forces`. We cannot merge them since + # here we need to get also aux_data. + # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point # along with contact-specific auxiliary states. with data.switch_velocity_representation(VelRepr.Inertial):