Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve utilities provided by JaxsimDataclass #116

Merged
merged 9 commits into from
Mar 22, 2024
2 changes: 1 addition & 1 deletion src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def switch_velocity_representation(
# We run this in a mutable context with restoration so that any exception
# occurring, we restore the original object in case it was modified.
with self.mutable_context(
mutability=self._mutability(), restore_after_exception=True
mutability=self.mutability(), restore_after_exception=True
):
yield self

Expand Down
318 changes: 287 additions & 31 deletions src/jaxsim/utils/jaxsim_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import contextlib
import copy
import dataclasses
from typing import ClassVar, Generator
import functools
from collections.abc import Iterator
from typing import Any, Callable, ClassVar, Sequence, Type

import jax.flatten_util
import jax_dataclasses
Expand All @@ -19,91 +21,345 @@

@jax_dataclasses.pytree_dataclass
class JaxsimDataclass(abc.ABC):
""""""
"""Class extending `jax_dataclasses.pytree_dataclass` instances with utilities."""

# This attribute is set by jax_dataclasses
__mutability__: ClassVar[Mutability] = Mutability.FROZEN

@contextlib.contextmanager
def editable(self: Self, validate: bool = True) -> Generator[Self, None, None]:
""""""
def editable(self: Self, validate: bool = True) -> Iterator[Self]:
"""
Context manager to operate on a mutable copy of the object.

Args:
validate: Whether to validate the output PyTree upon exiting the context.

Yields:
A mutable copy of the object.

Note:
This context manager is useful to operate on an r/w copy of a PyTree making
sure that the output object does not trigger JIT recompilations.
"""

mutability = (
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
)

with JaxsimDataclass.mutable_context(self.copy(), mutability=mutability) as obj:
with self.copy().mutable_context(mutability=mutability) as obj:
yield obj

@contextlib.contextmanager
def mutable_context(
self: Self, mutability: Mutability, restore_after_exception: bool = True
) -> Generator[Self, None, None]:
""""""
) -> Iterator[Self]:
"""
Context manager to temporarily change the mutability of the object.

Args:
mutability: The mutability to set.
restore_after_exception:
Whether to restore the original object in case of an exception
occurring within the context.

Yields:
The object with the new mutability.

Note:
This context manager is useful to operate in place on a PyTree without
the need to make a copy while optionally keeping active the checks on
the PyTree structure, shapes, and dtypes.
"""

if restore_after_exception:
self_copy = self.copy()

original_mutability = self._mutability()
original_mutability = self.mutability()

def restore_self():
self._set_mutability(mutability=Mutability.MUTABLE)
original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
original_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
original_structure = jax.tree_util.tree_structure(tree=self)

def restore_self() -> None:
self.set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION)
for f in dataclasses.fields(self_copy):
setattr(self, f.name, getattr(self_copy, f.name))

try:
self._set_mutability(mutability)
self.set_mutability(mutability)
yield self

if mutability is not Mutability.MUTABLE_NO_VALIDATION:
new_structure = jax.tree_util.tree_structure(tree=self)
if original_structure != new_structure:
msg = "Pytree structure has changed from {} to {}"
raise ValueError(msg.format(original_structure, new_structure))

new_shapes = JaxsimDataclass.get_leaf_shapes(tree=self)
if original_shapes != new_shapes:
msg = "Leaves shapes have changed from {} to {}"
raise ValueError(msg.format(original_shapes, new_shapes))

new_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self)
if original_dtypes != new_dtypes:
msg = "Leaves dtypes have changed from {} to {}"
raise ValueError(msg.format(original_dtypes, new_dtypes))

new_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self)
if original_weak_types != new_weak_types:
msg = "Leaves weak types have changed from {} to {}"
raise ValueError(msg.format(original_weak_types, new_weak_types))

except Exception as e:
if restore_after_exception:
restore_self()
self._set_mutability(original_mutability)
self.set_mutability(original_mutability)
raise e

finally:
self._set_mutability(original_mutability)
self.set_mutability(original_mutability)

@staticmethod
def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]:
"""
Helper method to get the leaf shapes of a PyTree.

Args:
tree: The PyTree to consider.

Returns:
A tuple containing the leaf shapes of the PyTree or `None` is the leaf is
not a numpy-like array.
"""

return tuple( # noqa
leaf.shape if hasattr(leaf, "shape") else None
for leaf in jax.tree_util.tree_leaves(tree)
if hasattr(leaf, "shape")
)

@staticmethod
def get_leaf_dtypes(tree: jtp.PyTree) -> tuple:
"""
Helper method to get the leaf dtypes of a PyTree.

Args:
tree: The PyTree to consider.

Returns:
A tuple containing the leaf dtypes of the PyTree or `None` is the leaf is
not a numpy-like array.
"""

return tuple(
leaf.dtype if hasattr(leaf, "dtype") else None
for leaf in jax.tree_util.tree_leaves(tree)
if hasattr(leaf, "dtype")
)

@staticmethod
def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]:
"""
Helper method to get the leaf weak types of a PyTree.

Args:
tree: The PyTree to consider.

Returns:
A tuple marking whether the leaf contains a JAX array with weak type.
"""

