diff --git a/environment.yml b/environment.yml index cc6a4023b..972ef02c1 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,7 @@ dependencies: - jaxlie >= 1.3.0 - jax-dataclasses >= 1.4.0 - pptree - - rod >= 0.2.0 + - rod >= 0.3.0 - typing_extensions # python<3.12 # ==================================== # Optional dependencies from setup.cfg @@ -41,18 +41,15 @@ dependencies: - pip - sphinx - sphinx-autodoc-typehints + - sphinx-book-theme - sphinx-copybutton - sphinx-design - sphinx_fontawesome - sphinx-jinja2-compat - sphinx-multiversion - sphinx_rtd_theme - - sphinx-book-theme - sphinx-toolbox # ======================================== # Other dependencies for GitHub Codespaces # ======================================== - # System dependencies to run the tests - - gz-sim7 - # Other packages - ipython diff --git a/setup.cfg b/setup.cfg index 887110b8a..0061982ed 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,7 +60,7 @@ install_requires = jaxlie >= 1.3.0 jax_dataclasses >= 1.4.0 pptree - rod >= 0.2.0 + rod >= 0.3.0 typing_extensions ; python_version < '3.12' [options.packages.find] diff --git a/src/jaxsim/parsers/kinematic_graph.py b/src/jaxsim/parsers/kinematic_graph.py index 3e7c3e2ac..42705f7b4 100644 --- a/src/jaxsim/parsers/kinematic_graph.py +++ b/src/jaxsim/parsers/kinematic_graph.py @@ -121,7 +121,8 @@ def __post_init__(self): # Also here, we assume the model is fixed-base, therefore the first frame will # have last_link_idx + 1. These frames are not part of the physics model. for index, frame in enumerate(self.frames): - frame.index = index + len(self.link_names()) + with frame.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): + frame.index = int(index + len(self.link_names())) # Number joints so that their index matches their child link index links_dict = {l.name: l for l in iter(self)} diff --git a/src/jaxsim/parsers/rod/parser.py b/src/jaxsim/parsers/rod/parser.py index c2e396d57..376e20506 100644 --- a/src/jaxsim/parsers/rod/parser.py +++ b/src/jaxsim/parsers/rod/parser.py @@ -6,6 +6,7 @@ import numpy as np import rod +import jaxsim.utils from jaxsim import logging from jaxsim.math.quaternion import Quaternion from jaxsim.parsers import descriptions, kinematic_graph @@ -25,6 +26,7 @@ class SDFData(NamedTuple): link_descriptions: List[descriptions.LinkDescription] joint_descriptions: List[descriptions.JointDescription] + frame_descriptions: List[descriptions.LinkDescription] collision_shapes: List[descriptions.CollisionShape] sdf_model: rod.Model | None = None @@ -70,6 +72,8 @@ def extract_model_data( # Jaxsim supports only models compatible with URDF, i.e. those having all links # directly attached to their parent joint without additional roto-translations. + # Furthermore, the following switch also post-processes frames such that their + # pose is expressed wrt the parent link they are rigidly attached to. sdf_model.switch_frame_convention(frame_convention=rod.FrameConvention.Urdf) # Log type of base link @@ -113,6 +117,23 @@ def extract_model_data( # Create a dictionary to find easily links links_dict: Dict[str, descriptions.LinkDescription] = {l.name: l for l in links} + # ============ + # Parse frames + # ============ + + # Parse the frames (unconnected) + frames = [ + descriptions.LinkDescription( + name=f.name, + mass=jnp.array(0.0, dtype=float), + inertia=jnp.zeros(shape=(3, 3)), + parent=links_dict[f.attached_to], + pose=f.pose.transform() if f.pose is not None else jnp.eye(4), + ) + for f in sdf_model.frames() + if f.attached_to in links_dict + ] + # ========================= # Process fixed-base models # ========================= @@ -309,6 +330,7 @@ def extract_model_data( model_name=sdf_model.name, link_descriptions=links, joint_descriptions=joints, + frame_descriptions=frames, collision_shapes=collisions, fixed_base=sdf_model.is_fixed_base(), base_link_name=sdf_model.get_canonical_link(), @@ -338,10 +360,14 @@ def build_model_description( model_description=model_description, model_name=None, is_urdf=is_urdf ) - # Build the model description. + # Build the intermediate representation used for building a JaxSim model. + # This process, beyond other operations, removes the fixed joints. # Note: if the model is fixed-base, the fixed joint between world and the first # link is removed and the pose of the first link is updated. - model = descriptions.ModelDescription.build_model_from( + # + # The whole process is: + # URDF/SDF ⟶ rod.Model ⟶ ModelDescription ⟶ JaxSimModel. + graph = descriptions.ModelDescription.build_model_from( name=sdf_data.model_name, links=sdf_data.link_descriptions, joints=sdf_data.joint_descriptions, @@ -356,7 +382,47 @@ def build_model_description( ], ) + # Depending on how the model is reduced due to the removal of fixed joints, + # there might be frames that are no longer attached to existing links. + # We need to change the link to which they are attached to, and update their pose. + frames_with_no_parent_link = ( + f for f in sdf_data.frame_descriptions if f.parent.name not in graph + ) + + # Build the object to compute forward kinematics. + fk = kinematic_graph.KinematicGraphTransforms(graph=graph) + + for frame in frames_with_no_parent_link: + # Get the original data of the frame. + original_pose = frame.pose + original_parent_link = frame.parent.name + + # The parent link, that has been removed, became a frame. + assert original_parent_link in graph.frames_dict, (frame, original_parent_link) + + # Get the new parent of the frame corresponding to the removed parent link. + new_parent_link = graph.frames_dict[original_parent_link].parent.name + logging.debug(f"Frame '{frame.name}' is now attached to '{new_parent_link}'") + + # Get the transform from the new parent link to the original parent link. + # The original pose is expressed wrt the original parent link. + F_H_P = fk.relative_transform( + relative_to=new_parent_link, name=original_parent_link + ) + + # Update the frame with the updated data. + with frame.mutable_context( + mutability=jaxsim.utils.Mutability.MUTABLE_NO_VALIDATION + ): + frame.parent = graph.links_dict[new_parent_link] + frame.pose = np.array(F_H_P @ original_pose) + + # Include the SDF frames originally stored in the SDF. + graph = dataclasses.replace( + graph, frames=sdf_data.frame_descriptions + graph.frames + ) + # Store the parsed SDF tree as extra info - model = dataclasses.replace(model, extra_info={"sdf_model": sdf_data.sdf_model}) + graph = dataclasses.replace(graph, extra_info={"sdf_model": sdf_data.sdf_model}) - return model + return graph