Skip to content

Commit

Permalink
Remove wrappers module and custom hash methods
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jan 28, 2025
1 parent ccb7f27 commit 10dac08
Show file tree
Hide file tree
Showing 15 changed files with 338 additions and 670 deletions.
46 changes: 13 additions & 33 deletions src/jaxsim/api/kin_dyn_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Inertia, JointModel, supported_joint_motion
from jaxsim.parsers.descriptions import JointDescription, JointType, ModelDescription
from jaxsim.utils import HashedNumpyArray, JaxsimDataclass
from jaxsim.utils import JaxsimDataclass


@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
Expand All @@ -34,9 +34,9 @@ class KinDynParameters(JaxsimDataclass):

# Static
link_names: Static[tuple[str]]
_parent_array: Static[HashedNumpyArray]
_support_body_array_bool: Static[HashedNumpyArray]
_motion_subspaces: Static[HashedNumpyArray]
_parent_array: Static[tuple[int]]
_support_body_array_bool: Static[tuple[int]]
_motion_subspaces: Static[tuple[float]]

# Links
link_parameters: LinkParameters
Expand All @@ -56,21 +56,21 @@ def motion_subspaces(self) -> jtp.Matrix:
r"""
Return the motion subspaces :math:`\mathbf{S}(s)` of the joints.
"""
return self._motion_subspaces.get()
return jnp.array(self._motion_subspaces, dtype=float)

@property
def parent_array(self) -> jtp.Vector:
r"""
Return the parent array :math:`\lambda(i)` of the model.
"""
return self._parent_array.get()
return jnp.array(self._parent_array, dtype=int)

@property
def support_body_array_bool(self) -> jtp.Matrix:
r"""
Return the boolean support parent array :math:`\kappa_{b}(i)` of the model.
"""
return self._support_body_array_bool.get()
return jnp.array(self._support_body_array_bool, dtype=int)

@staticmethod
def build(model_description: ModelDescription) -> KinDynParameters:
Expand Down Expand Up @@ -227,8 +227,8 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:

S = {
JointType.Fixed: np.zeros(shape=(6, 1)),
JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis.axis])),
JointType.Prismatic: np.vstack(np.hstack([axis.axis, np.zeros(3)])),
JointType.Revolute: np.vstack(np.hstack([np.zeros(3), axis])),
JointType.Prismatic: np.vstack(np.hstack([axis, np.zeros(3)])),
}

return S[joint_type]
Expand All @@ -254,36 +254,16 @@ def motion_subspace(joint_type: int, axis: npt.ArrayLike) -> npt.ArrayLike:

return KinDynParameters(
link_names=tuple(l.name for l in ordered_links),
_parent_array=HashedNumpyArray(array=parent_array),
_support_body_array_bool=HashedNumpyArray(array=support_body_array_bool),
_motion_subspaces=HashedNumpyArray(array=motion_subspaces),
_parent_array=tuple(parent_array.tolist()),
_support_body_array_bool=tuple(support_body_array_bool.tolist()),
_motion_subspaces=tuple(motion_subspaces.tolist()),
link_parameters=link_parameters,
joint_model=joint_model,
joint_parameters=joint_parameters,
contact_parameters=contact_parameters,
frame_parameters=frame_parameters,
)

def __eq__(self, other: KinDynParameters) -> bool:

if not isinstance(other, KinDynParameters):
return False

return hash(self) == hash(other)

def __hash__(self) -> int:

return hash(
(
hash(self.number_of_links()),
hash(self.number_of_joints()),
hash(self.frame_parameters.name),
hash(self.frame_parameters.body),
hash(self._parent_array),
hash(self._support_body_array_bool),
)
)