return tuple(
leaf.weak_type if hasattr(leaf, "weak_type") else False
for leaf in jax.tree_util.tree_leaves(tree)
if hasattr(leaf, "weak_type")
)

@staticmethod
def check_compatibility(*trees: Sequence[Any]) -> None:
"""
Check whether the PyTrees are compatible in structure, shape, and dtype.

Args:
*trees: The PyTrees to compare.

Raises:
ValueError: If the PyTrees have incompatible structures, shapes, or dtypes.
"""

target_structure = jax.tree_util.tree_structure(trees[0])

compatible_structure = functools.reduce(
lambda compatible, tree: compatible
and jax.tree_util.tree_structure(tree) == target_structure,
trees[1:],
True,
)

if not compatible_structure:
raise ValueError("Pytrees have incompatible structures.")

target_shapes = JaxsimDataclass.get_leaf_shapes(trees[0])

compatible_shapes = functools.reduce(
lambda compatible, tree: compatible
and JaxsimDataclass.get_leaf_shapes(tree) == target_shapes,
trees[1:],
True,
)

if not compatible_shapes:
raise ValueError("Pytrees have incompatible shapes.")

target_dtypes = JaxsimDataclass.get_leaf_dtypes(trees[0])

compatible_dtypes = functools.reduce(
lambda compatible, tree: compatible
and JaxsimDataclass.get_leaf_dtypes(tree) == target_dtypes,
trees[1:],
True,
)

if not compatible_dtypes:
raise ValueError("Pytrees have incompatible dtypes.")

def is_mutable(self, validate: bool = False) -> bool:
""""""
"""
Check whether the object is mutable.

Args:
validate: Additionally checks if the object also has validation enabled.

Returns:
True if the object is mutable, False otherwise.
"""

return (
self.__mutability__ is Mutability.MUTABLE
if validate
else self.__mutability__ is Mutability.MUTABLE_NO_VALIDATION
)

def set_mutability(self, mutable: bool = True, validate: bool = False) -> None:
if not mutable:
mutability = Mutability.FROZEN
else:
mutability = (
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
)
def mutability(self) -> Mutability:
"""
Get the mutability type of the object.

self._set_mutability(mutability=mutability)
Returns:
The mutability type of the object.
"""

def _mutability(self) -> Mutability:
return self.__mutability__

def _set_mutability(self, mutability: Mutability) -> None:
def set_mutability(self, mutability: Mutability) -> None:
"""
Set the mutability of the object in-place.

Args:
mutability: The desired mutability type.
"""

jax_dataclasses._copy_and_mutate._mark_mutable(
self, mutable=mutability, visited=set()
)

def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self:
self.set_mutability(mutable=mutable, validate=validate)
"""
Return a mutable reference of the object.

Args:
mutable: Whether to make the object mutable.
validate: Whether to enable validation on the object.

Returns:
A mutable reference of the object.
"""

if mutable:
mutability = (
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
)
else:
mutability = Mutability.FROZEN

self.set_mutability(mutability=mutability)
return self

def copy(self: Self) -> Self:
"""
Return a copy of the object.

Returns:
A copy of the object.
"""

# Make a copy calling tree_map.
obj = jax.tree_util.tree_map(lambda leaf: leaf, self)
obj._set_mutability(mutability=self._mutability())

# Make sure that the copied object and all the copied leaves have the same
# mutability of the original object.
obj.set_mutability(mutability=self.mutability())

return obj

def replace(self: Self, validate: bool = True, **kwargs) -> Self:
with self.editable(validate=validate) as obj:
_ = [obj.__setattr__(k, copy.copy(v)) for k, v in kwargs.items()]
"""
Return a new object replacing in-place the specified fields with new values.

obj._set_mutability(mutability=self._mutability())
return obj
Args:
validate:
Whether to validate that the new fields do not alter the PyTree.
**kwargs: The fields to replace.

Returns:
A reference of the object with the specified fields replaced.
"""

mutability = (
Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION
)

with self.mutable_context(mutability=mutability) as obj:
_ = [obj.__setattr__(k, v) for k, v in kwargs.items()]

# Make sure that all the new leaves have the same mutability of the object.
obj.set_mutability(mutability=self.mutability())

# Return a shallow copy of the object with the new fields replaced.
# Note that the shallow copy of the original object contains exactly the same
# attributes of the original object (in other words, with the same id).
return copy.copy(obj)

def flatten(self) -> jtp.VectorJax:
return jax.flatten_util.ravel_pytree(self)[0]
"""
Flatten the object into a 1D vector.

Returns:
A 1D vector containing the flattened object.
"""

return self.flatten_fn()(self)

@classmethod
def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]:
"""
Return a function to flatten the object into a 1D vector.

Returns:
A function to flatten the object into a 1D vector.
"""

return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0]

def unflatten_fn(self: Self) -> Callable[[jtp.VectorJax], Self]:
"""
Return a function to unflatten a 1D vector into the object.

Returns:
A function to unflatten a 1D vector into the object.

Notes:
Due to JAX internals, the function to unflatten a PyTree needs to be
created from an existing instance of the PyTree.
"""
return jax.flatten_util.ravel_pytree(self)[1]
Loading