Skip to content

Commit

Permalink
Merge branch 'sprint/reintroduce_contact_models' of github.com:ami-ii…
Browse files Browse the repository at this point in the history
…t/jaxsim into sprint/branch-integration
  • Loading branch information
flferretti committed Jan 31, 2025
2 parents e5f5b65 + ccb1c2e commit ad9618d
Show file tree
Hide file tree
Showing 14 changed files with 1,310 additions and 63 deletions.
80 changes: 72 additions & 8 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,11 @@ def collidable_point_kinematics(
the linear component of the mixed 6D frame velocity.
"""

# Switch to inertial-fixed since the RBDAs expect velocities in this representation.
with data.switch_velocity_representation(VelRepr.Inertial):

W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
link_transforms=data._link_transforms,
link_velocities=data._link_velocities,
)
W_p_Ci, W_ṗ_Ci = jaxsim.rbda.collidable_points.collidable_points_pos_vel(
model=model,
link_transforms=data._link_transforms,
link_velocities=data._link_velocities,
)

return W_p_Ci, W_ṗ_Ci

Expand Down Expand Up @@ -164,15 +161,24 @@ def estimate_good_soft_contacts_parameters(
def estimate_good_contact_parameters(
model: js.model.JaxSimModel,
*,
standard_gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
static_friction_coefficient: jtp.FloatLike = 0.5,
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
**kwargs,
) -> jaxsim.rbda.contacts.ContactParamsTypes:
"""
Estimate good contact parameters.
Args:
model: The model to consider.
standard_gravity: The standard gravity acceleration.
static_friction_coefficient: The static friction coefficient.
number_of_active_collidable_points_steady_state:
The number of active collidable points in steady state.
damping_ratio: The damping ratio.
max_penetration: The maximum penetration allowed.
kwargs:
Additional model-specific parameters passed to the builder method of
the parameters class.
Expand All @@ -190,8 +196,66 @@ def estimate_good_contact_parameters(
specific application.
"""

def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
"""
Displacement between the CoM and the lowest collidable point using zero
joint positions.
"""

zero_data = js.data.JaxSimModelData.build(
model=model,
)

W_pz_CoM = js.com.com_position(model=model, data=zero_data)[2]

if model.floating_base():
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
return 2 * (W_pz_CoM - W_pz_C.min())

return 2 * W_pz_CoM

max_δ = (
max_penetration
if max_penetration is not None
# Consider as default a 0.5% of the model height.
else 0.005 * estimate_model_height(model=model)
)

nc = number_of_active_collidable_points_steady_state

match model.contact_model:

case contacts.SoftContacts():
assert isinstance(model.contact_model, contacts.SoftContacts)

parameters = contacts.SoftContactsParams.build_default_from_jaxsim_model(
model=model,
standard_gravity=standard_gravity,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
**kwargs,
)

case contacts.RigidContacts():
assert isinstance(model.contact_model, contacts.RigidContacts)

# Disable Baumgarte stabilization by default since it does not play
# well with the forward Euler integrator.
K = kwargs.get("K", 0.0)

parameters = contacts.RigidContactsParams.build(
mu=static_friction_coefficient,
**(
dict(
K=K,
D=2 * jnp.sqrt(K),
)
| kwargs
),
)

case contacts.RelaxedRigidContacts():
assert isinstance(model.contact_model, contacts.RelaxedRigidContacts)

Expand Down
14 changes: 9 additions & 5 deletions src/jaxsim/api/contact_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.rbda.contacts import SoftContacts


@jax.jit
Expand All @@ -15,7 +16,7 @@ def link_contact_forces(
*,
link_forces: jtp.MatrixLike | None = None,
joint_torques: jtp.VectorLike | None = None,
) -> jtp.Matrix:
) -> tuple[jtp.Matrix, dict[str, jtp.Matrix]]:
"""
Compute the 6D contact forces of all links of the model in inertial representation.
Expand All @@ -33,11 +34,14 @@ def link_contact_forces(
"""

# Compute the contact forces for each collidable point with the active contact model.
W_f_C, _ = model.contact_model.compute_contact_forces(
W_f_C, extended_contact_state = model.contact_model.compute_contact_forces(
model=model,
data=data,
link_forces=link_forces,
joint_force_references=joint_torques,
**(
dict(link_forces=link_forces, joint_force_references=joint_torques)
if not isinstance(model.contact_model, SoftContacts)
else {}
),
)

# Compute the 6D forces applied to the links equivalent to the forces applied
Expand All @@ -46,7 +50,7 @@ def link_contact_forces(
model=model, data=data, contact_forces=W_f_C
)

return W_f_L
return W_f_L, extended_contact_state


@staticmethod
Expand Down
30 changes: 25 additions & 5 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
_link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
_link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)

# Extended state for soft and rigid contact models.
contact_state: dict[str, jtp.Array] = dataclasses.field(default=None)

@staticmethod
def build(
model: js.model.JaxSimModel,
Expand All @@ -70,6 +73,8 @@ def build(
base_angular_velocity: jtp.VectorLike | None = None,
joint_velocities: jtp.VectorLike | None = None,
velocity_representation: VelRepr = VelRepr.Mixed,
*,
contact_state: dict[str, jtp.Array] | None = None,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with the given state.
Expand All @@ -85,6 +90,7 @@ def build(
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
velocity_representation: The velocity representation to use. It defaults to mixed if not provided.
contact_state: The optional contact state.
Returns:
A `JaxSimModelData` initialized with the given state.
Expand Down Expand Up @@ -167,18 +173,29 @@ def build(
)
)

contact_state = (
{
"tangential_deformation": jnp.zeros_like(
model.kin_dyn_parameters.contact_parameters.point
)
}
if isinstance(model.contact_model, jaxsim.rbda.contacts.SoftContacts)
else contact_state or {}
)

model_data = JaxSimModelData(
velocity_representation=velocity_representation,
_base_quaternion=base_quaternion,
_base_position=base_position,
_joint_positions=joint_positions,
_base_linear_velocity=v_WB[0:3],
_base_angular_velocity=v_WB[3:6],
_base_linear_velocity=W_v_WB[0:3],
_base_angular_velocity=W_v_WB[3:6],
_joint_velocities=joint_velocities,
_base_transform=W_H_B,
_joint_transforms=joint_transforms,
_link_transforms=link_transforms,
_link_velocities=link_velocities,
_link_velocities=link_velocities_inertial,
contact_state=contact_state or {},
)

if not model_data.valid(model=model):
Expand Down Expand Up @@ -386,6 +403,8 @@ def replace(
base_angular_velocity: jtp.Vector | None = None,
base_position: jtp.Vector | None = None,
velocity_representation: VelRepr | None = None,
*,
contact_state: dict[str, jtp.Array] | None = None,
validate: bool = False,
) -> Self:
"""
Expand Down Expand Up @@ -438,8 +457,8 @@ def replace(
base_quaternion=base_quaternion,
joint_positions=joint_positions,
joint_velocities=joint_velocities,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
base_linear_velocity_inertial=base_linear_velocity,
base_angular_velocity_inertial=base_angular_velocity,
)

return super().replace(
Expand All @@ -454,6 +473,7 @@ def replace(
_joint_transforms=joint_transforms,
_link_transforms=link_transforms,
_link_velocities=link_velocities,
contact_state=contact_state,
validate=validate,
)

Expand Down
9 changes: 9 additions & 0 deletions src/jaxsim/api/integrators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import jax.numpy as jnp

import jaxsim
Expand All @@ -12,6 +13,8 @@ def semi_implicit_euler_integration(
data: js.data.JaxSimModelData,
base_acceleration_inertial: jtp.Vector,
joint_accelerations: jtp.Vector,
*,
extended_contact_state: jtp.Vector,
) -> JaxSimModelData:
"""Integrate the system state using the semi-implicit Euler method."""
# Step the dynamics forward.
Expand Down Expand Up @@ -57,6 +60,11 @@ def semi_implicit_euler_integration(

new_joint_position = data.joint_positions + dt * new_joint_velocities

# Integrate the leaves of the contact state PyTree.
integrated_contact_state = jax.tree.map(
lambda x, x_dot: x + dt * x_dot, data.contact_state, extended_contact_state
)

data = data.replace(
model=model,
validate=True,
Expand All @@ -70,6 +78,7 @@ def semi_implicit_euler_integration(
# it's equivalent to the one in inertial representation
# See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
base_angular_velocity=base_ang_velocity_mixed,
contact_state=integrated_contact_state,
)

return data
Loading

0 comments on commit ad9618d

Please sign in to comment.