Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to use the link frame when setting or getting forces using the references helper #164

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 79 additions & 10 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,13 @@ def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
# Extract quantities
# ==================

@functools.partial(jax.jit, static_argnames=["link_names"])
@functools.partial(jax.jit, static_argnames=["link_names", "use_link_frame"])
def link_forces(
self,
model: js.model.JaxSimModel | None = None,
data: js.data.JaxSimModelData | None = None,
link_names: tuple[str, ...] | None = None,
use_link_frame: bool = False,
) -> jtp.Matrix:
"""
Return the link forces expressed in the frame of the active representation.
Expand All @@ -155,6 +156,9 @@ def link_forces(
model: The model to consider.
data: The data of the considered model.
link_names: The names of the links corresponding to the forces.
use_link_frame:
Whether to consider the frame of the link instead of the frame of the
base in body-fixed and mixed representations.

Returns:
If no model and no link names are provided, the link forces as a
Expand Down Expand Up @@ -202,18 +206,44 @@ def link_forces(
if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model):
raise ValueError("The provided data is not valid for the model")

# Helper function to convert a single 6D force to the active representation.
def convert(f_L: jtp.Vector) -> jtp.Vector:
# Helper function to convert a single 6D force to the active representation
# considering as body the base link (i.e. B_f_L and BW_f_L).
def convert_using_base_frame(f_L: jtp.Vector) -> jtp.Vector:

return JaxSimModelReferences.inertial_to_other_representation(
array=f_L,
other_representation=self.velocity_representation,
transform=data.base_transform(),
is_force=True,
)

# Helper function to convert a single 6D force to the active representation
# considering as body the link (i.e. L_f_L and LW_f_L).
# Contrarily to convert_using_base_frame, this is already vectorized.
def convert_using_link_frame(
W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike
) -> jtp.Matrix:

return jax.vmap(
lambda W_f_L, W_H_L: JaxSimModelReferences.inertial_to_other_representation(
array=W_f_L,
other_representation=self.velocity_representation,
transform=W_H_L,
is_force=True,
)
)(W_f_L, W_H_L)

# Convert to the desired representation.
f_L = jax.vmap(convert)(W_f_L[link_idxs, :])
if use_link_frame:
# The f_L output is either L_f_L or LW_f_L, depending on the representation.
W_H_L = js.model.forward_kinematics(model=model, data=data)
f_L = convert_using_link_frame(
W_f_L=W_f_L[link_idxs, :], W_H_L=W_H_L[link_idxs, :, :]
)
return f_L

# The f_L output is either B_f_L or BW_f_L, depending on the representation.
f_L = jax.vmap(convert_using_base_frame)(W_f_L[link_idxs, :])
return f_L

def joint_force_references(
Expand Down Expand Up @@ -313,14 +343,17 @@ def replace(forces: jtp.VectorLike) -> JaxSimModelReferences:

return replace(forces=self.input.physics_model.tau.at[joint_idxs].set(forces))

@functools.partial(jax.jit, static_argnames=["link_names", "additive"])
@functools.partial(
jax.jit, static_argnames=["link_names", "additive", "use_link_frame"]
)
def apply_link_forces(
self,
forces: jtp.MatrixLike,
model: js.model.JaxSimModel | None = None,
data: js.data.JaxSimModelData | None = None,
link_names: tuple[str, ...] | None = None,
link_names: tuple[str, ...] | str | None = None,
additive: bool = False,
use_link_frame: bool = False,
) -> Self:
"""
Apply the link forces.
Expand All @@ -336,6 +369,9 @@ def apply_link_forces(
link_names: The names of the links corresponding to the forces.
additive:
Whether to add the forces to the existing ones instead of replacing them.
use_link_frame:
Whether to consider the frame of the link instead of the frame of the
base in body-fixed and mixed representations.

Returns:
A new `JaxSimModelReferences` object with the given link forces.
Expand All @@ -345,7 +381,7 @@ def apply_link_forces(
Then, we always convert and store forces in inertial-fixed representation.
"""

f_L = jnp.array(forces)
f_L = jnp.atleast_2d(forces).astype(float)

# Helper function to replace the link forces.
def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
Expand Down Expand Up @@ -380,6 +416,15 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:

# If we have the model, we can extract the link names if not provided.
link_names = link_names if link_names is not None else model.link_names()

# Make sure that the link names are a tuple if they are provided by the user.
link_names = (link_names,) if isinstance(link_names, str) else link_names

if len(link_names) != f_L.shape[0]:
msg = "The number of link names ({}) must match the number of forces ({})"
raise ValueError(msg.format(len(link_names), f_L.shape[0]))

# Extract the link indices.
link_idxs = js.link.names_to_idxs(link_names=link_names, model=model)

# Compute the bias depending on whether we either set or add the link forces.
Expand All @@ -405,16 +450,40 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
if not_tracing(forces) and not data.valid(model=model):
raise ValueError("The provided data is not valid for the model")

# Helper function to convert a single 6D force to the inertial representation.
def convert(f_L: jtp.Vector) -> jtp.Vector:
# Helper function to convert a single 6D force to the inertial representation
# considering as body the base link (i.e. B_f_L and BW_f_L).
def convert_using_base_frame(f_L: jtp.Vector) -> jtp.Vector:

return JaxSimModelReferences.other_representation_to_inertial(
array=f_L,
other_representation=self.velocity_representation,
transform=data.base_transform(),
is_force=True,
)

W_f_L = jax.vmap(convert)(f_L)
# Helper function to convert a single 6D force to the inertial representation
# considering as body the link (i.e. L_f_L and LW_f_L).
# Contrarily to convert_using_base_frame, this is already vectorized.
def convert_using_link_frame(
f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike
) -> jtp.Matrix:

return jax.vmap(
lambda f_L, W_H_L: JaxSimModelReferences.other_representation_to_inertial(
array=f_L,
other_representation=self.velocity_representation,
transform=W_H_L,
is_force=True,
)
)(f_L, W_H_L)

if use_link_frame:
# The f_L input is either L_f_L or LW_f_L, depending on the representation.
W_H_L = js.model.forward_kinematics(model=model, data=data)
W_f_L = convert_using_link_frame(f_L=f_L, W_H_L=W_H_L[link_idxs, :, :])
else:
# The f_L input is either B_f_L or BW_f_L, depending on the representation.
W_f_L = jax.vmap(convert_using_base_frame)(f_L)

return replace(
forces=self.input.physics_model.f_ext.at[link_idxs, :].set(W_f0_L + W_f_L)
Expand Down
Loading