Skip to content

Commit

Permalink
make cached attrs not settable
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti authored and younik committed Jan 29, 2025
2 parents 15e402e + 7aff901 commit 7f926e7
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 30 deletions.
2 changes: 1 addition & 1 deletion docs/guide/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ The logging and exceptions configurations is controlled by the following environ

*Default:* ``DEBUG`` for development, ``WARNING`` for production.

- ``JAXSIM_DISABLE_EXCEPTIONS``: Disables the runtime checks and exceptions.
- ``JAXSIM_ENABLE_EXCEPTIONS``: Enables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required.

*Default:* ``False``.

Expand Down
60 changes: 44 additions & 16 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
base_position: jtp.Vector

# Cached computations.
base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None)
joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)
_base_transform: jtp.Matrix = dataclasses.field(repr=False, default=None)
_joint_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
_link_transforms: jtp.Matrix = dataclasses.field(repr=False, default=None)
_link_velocities: jtp.Matrix = dataclasses.field(repr=False, default=None)

@staticmethod
def build(
Expand Down Expand Up @@ -173,10 +173,10 @@ def build(
base_angular_velocity=v_WB[3:6],
joint_velocities=joint_velocities,
velocity_representation=velocity_representation,
base_transform=W_H_B,
joint_transforms=joint_transforms,
link_transforms=link_transforms,
link_velocities=link_velocities,
_base_transform=W_H_B,
_joint_transforms=joint_transforms,
_link_transforms=link_transforms,
_link_velocities=link_velocities,
)

if not model_data.valid(model=model):
Expand Down Expand Up @@ -209,6 +209,34 @@ def zero(
# Extract quantities
# ==================

@property
def base_transform(self) -> jtp.Matrix:
"""
Get the base transform.
"""
return self._base_transform

@property
def joint_transforms(self) -> jtp.Matrix:
"""
Get the joint transforms.
"""
return self._joint_transforms

@property
def link_transforms(self) -> jtp.Matrix:
"""
Get the link transforms.
"""
return self._link_transforms

@property
def link_velocities(self) -> jtp.Matrix:
"""
Get the link velocities.
"""
return self._link_velocities

@js.common.named_scope
@functools.partial(jax.jit, static_argnames=["dcm"])
def base_orientation(self, dcm: jtp.BoolLike = False) -> jtp.Vector | jtp.Matrix:
Expand Down Expand Up @@ -254,7 +282,7 @@ def base_velocity(self) -> jtp.Vector:
]
)

W_H_B = self.base_transform
W_H_B = self._base_transform

return (
JaxSimModelData.inertial_to_other_representation(
Expand All @@ -278,7 +306,7 @@ def generalized_position(self) -> tuple[jtp.Matrix, jtp.Vector]:
A tuple containing the base transform and the joint positions.
"""

return self.base_transform, self.joint_positions
return self._base_transform, self.joint_positions

@js.common.named_scope
@jax.jit
Expand Down Expand Up @@ -540,7 +568,7 @@ def reset_base_velocity(
W_v_WB = self.other_representation_to_inertial(
array=jnp.atleast_1d(base_velocity.squeeze()).astype(float),
other_representation=velocity_representation,
transform=self.base_transform,
transform=self._base_transform,
is_force=False,
)

Expand Down Expand Up @@ -594,7 +622,7 @@ def update_cached(self, model: js.model.JaxSimModel) -> JaxSimModelData:
)

joint_transforms = model.kin_dyn_parameters.joint_transforms(
joint_positions=self.joint_positions, base_transform=self.base_transform
joint_positions=self.joint_positions, base_transform=self._base_transform
)

link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model(
Expand All @@ -608,10 +636,10 @@ def update_cached(self, model: js.model.JaxSimModel) -> JaxSimModelData:
)

return self.replace(
base_transform=base_transform,
joint_transforms=joint_transforms,
link_transforms=link_transforms,
link_velocities=link_velocities,
_base_transform=base_transform,
_joint_transforms=joint_transforms,
_link_transforms=link_transforms,
_link_velocities=link_velocities,
)


Expand Down
14 changes: 7 additions & 7 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ class JaxSimModel(JaxsimDataclass):

model_name: Static[str]

time_step: jtp.FloatLike = dataclasses.field(
default_factory=lambda: jnp.array(0.001, dtype=float),
time_step: float = dataclasses.field(
default=0.001,
)

terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
Expand Down Expand Up @@ -91,7 +91,7 @@ def __hash__(self) -> int:
return hash(
(
hash(self.model_name),
hash(float(self.time_step)),
hash(self.time_step),
hash(self.kin_dyn_parameters),
hash(self.contact_model),
)
Expand Down Expand Up @@ -222,7 +222,7 @@ def build(
time_step = (
time_step
if time_step is not None
else JaxSimModel.__dataclass_fields__["time_step"].default_factory()
else JaxSimModel.__dataclass_fields__["time_step"].default
)

# Create the default contact model.
Expand Down Expand Up @@ -317,7 +317,7 @@ def floating_base(self) -> bool:
True if the model is floating-base, False otherwise.
"""

return bool(self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6)
return self.kin_dyn_parameters.joint_model.joint_dofs[0] == 6

def base_link(self) -> str:
"""
Expand Down Expand Up @@ -348,7 +348,7 @@ def dofs(self) -> int:
the number of joints. In the future, this could be different.
"""

return int(sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:]))
return sum(self.kin_dyn_parameters.joint_model.joint_dofs[1:])

def joint_names(self) -> tuple[str, ...]:
"""
Expand Down Expand Up @@ -431,7 +431,7 @@ def reduce(
for joint_name in set(model.joint_names()) - set(considered_joints):
j = intermediate_description.joints_dict[joint_name]
with j.mutable_context():
j.initial_position = float(locked_joint_positions.get(joint_name, 0.0))
j.initial_position = locked_joint_positions.get(joint_name, 0.0)

# Reduce the model description.
# If `considered_joints` contains joints not existing in the model,
Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def raise_if(

# Disable host callback if running on unsupported hardware or if the user
# explicitly disabled it.
if jax.devices()[0].platform in {"tpu", "METAL"} or os.environ.get(
"JAXSIM_DISABLE_EXCEPTIONS", 0
if jax.devices()[0].platform in {"tpu", "METAL"} or not os.environ.get(
"JAXSIM_ENABLE_EXCEPTIONS", 0
):
return

Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os

os.environ["JAXSIM_ENABLE_EXCEPTIONS"] = "1"

import pathlib
import subprocess

Expand Down
2 changes: 1 addition & 1 deletion tests/test_api_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_model_creation_and_reduction(
locked_joint_positions=dict(
zip(
model_full.joint_names(),
data_full.joint_positions,
data_full.joint_positions.tolist(),
strict=True,
)
),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def step(
model=model,
data=data_x0,
joint_force_references=τ,
link_forces=W_f_L,
link_forces_inertial=W_f_L,
)

xf_W_p_B = data_xf.base_position
Expand Down
4 changes: 2 additions & 2 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_box_with_external_forces(
data = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
link_forces_inertial=references._link_forces,
)

# Check that the box didn't move.
Expand Down Expand Up @@ -148,7 +148,7 @@ def test_box_with_zero_gravity(
data = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
link_forces_inertial=references.link_forces(model=model, data=data),
)

# Check that the box moved as expected.
Expand Down

0 comments on commit 7f926e7

Please sign in to comment.