Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute free-floating Jacobian derivative using batched math operations #330

Merged
merged 7 commits into from
Jan 10, 2025
Merged
207 changes: 102 additions & 105 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,152 +736,149 @@ def generalized_free_floating_jacobian_derivative(
# Compute the base transform.
W_H_B = data.base_transform()

@functools.partial(jax.vmap, in_axes=(0, None, None, 0))
def _compute_row(
B_H_L: jtp.Matrix,
B_J_full_WL_B: jtp.Matrix,
W_H_B: jtp.Matrix,
κb: jtp.Matrix,
) -> jtp.Matrix:
# We add the 5 columns of ones to the Jacobian derivative to account for the
# base velocity and acceleration (5 + number of links = 6 + number of joints).
B_J̇_WL_B = (
jnp.hstack([jnp.ones((κb.shape[0], 5)), κb])[:, jnp.newaxis] * B_J̇_full_WX_B
)
B_J_WL_B = (
jnp.hstack([jnp.ones((κb.shape[0], 5)), κb])[:, jnp.newaxis] * B_J_full_WL_B
)

# =====================================================
# Compute quantities to adjust the input representation
# =====================================================
# =====================================================
# Compute quantities to adjust the input representation
# =====================================================

In = jnp.eye(model.dofs())
On = jnp.zeros(shape=(model.dofs(), model.dofs()))
In = jnp.eye(model.dofs())
On = jnp.zeros(shape=(model.dofs(), model.dofs()))

# Extract the link quantities using the boolean support body array.
B_J̇_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J̇_full_WX_B
B_J_WL_B = jnp.hstack([jnp.ones(5), κb]) * B_J_full_WL_B
match data.velocity_representation:

match data.velocity_representation:
case VelRepr.Inertial:

case VelRepr.Inertial:
B_X_W = jaxsim.math.Adjoint.from_transform(transform=W_H_B, inverse=True)

B_X_W = jaxsim.math.Adjoint.from_transform(
transform=W_H_B, inverse=True
)
W_v_WB = data.base_velocity()
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)

W_v_WB = data.base_velocity()
B_Ẋ_W = -B_X_W @ jaxsim.math.Cross.vx(W_v_WB)
# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_W, In)
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On)

# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_W, In)
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_W, On)
case VelRepr.Body:

case VelRepr.Body:
B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation(
translation=jnp.zeros(3), rotation=jnp.eye(3)
)

B_X_B = jaxsim.math.Adjoint.from_rotation_and_translation(
translation=jnp.zeros(3), rotation=jnp.eye(3)
)
B_Ẋ_B = jnp.zeros(shape=(6, 6))

B_Ẋ_B = jnp.zeros(shape=(6, 6))
# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_B, In)
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On)

# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_B, In)
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_B, On)
case VelRepr.Mixed:

case VelRepr.Mixed:
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = jaxsim.math.Adjoint.from_transform(transform=BW_H_B, inverse=True)

BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = jaxsim.math.Adjoint.from_transform(
transform=BW_H_B, inverse=True
)
BW_v_WB = data.base_velocity()
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))

BW_v_WB = data.base_velocity()
BW_v_W_BW = BW_v_WB.at[3:6].set(jnp.zeros(3))
BW_v_BW_B = BW_v_WB - BW_v_W_BW
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)

BW_v_BW_B = BW_v_WB - BW_v_W_BW
B_Ẋ_BW = -B_X_BW @ jaxsim.math.Cross.vx(BW_v_BW_B)
# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_BW, In)
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On)

# Compute the operator to change the representation of ν, and its
# time derivative.
T = jax.scipy.linalg.block_diag(B_X_BW, In)
Ṫ = jax.scipy.linalg.block_diag(B_Ẋ_BW, On)
case _:
raise ValueError(data.velocity_representation)

case _:
raise ValueError(data.velocity_representation)
# ======================================================
# Compute quantities to adjust the output representation
# ======================================================

# ======================================================
# Compute quantities to adjust the output representation
# ======================================================
match output_vel_repr:

match output_vel_repr:
case VelRepr.Inertial:

case VelRepr.Inertial:
O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)

O_X_B = W_X_B = jaxsim.math.Adjoint.from_transform(transform=W_H_B)
with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841

O_Ẋ_B = W_Ẋ_B = W_X_B @ jaxsim.math.Cross.vx(B_v_WB) # noqa: F841
case VelRepr.Body:

