Skip to content

Commit

Permalink
Compute Jacobian derivative using batched math
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jan 7, 2025
1 parent 52c9a20 commit 1839965
Showing 1 changed file with 105 additions and 105 deletions.
210 changes: 105 additions & 105 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,152 +733,152 @@ def generalized_free_floating_jacobian_derivative(
# Compute the base transform.
W_H_B = data.base_transform()

@functools.partial(jax.vmap, in_axes=(0, None, None, 0))
def _compute_row(
B_H_L: jtp.Matrix,
B_J_full_WL_B: jtp.Matrix,
W_H_B: jtp.Matrix,
κb: jtp.Matrix,
) -> jtp.Matrix:
B_J_WL_B = jnp.where(
jnp.hstack([jnp.ones((κb.shape[0], 5)), κb])[:, jnp.newaxis],
B_J_full_WL_B[jnp.newaxis, :],
0.0,
)

# =====================================================
# Compute quantities to adjust the input representation
# =====================================================
B_J̇_WL_B = jnp.where(
jnp.hstack([jnp.ones((κb.shape[0], 5)), κb])[:, jnp.newaxis],
B_J̇_full_WX_B[jnp.newaxis, :],
0.0,
)

In = jnp.eye(model.dofs())
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
# =====================================================
# Compute quantities to adjust the input representation
# =====================================================

# Extract the link quantities using the boolean support body array.
B_J̇_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J̇_full_WX_B
B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WL_B
In = jnp.eye(model.dofs())
On = jnp.zeros(shape=(model.dofs(), model.dofs()))

match data.velocity_representation:
match data.velocity_representation:

case VelRepr.Inertial:
case VelRepr.Inertial:

B_X_W = jaxsim.math.Adjoint.from_transform(
transform=W_H_B, inverse=True
)
B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)

W_v_WB = data.base_velocity()
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)

# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_W, In)
= jax.scipy.linalg.block_diag(B_Ẋ_W, On)
W_v_WB = data.base_velocity()
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)

case VelRepr.Body:
# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_W, In)
= jax.scipy.linalg.block_diag(B_Ẋ_W, On)

B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation(
translation=jnp.zeros(3), rotation=jnp.eye(3)
)
case VelRepr.Body:

B_Ẋ_B = jnp.zeros(shape=(6, 6))
B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation(
translation=jnp.zeros(3), rotation=jnp.eye(3)
)

# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_B, In)
= jax.scipy.linalg.block_diag(B_Ẋ_B, On)
B_Ẋ_B = jnp.zeros(shape=(6, 6))

case VelRepr.Mixed:
# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_B, In)
= jax.scipy.linalg.block_diag(B_Ẋ_B, On)

BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = jaxsim.math.Adjoint.from_transform(
transform=BW_H_B, inverse=True
)
case VelRepr.Mixed:

BW_v_WB = data.base_velocity()
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)

BW_v_BW_B = BW_v_WB - BW_v_W_BW
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
BW_v_WB = data.base_velocity()
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))

# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_BW, In)
= jax.scipy.linalg.block_diag(B_Ẋ_BW, On)
BW_v_BW_B = BW_v_WB - BW_v_W_BW
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)

case _:
raise ValueError(data.velocity_representation)
# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_BW, In)
= jax.scipy.linalg.block_diag(B_Ẋ_BW, On)

# ======================================================
# Compute quantities to adjust the output representation
# ======================================================
case _:
raise ValueError(data.velocity_representation)

match output_vel_repr:
# ======================================================
# Compute quantities to adjust the output representation
# ======================================================

case VelRepr.Inertial:
match output_vel_repr:

O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)
case VelRepr.Inertial:

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)

O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()

case VelRepr.Body:
O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841

O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform(
transform=B_H_L, inverse=True
)
case VelRepr.Body:

B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)
O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform(
transform=B_H_L, inverse=True
)

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
L_v_WL = L_X_B @ B_J_WL_B @ data.generalized_velocity()
B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)

O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
B_X_L @ L_v_WL - B_v_WB
with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
L_v_WL = jnp.einsum(
"b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity()
)

case VelRepr.Mixed:

W_H_L = W_H_B @ B_H_L
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
jnp.einsum("bij,bj->bi", B_X_L, L_v_WL) - B_v_WB
)

O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B)
case VelRepr.Mixed:

B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)
W_H_L = W_H_B @ B_H_L
LW_H_L = W_H_L.at[:, 0:3, 3].set(jnp.zeros_like(W_H_L[:, 0:3, 3]))
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B)

with data.switch_velocity_representation(VelRepr.Mixed):
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
LW_v_WL = LW_X_B @ (
B_J_WL_B
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
@ data.generalized_velocity()
)
LW_v_W_LW = LW_v_WL.at[3:6].set(jnp.zeros(3))
B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)

LW_v_LW_L = LW_v_WL - LW_v_W_LW
LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L
with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()

O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
B_X_LW @ LW_v_B_LW
with data.switch_velocity_representation(VelRepr.Mixed):
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
LW_v_WL = jnp.einsum(
"bij,bj->bi",
LW_X_B,
B_J_WL_B
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
@ data.generalized_velocity(),
)
case _:
raise ValueError(output_vel_repr)

# =============================================================
# Express the Jacobian derivative in the target representations
# =============================================================
LW_v_W_LW = LW_v_WL.at[:, 3:6].set(jnp.zeros_like(LW_v_WL[:, 3:6]))

LW_v_LW_L = LW_v_WL - LW_v_W_LW
LW_v_B_LW = LW_v_WL - jnp.einsum("bij,j->bi", LW_X_B, B_v_WB) - LW_v_LW_L

O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
jnp.einsum("bij,bj->bi", B_X_LW, LW_v_B_LW)
)

case _:
raise ValueError(output_vel_repr)

# Sum all the components that form the Jacobian derivative in the target
# input/output velocity representations.
O_J̇_WL_I = jnp.zeros(shape=(6, 6 + model.dofs()))
O_J̇_WL_I += O_Ẋ_B @ B_J_WL_B @ T
O_J̇_WL_I += O_X_B @ B_J̇_WL_B @ T
O_J̇_WL_I += O_X_B @ B_J_WL_B @
# =============================================================
# Express the Jacobian derivative in the target representations
# =============================================================

return O_J̇_WL_I
# Sum all the components that form the Jacobian derivative in the target
# input/output velocity representations.
O_J̇_WL_I = jnp.zeros_like(B_J̇_WL_B)
O_J̇_WL_I += O_Ẋ_B @ B_J_WL_B @ T
O_J̇_WL_I += O_X_B @ B_J̇_WL_B @ T
O_J̇_WL_I += O_X_B @ B_J_WL_B @

return _compute_row(B_H_L, B_J_full_WL_B, W_H_B, κb)
return O_J̇_WL_I


@functools.partial(jax.jit, static_argnames=["prefer_aba"])
Expand Down

0 comments on commit 1839965

Please sign in to comment.