Skip to content

Commit

Permalink
Merge branch 'sprint/caching' of github.com:ami-iit/jaxsim into sprin…
Browse files Browse the repository at this point in the history
…t/branch-integration
  • Loading branch information
flferretti committed Jan 31, 2025
2 parents 0234c4f + 3a0dc6c commit d4c73cf
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 19 deletions.
9 changes: 5 additions & 4 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
from collections.abc import Sequence
from typing import override

try:
from typing import override
except ImportError:
from typing_extensions import override

import jax
import jax.numpy as jnp
import jax.scipy.spatial.transform
Expand Down Expand Up @@ -402,7 +407,6 @@ def replace(
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
base_position: jtp.Vector | None = None,
velocity_representation: VelRepr | None = None,
*,
contact_state: dict[str, jtp.Array] | None = None,
validate: bool = False,
Expand All @@ -418,8 +422,6 @@ def replace(
base_quaternion = self.base_quaternion
if base_position is None:
base_position = self.base_position
if velocity_representation is None:
velocity_representation = self.velocity_representation

joint_positions = jnp.atleast_1d(joint_positions.squeeze()).astype(float)
joint_velocities = jnp.atleast_1d(joint_velocities.squeeze()).astype(float)
Expand Down Expand Up @@ -462,7 +464,6 @@ def replace(
)

return super().replace(
velocity_representation=velocity_representation,
_joint_positions=joint_positions,
_joint_velocities=joint_velocities,
_base_quaternion=base_quaternion,
Expand Down
25 changes: 13 additions & 12 deletions src/jaxsim/api/integrators.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import dataclasses

import jax
import jax.numpy as jnp

Expand All @@ -17,8 +19,7 @@ def semi_implicit_euler_integration(
extended_contact_state: jtp.Vector,
) -> JaxSimModelData:
"""Integrate the system state using the semi-implicit Euler method."""
# Step the dynamics forward.
velocity_representation = data.velocity_representation

with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):

dt = model.time_step
Expand Down Expand Up @@ -68,20 +69,20 @@ def semi_implicit_euler_integration(
is_leaf=lambda x: x is None,
)

data = data.replace(
model=model,
validate=True,
velocity_representation=velocity_representation,
base_quaternion=new_base_quaternion,
base_position=new_base_position,
joint_positions=new_joint_position,
joint_velocities=new_joint_velocities,
base_linear_velocity=base_lin_velocity_inertial,
# TODO: Avoid double replace, e.g. by computing cached value here
data = dataclasses.replace(
data,
_base_quaternion=new_base_quaternion,
_base_position=new_base_position,
_joint_positions=new_joint_position,
_joint_velocities=new_joint_velocities,
_base_linear_velocity=base_lin_velocity_inertial,
# Here we use the base angular velocity in mixed representation since
# it's equivalent to the one in inertial representation
# See: S. Traversaro and A. Saccon, “Multibody Dynamics Notation (Version 2), pg.9
base_angular_velocity=base_ang_velocity_mixed,
_base_angular_velocity=base_ang_velocity_mixed,
contact_state=integrated_contact_state,
)
data = data.replace(model=model) # update cache

return data
5 changes: 2 additions & 3 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,10 +2158,9 @@ def step(

# ne parliamo dopo
# Restore the input velocity representation
data_tf = data_tf.replace(
model=model,
data_tf = dataclasses.replace(
data_tf,
velocity_representation=data.velocity_representation,
validate=False,
)

return data_tf

0 comments on commit d4c73cf

Please sign in to comment.