case VelRepr.Body:
O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform(
transform=B_H_L, inverse=True
)

O_X_B = L_X_B = jaxsim.math.Adjoint.from_transform(
transform=B_H_L, inverse=True
)
B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)

B_X_L = jaxsim.math.Adjoint.inverse(adjoint=L_X_B)
with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
L_v_WL = jnp.einsum(
"b6j,j->b6", L_X_B @ B_J_WL_B, data.generalized_velocity()
)

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
L_v_WL = L_X_B @ B_J_WL_B @ data.generalized_velocity()
O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
jnp.einsum("bij,bj->bi", B_X_L, L_v_WL) - B_v_WB
)

O_Ẋ_B = L_Ẋ_B = -L_X_B @ jaxsim.math.Cross.vx( # noqa: F841
B_X_L @ L_v_WL - B_v_WB
)
case VelRepr.Mixed:

case VelRepr.Mixed:
W_H_L = W_H_B @ B_H_L
LW_H_L = W_H_L.at[:, 0:3, 3].set(jnp.zeros_like(W_H_L[:, 0:3, 3]))
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)

W_H_L = W_H_B @ B_H_L
LW_H_L = W_H_L.at[0:3, 3].set(jnp.zeros(3))
LW_H_B = LW_H_L @ jaxsim.math.Transform.inverse(B_H_L)
O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B)

O_X_B = LW_X_B = jaxsim.math.Adjoint.from_transform(transform=LW_H_B)
B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)

B_X_LW = jaxsim.math.Adjoint.inverse(adjoint=LW_X_B)
with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()

with data.switch_velocity_representation(VelRepr.Body):
B_v_WB = data.base_velocity()
with data.switch_velocity_representation(VelRepr.Mixed):
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
LW_v_WL = jnp.einsum(
"bij,bj->bi",
LW_X_B,
B_J_WL_B
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
@ data.generalized_velocity(),
)

with data.switch_velocity_representation(VelRepr.Mixed):
BW_H_B = W_H_B.at[0:3, 3].set(jnp.zeros(3))
B_X_BW = Adjoint.from_transform(transform=BW_H_B, inverse=True)
LW_v_WL = LW_X_B @ (
B_J_WL_B
@ jax.scipy.linalg.block_diag(B_X_BW, jnp.eye(model.dofs()))
@ data.generalized_velocity()
)
LW_v_W_LW = LW_v_WL.at[3:6].set(jnp.zeros(3))
LW_v_W_LW = LW_v_WL.at[:, 3:6].set(jnp.zeros_like(LW_v_WL[:, 3:6]))

LW_v_LW_L = LW_v_WL - LW_v_W_LW
LW_v_B_LW = LW_v_WL - LW_X_B @ B_v_WB - LW_v_LW_L
LW_v_LW_L = LW_v_WL - LW_v_W_LW
LW_v_B_LW = LW_v_WL - jnp.einsum("bij,j->bi", LW_X_B, B_v_WB) - LW_v_LW_L

O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
B_X_LW @ LW_v_B_LW
)
case _:
raise ValueError(output_vel_repr)
O_Ẋ_B = LW_Ẋ_B = -LW_X_B @ jaxsim.math.Cross.vx( # noqa: F841
jnp.einsum("bij,bj->bi", B_X_LW, LW_v_B_LW)
)

# =============================================================
# Express the Jacobian derivative in the target representations
# =============================================================
case _:
raise ValueError(output_vel_repr)

# Sum all the components that form the Jacobian derivative in the target
# input/output velocity representations.
O_J̇_WL_I = jnp.zeros(shape=(6, 6 + model.dofs()))
O_J̇_WL_I += O_Ẋ_B @ B_J_WL_B @ T
O_J̇_WL_I += O_X_B @ B_J̇_WL_B @ T
O_J̇_WL_I += O_X_B @ B_J_WL_B @ Ṫ
# =============================================================
# Express the Jacobian derivative in the target representations
# =============================================================

return O_J̇_WL_I
# Sum all the components that form the Jacobian derivative in the target
# input/output velocity representations.
O_J̇_WL_I = jnp.zeros_like(B_J̇_WL_B)
O_J̇_WL_I += O_Ẋ_B @ B_J_WL_B @ T
O_J̇_WL_I += O_X_B @ B_J̇_WL_B @ T
O_J̇_WL_I += O_X_B @ B_J_WL_B @ Ṫ

return _compute_row(B_H_L, B_J_full_WL_B, W_H_B, κb)
return O_J̇_WL_I