# =============================
# Helpers to extract parameters
# =============================
Expand Down Expand Up @@ -409,7 +389,7 @@ def joint_transforms(
pre_H_suc_J = jax.vmap(supported_joint_motion)(
joint_types=jnp.array(self.joint_model.joint_types[1:]).astype(int),
joint_positions=jnp.array(joint_positions),
joint_axes=jnp.array([j.axis for j in self.joint_model.joint_axis]),
joint_axes=jnp.array(self.joint_model.joint_axis),
)

# Extract the transforms and motion subspaces of the joints.
Expand Down
46 changes: 10 additions & 36 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jaxsim.typing as jtp
from jaxsim.math import Adjoint, Cross
from jaxsim.parsers.descriptions import ModelDescription
from jaxsim.utils import JaxsimDataclass, Mutability, wrappers
from jaxsim.utils import JaxsimDataclass, Mutability

from .common import VelRepr

Expand Down Expand Up @@ -59,43 +59,16 @@ class JaxSimModel(JaxsimDataclass):
default=None, repr=False
)

_description: Static[wrappers.HashlessObject[ModelDescription | None]] = (
dataclasses.field(default=None, repr=False)
_description: Static[ModelDescription | None] = dataclasses.field(
default=None, repr=False
)

@property
def description(self) -> ModelDescription:
"""
Return the model description.
"""
return self._description.get()

def __eq__(self, other: JaxSimModel) -> bool:

if not isinstance(other, JaxSimModel):
return False

if self.model_name != other.model_name:
return False

if self.time_step != other.time_step:
return False

if self.kin_dyn_parameters != other.kin_dyn_parameters:
return False

return True

def __hash__(self) -> int:

return hash(
(
hash(self.model_name),
hash(self.time_step),
hash(self.kin_dyn_parameters),
hash(self.contact_model),
)
)
return self._description

# ========================
# Initialization and state
Expand Down Expand Up @@ -252,7 +225,7 @@ def build(
# don't want to trigger recompilation if it changes. All relevant parameters
# needed to compute kinematics and dynamics quantities are stored in the
# kin_dyn_parameters attribute.
_description=wrappers.HashlessObject(obj=model_description),
_description=model_description,
)

return model
Expand Down Expand Up @@ -423,15 +396,16 @@ def reduce(

# Operate on a deep copy of the model description in order to prevent problems
# when mutable attributes are updated.
intermediate_description = copy.deepcopy(model.description)
intermediate_description = copy.deepcopy(model._description)

# Update the initial position of the joints.
# This is necessary to compute the correct pose of the link pairs connected
# to removed joints.
for joint_name in set(model.joint_names()) - set(considered_joints):
j = intermediate_description.joints_dict[joint_name]
with j.mutable_context():
j.initial_position = locked_joint_positions.get(joint_name, 0.0)
intermediate_description.joints_dict[joint_name] = dataclasses.replace(
intermediate_description.joints_dict[joint_name],
_initial_position=float(locked_joint_positions.get(joint_name, 0.0)),
)

# Reduce the model description.
# If `considered_joints` contains joints not existing in the model,
Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/math/joint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import jaxsim.typing as jtp
from jaxsim.math import Rotation
from jaxsim.parsers.descriptions import JointGenericAxis, JointType, ModelDescription
from jaxsim.parsers.descriptions import JointType, ModelDescription
from jaxsim.parsers.kinematic_graph import KinematicGraphTransforms


Expand Down Expand Up @@ -39,7 +39,7 @@ class JointModel:
joint_dofs: Static[tuple[int, ...]]
joint_names: Static[tuple[str, ...]]
joint_types: Static[tuple[int, ...]]
joint_axis: Static[tuple[JointGenericAxis, ...]]
joint_axis: Static[tuple[tuple[int]]]

@staticmethod
def build(description: ModelDescription) -> JointModel:
Expand Down Expand Up @@ -108,7 +108,7 @@ def build(description: ModelDescription) -> JointModel:
joint_dofs=tuple([base_dofs] + [1 for _ in ordered_joints]),
joint_names=tuple(["world_to_base"] + [j.name for j in ordered_joints]),
joint_types=tuple([JointType.Fixed] + [j.jtype for j in ordered_joints]),
joint_axis=tuple(JointGenericAxis(axis=j.axis) for j in ordered_joints),
joint_axis=tuple(tuple(j.axis.tolist()) for j in ordered_joints),
)

def parent_H_predecessor(self, joint_index: jtp.IntLike) -> jtp.Matrix:
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/parsers/descriptions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
MeshCollision,
SphereCollision,
)
from .joint import JointDescription, JointGenericAxis, JointType
from .joint import JointDescription, JointType
from .link import LinkDescription
from .model import ModelDescription
99 changes: 29 additions & 70 deletions src/jaxsim/parsers/descriptions/collision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import abc
import dataclasses

