Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jaxsim.api.model.link_contact_forces #232

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 25 additions & 22 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,14 +1747,18 @@ def link_contact_forces(
data: The data of the considered model.

Returns:
A (nL, 6) array containing the stacked 6D contact forces of the links,
A `(nL, 6)` array containing the stacked 6D contact forces of the links,
expressed in the frame corresponding to the active representation.
"""

# Note: the following code should be kept in sync with the function
# `jaxsim.api.ode.system_velocity_dynamics`. We cannot merge them since
# there we need to get also aux_data.

# Compute the 6D forces applied to each collidable point expressed in the
# inertial frame.
with data.switch_velocity_representation(VelRepr.Inertial):
W_f_Ci = js.contact.collidable_point_forces(model=model, data=data)
W_f_C = js.contact.collidable_point_forces(model=model, data=data)

# Construct the vector defining the parent link index of each collidable point.
# We use this vector to sum the 6D forces of all collidable points rigidly
Expand All @@ -1763,29 +1767,28 @@ def link_contact_forces(
model.kin_dyn_parameters.contact_parameters.body, dtype=int
)

# Create the mask that associate each collidable point to their parent link.
# We use this mask to sum the collidable points to the right link.
mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange(
model.number_of_links()
)

# Sum the forces of all collidable points rigidly attached to a body.
# Since the contact forces W_f_Ci are expressed in the world frame,
# Since the contact forces W_f_C are expressed in the world frame,
# we don't need any coordinate transformation.
W_f_Li = jax.vmap(
lambda nc: (
jnp.vstack(
jnp.equal(parent_link_index_of_collidable_points, nc).astype(int)
)
* W_f_Ci
).sum(axis=0)
)(jnp.arange(model.number_of_links()))

# Convert the 6D forces to the active representation.
f_Li = jax.vmap(
lambda W_f_L: data.inertial_to_other_representation(
array=W_f_L,
other_representation=data.velocity_representation,
transform=data.base_transform(),
is_force=True,
)
)(W_f_Li)
W_f_L = mask.T @ W_f_C

# Create a references object to store the link forces.
references = js.references.JaxSimModelReferences.build(
model=model, link_forces=W_f_L, velocity_representation=VelRepr.Inertial
)

# Use the references object to convert the link forces to the velocity
# representation of data.
with references.switch_velocity_representation(data.velocity_representation):
f_L = references.link_forces(model=model, data=data)

return f_Li
return f_L


# ======
Expand Down
7 changes: 7 additions & 0 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,16 @@ def system_velocity_dynamics(
# with the terrain.
W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float)

# Initialize a dictionary of auxiliary data.
# This dictionary is used to store additional data computed by the contact model.
aux_data = {}

if len(model.kin_dyn_parameters.contact_parameters.body) > 0:

# Note: the following code should be kept in sync with the function
# `jaxsim.api.model.link_contact_forces`. We cannot merge them since
# here we need to get also aux_data.

# Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point
# along with contact-specific auxiliary states.
with data.switch_velocity_representation(VelRepr.Inertial):
Expand Down
Loading