@functools.partial(jax.jit, static_argnames=["prefer_aba"])
Expand Down
24 changes: 14 additions & 10 deletions src/jaxsim/math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,13 @@ def from_transform(transform: jtp.MatrixLike, inverse: bool = False) -> jtp.Matr
The 6x6 adjoint matrix.
"""

A_H_B = jnp.array(transform).astype(float)
assert transform.shape == (4, 4)
A_H_B = jnp.reshape(transform, (-1, 4, 4))

return (
jaxlie.SE3.from_matrix(matrix=A_H_B).adjoint()
if not inverse
else jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().adjoint()
)
).reshape(transform.shape[:-2] + (6, 6))

@staticmethod
def from_rotation_and_translation(
Expand Down Expand Up @@ -145,13 +144,18 @@ def inverse(adjoint: jtp.Matrix) -> jtp.Matrix:
Returns:
jtp.Matrix: The inverse adjoint matrix.
"""
A_X_B = adjoint
A_X_B = adjoint.reshape(-1, 6, 6)

A_R_B = A_X_B[0:3, 0:3]
A_R_B_T = jnp.swapaxes(A_X_B[..., 0:3, 0:3], -2, -1)
A_T_B = A_X_B[..., 0:3, 3:6]

return jnp.vstack(
return jnp.concatenate(
[
jnp.block([A_R_B.T, -A_R_B.T @ A_X_B[0:3, 3:6] @ A_R_B.T]),
jnp.block([jnp.zeros(shape=(3, 3)), A_R_B.T]),
]
)
jnp.concatenate(
[A_R_B_T, -A_R_B_T @ A_T_B @ A_R_B_T],
axis=-1,
),
jnp.concatenate([jnp.zeros_like(A_R_B_T), A_R_B_T], axis=-1),
],
axis=-2,
).reshape(adjoint.shape)
15 changes: 10 additions & 5 deletions src/jaxsim/math/cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@ def vx(velocity_sixd: jtp.Vector) -> jtp.Matrix:
Raises:
ValueError: If the input vector does not have a size of 6.
"""
v, ω = jnp.split(velocity_sixd.squeeze(), 2)
velocity_sixd = velocity_sixd.reshape(-1, 6)

v_cross = jnp.vstack(
v, ω = jnp.split(velocity_sixd, 2, axis=-1)

v_cross = jnp.concatenate(
[
jnp.block([Skew.wedge(vector=ω), Skew.wedge(vector=v)]),
jnp.block([jnp.zeros(shape=(3, 3)), Skew.wedge(vector=ω)]),
]
jnp.concatenate(
[Skew.wedge(ω), jnp.zeros((ω.shape[0], 3, 3)).squeeze()], axis=-2
),
jnp.concatenate([Skew.wedge(v), Skew.wedge(ω)], axis=-2),
],
axis=-1,
)

return v_cross
Expand Down
17 changes: 14 additions & 3 deletions src/jaxsim/math/skew.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,20 @@ def wedge(vector: jtp.Vector) -> jtp.Matrix:
jtp.Matrix: The skew-symmetric matrix corresponding to the input vector.

"""
vector = vector.squeeze()
x, y, z = vector
skew = jnp.array([[0, -z, y], [z, 0, -x], [-y, x, 0]])

vector = vector.reshape(-1, 3)

x, y, z = jnp.split(vector, 3, axis=-1)

skew = jnp.stack(
[
jnp.concatenate([jnp.zeros_like(x), -z, y], axis=-1),
jnp.concatenate([z, jnp.zeros_like(x), -x], axis=-1),
jnp.concatenate([-y, x, jnp.zeros_like(x)], axis=-1),
],
axis=-2,
).squeeze()

return skew

@staticmethod
Expand Down
10 changes: 7 additions & 3 deletions src/jaxsim/math/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,11 @@ def inverse(transform: jtp.MatrixLike) -> jtp.Matrix:
The 4x4 inverse transformation matrix.
"""

A_H_B = jnp.array(transform).astype(float)
assert A_H_B.shape == (4, 4)
A_H_B = jnp.reshape(transform, (-1, 4, 4))

return jaxlie.SE3.from_matrix(matrix=A_H_B).inverse().as_matrix()
return (
jaxlie.SE3.from_matrix(matrix=A_H_B)
.inverse()
.as_matrix()
.reshape(transform.shape[:-2] + (4, 4))
)