From 7cd89d9e0c663c2fcfea771f79f68d5b3a7ddfb9 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 23 Oct 2024 17:31:56 +0200 Subject: [PATCH 1/2] Make `FrameParameters.body` `Static[tuple[int]]` --- src/jaxsim/api/frame.py | 2 +- src/jaxsim/api/kin_dyn_parameters.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jaxsim/api/frame.py b/src/jaxsim/api/frame.py index c85a6cec7..28b6e1dc5 100644 --- a/src/jaxsim/api/frame.py +++ b/src/jaxsim/api/frame.py @@ -40,7 +40,7 @@ def idx_of_parent_link( idx=frame_index, ) - return model.kin_dyn_parameters.frame_parameters.body[ + return jnp.array(model.kin_dyn_parameters.frame_parameters.body)[ frame_index - model.number_of_links() ] diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 62dcd1946..f37bf2583 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -825,7 +825,7 @@ class FrameParameters(JaxsimDataclass): name: Static[tuple[str, ...]] = dataclasses.field(default_factory=tuple) - body: jtp.Vector = dataclasses.field(default_factory=lambda: jnp.array([])) + body: Static[tuple[int, ...]] = dataclasses.field(default_factory=tuple) transform: jtp.Array = dataclasses.field(default_factory=lambda: jnp.array([])) @@ -862,7 +862,7 @@ def build_from(model_description: ModelDescription) -> FrameParameters: fp = FrameParameters( name=names, transform=transforms.astype(float), - body=jnp.array(parent_link_index_of_frames).astype(int), + body=parent_link_index_of_frames, ) assert fp.transform.shape[1:] == (4, 4), fp.transform.shape[1:] From 1855cc1442768a7f217cf9986a7f5120d28289a9 Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Wed, 23 Oct 2024 17:32:23 +0200 Subject: [PATCH 2/2] Update `FrameParameters.__hash__` --- src/jaxsim/api/kin_dyn_parameters.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index f37bf2583..1d4ce3491 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -238,7 +238,7 @@ def __hash__(self) -> int: hash(self.number_of_links()), hash(self.number_of_joints()), hash(self.frame_parameters.name), - hash(tuple(self.frame_parameters.body.tolist())), + hash(self.frame_parameters.body), hash(self._parent_array), hash(self._support_body_array_bool), )