import jax.numpy as jnp
import numpy as np
import numpy.typing as npt

import jaxsim.typing as jtp
from jaxsim import logging

from .link import LinkDescription
Expand All @@ -25,8 +23,28 @@ class CollidablePoint:
"""

parent_link: LinkDescription
position: npt.NDArray = dataclasses.field(default_factory=lambda: np.zeros(3))
enabled: bool = True
_position: tuple[float] = dataclasses.field(default=(0.0, 0.0, 0.0))

@property
def position(self) -> npt.NDArray:
"""
Get the position of the collidable point.
Returns:
The position of the collidable point.
"""
return np.array(self._position)

@position.setter
def position(self, value: npt.NDArray) -> None:
"""
Set the position of the collidable point.
Args:
value: The new position of the collidable point.
"""
self._position = tuple(value.tolist())

def change_link(
self, new_link: LinkDescription, new_H_old: npt.NDArray
Expand All @@ -35,8 +53,8 @@ def change_link(
Move the collidable point to a new parent link.
Args:
new_link (LinkDescription): The new parent link to which the collidable point is moved.
new_H_old (npt.NDArray): The transformation matrix from the new link's frame to the old link's frame.
new_link: The new parent link to which the collidable point is moved.
new_H_old: The transformation matrix from the new link's frame to the old link's frame.
Returns:
CollidablePoint: A new collidable point associated with the new parent link.
Expand All @@ -47,27 +65,12 @@ def change_link(

return CollidablePoint(
parent_link=new_link,
position=(new_H_old @ jnp.hstack([self.position, 1.0])).squeeze()[0:3],
_position=tuple(
(new_H_old @ np.hstack([self.position, 1.0])).squeeze()[0:3].tolist()
),
enabled=self.enabled,
)

def __hash__(self) -> int:

return hash(
(
hash(self.parent_link),
hash(tuple(self.position.tolist())),
hash(self.enabled),
)
)

def __eq__(self, other: CollidablePoint) -> bool:

if not isinstance(other, CollidablePoint):
return False

return hash(self) == hash(other)

def __str__(self) -> str:
return (
f"{self.__class__.__name__}("
Expand Down Expand Up @@ -107,22 +110,7 @@ class BoxCollision(CollisionShape):
center: The center of the box in the local frame of the collision shape.
"""

center: jtp.VectorLike

def __hash__(self) -> int:
return hash(
(
hash(super()),
hash(tuple(self.center.tolist())),
)
)

def __eq__(self, other: BoxCollision) -> bool:

if not isinstance(other, BoxCollision):
return False

return hash(self) == hash(other)
center: tuple[float, float, float]


@dataclasses.dataclass
Expand All @@ -134,22 +122,7 @@ class SphereCollision(CollisionShape):
center: The center of the sphere in the local frame of the collision shape.
"""

center: jtp.VectorLike

def __hash__(self) -> int:
return hash(
(
hash(super()),
hash(tuple(self.center.tolist())),
)
)

def __eq__(self, other: BoxCollision) -> bool:

if not isinstance(other, BoxCollision):
return False

return hash(self) == hash(other)
center: tuple[float, float, float]


@dataclasses.dataclass
Expand All @@ -161,18 +134,4 @@ class MeshCollision(CollisionShape):
center: The center of the mesh in the local frame of the collision shape.
"""

center: jtp.VectorLike

def __hash__(self) -> int:
return hash(
(
hash(tuple(self.center.tolist())),
hash(self.collidable_points),
)
)

def __eq__(self, other: MeshCollision) -> bool:
if not isinstance(other, MeshCollision):
return False

return hash(self) == hash(other)
center: tuple[float, float, float]
Loading

0 comments on commit 10dac08

Please sign in to comment.