-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add function to compute the jacobian derivative of collidable points #213
Conversation
4e40904
to
ddd08e8
Compare
285b2e3
to
70dc778
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @flferretti!! For me it's good to go! 🚀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks for the new functionality. I left some minor suggestion. In addition of what written here below.
I had a hunch of a possible speed-up. Right now you compute for no reason the
I suggest to compute the Jacobians of all links first, and then only get the one corresponding to the link you need. For a simple sphere with 250 points, the following refactor switched from 110 ms to 2.5 ms of runtime.
Suggestion
@functools.partial(jax.jit, static_argnames=["output_vel_repr"])
def jacobian_derivative(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
output_vel_repr: VelRepr | None = None,
) -> jtp.Matrix:
r"""
Compute the derivative of the free-floating jacobian of the contact points.
Args:
model: The model to consider.
data: The data of the considered model.
output_vel_repr:
The output velocity representation of the free-floating jacobian derivative.
Returns:
The derivative of the :math:`6 \times (6+n)` free-floating jacobian of the contact points.
Note:
The input representation of the free-floating jacobian derivative is the active
velocity representation.
"""
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)
# Get the index of the parent link and the position of the collidable point.
parent_link_idxs = jnp.array(model.kin_dyn_parameters.contact_parameters.body)
L_p_Ci = jnp.array(model.kin_dyn_parameters.contact_parameters.point)
contact_idxs = jnp.arange(L_p_Ci.shape[0])
# =====================================================
# Compute quantities to adjust the input representation
# =====================================================
def compute_T(model: js.model.JaxSimModel, X: jtp.Matrix) -> jtp.Matrix:
In = jnp.eye(model.dofs())
T = jax.scipy.linalg.block_diag(X, In)
return T
def compute_Ṫ(model: js.model.JaxSimModel, Ẋ: jtp.Matrix) -> jtp.Matrix:
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
Ṫ = jax.scipy.linalg.block_diag(Ẋ, On)
return Ṫ
# Compute the operator to change the representation of ν, and its
# time derivative.
match data.velocity_representation:
case VelRepr.Inertial:
W_H_W = jnp.eye(4)
W_X_W = Adjoint.from_transform(transform=W_H_W)
W_Ẋ_W = jnp.zeros((6, 6))
T = compute_T(model=model, X=W_X_W)
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_W)
case VelRepr.Body:
W_H_B = data.base_transform()
W_X_B = Adjoint.from_transform(transform=W_H_B)
B_v_WB = data.base_velocity()
B_vx_WB = Cross.vx(B_v_WB)
W_Ẋ_B = W_X_B @ B_vx_WB
T = compute_T(model=model, X=W_X_B)
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_B)
case VelRepr.Mixed:
W_H_B = data.base_transform()
W_H_BW = W_H_B.at[0:3, 0:3].set(jnp.eye(3))
W_X_BW = Adjoint.from_transform(transform=W_H_BW)
BW_v_WB = data.base_velocity()
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
BW_vx_W_BW = Cross.vx(BW_v_W_BW)
W_Ẋ_BW = W_X_BW @ BW_vx_W_BW
T = compute_T(model=model, X=W_X_BW)
Ṫ = compute_Ṫ(model=model, Ẋ=W_Ẋ_BW)
case _:
raise ValueError(data.velocity_representation)
# =====================================================
# Compute quantities to adjust the output representation
# =====================================================
with data.switch_velocity_representation(VelRepr.Inertial):
# Compute the Jacobian of the parent link in inertial representation.
W_J_WL_W = js.model.generalized_free_floating_jacobian(
model=model,
data=data,
output_vel_repr=VelRepr.Inertial,
)
# Compute the Jacobian derivative of the parent link in inertial representation.
W_J̇_WL_W = js.model.generalized_free_floating_jacobian_derivative(
model=model,
data=data,
output_vel_repr=VelRepr.Inertial,
)
def compute_O_J̇_WC_I(
L_p_C: jtp.Vector,
contact_idx: jtp.Int,
) -> jtp.Matrix:
parent_link_idx = parent_link_idxs[contact_idx]
match output_vel_repr:
case VelRepr.Inertial:
O_X_W = W_X_W = Adjoint.from_transform(transform=jnp.eye(4))
O_Ẋ_W = W_Ẋ_W = jnp.zeros((6, 6))
case VelRepr.Body:
W_H_L = js.link.transform(
model=model, data=data, link_index=parent_link_idx
)
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
W_H_C = W_H_L @ L_H_C
O_X_W = C_X_W = Adjoint.from_transform(transform=W_H_C, inverse=True)
with data.switch_velocity_representation(VelRepr.Inertial):
W_nu = data.generalized_velocity()
W_v_WC = W_J_WL_W[parent_link_idx] @ W_nu
W_vx_WC = Cross.vx(W_v_WC)
O_Ẋ_W = C_Ẋ_W = -C_X_W @ W_vx_WC
case VelRepr.Mixed:
W_H_L = js.link.transform(
model=model, data=data, link_index=parent_link_idx
)
L_H_C = Transform.from_rotation_and_translation(translation=L_p_C)
W_H_C = W_H_L @ L_H_C
W_H_CW = W_H_C.at[0:3, 0:3].set(jnp.eye(3))
CW_H_W = Transform.inverse(W_H_CW)
O_X_W = CW_X_W = Adjoint.from_transform(transform=CW_H_W)
with data.switch_velocity_representation(VelRepr.Mixed):
CW_J_WC_CW = jacobian(
model=model,
data=data,
output_vel_repr=VelRepr.Mixed,
)[contact_idx]
CW_v_WC = CW_J_WC_CW @ data.generalized_velocity()
W_v_W_CW = jnp.zeros(6).at[0:3].set(CW_v_WC[0:3])
W_vx_W_CW = Cross.vx(W_v_W_CW)
O_Ẋ_W = CW_Ẋ_W = -CW_X_W @ W_vx_W_CW
case _:
raise ValueError(output_vel_repr)
O_J̇_WC_I = jnp.zeros(shape=(6, 6 + model.dofs()))
O_J̇_WC_I += O_Ẋ_W @ W_J_WL_W[parent_link_idx] @ T
O_J̇_WC_I += O_X_W @ W_J̇_WL_W[parent_link_idx] @ T
O_J̇_WC_I += O_X_W @ W_J_WL_W[parent_link_idx] @ Ṫ
return O_J̇_WC_I
O_J̇_WC = jax.vmap(compute_O_J̇_WC_I)(L_p_Ci, contact_idxs)
return O_J̇_WC
Thanks a lot @xela-95 and @diegoferigo! I also added |
Yep, I didn't realize the I had it just locally from old experiments :) |
1da04cb
to
88c5d4a
Compare
885d7d2
to
3ce64f3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Last minor suggestions, not necessary for merging.
Co-authored-by: Diego Ferigo <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>
c3bbe35
to
7f82bac
Compare
This PR is basically a copy-paste of the logic introduced by @xela-95 in #208. The only difference is that here the computation is vmapped over all the collidable points.
📚 Documentation preview 📚: https://jaxsim--213.org.readthedocs.build//213/