Skip to content

Commit

Permalink
Store ModelDescription directly inside JaxSimModel
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Mar 12, 2024
1 parent 65eaa0c commit 6f570bb
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from jax_dataclasses import Static

import jaxsim.api as js
import jaxsim.parsers.descriptions
import jaxsim.physics.algos.aba
import jaxsim.physics.algos.crba
import jaxsim.physics.algos.forward_kinematics
Expand All @@ -21,7 +22,7 @@
import jaxsim.typing as jtp
from jaxsim.high_level.common import VelRepr
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
from jaxsim.utils import JaxsimDataclass, Mutability
from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability


@jax_dataclasses.pytree_dataclass
Expand All @@ -44,6 +45,10 @@ class JaxSimModel(JaxsimDataclass):
default=None, repr=False, compare=False, hash=False
)

description: Static[
HashlessObject[jaxsim.parsers.descriptions.ModelDescription | None]
] = dataclasses.field(default=None, repr=False, compare=False, hash=False)

kin_dyn_parameters: js.kin_dyn_parameters.KynDynParameters | None = (
dataclasses.field(default=None, repr=False, compare=False, hash=False)
)
Expand Down Expand Up @@ -133,15 +138,14 @@ def build(
)

# Build the model
model = JaxSimModel(physics_model=physics_model, model_name=model_name) # noqa

# Create and store the KinDynParameters
with model.mutable_context(
mutability=Mutability.MUTABLE_NO_VALIDATION, restore_after_exception=False
):
model.kin_dyn_parameters = js.kin_dyn_parameters.KynDynParameters.build(
model=model
)
model = JaxSimModel(
physics_model=physics_model,
model_name=model_name,
description=HashlessObject(obj=physics_model.description),
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
model_description=physics_model.description
),
)

return model

Expand Down Expand Up @@ -277,7 +281,7 @@ def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimMode
# Reduce the model description.
# If considered_joints contains joints not existing in the model, the method
# will raise an exception.
reduced_intermediate_description = model.physics_model.description.reduce(
reduced_intermediate_description = model.description.obj.reduce(
considered_joints=list(considered_joints)
)

Expand All @@ -295,6 +299,7 @@ def reduce(model: JaxSimModel, considered_joints: tuple[str, ...]) -> JaxSimMode
# Store the origin of the model, in case downstream logic needs it
with reduced_model.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
reduced_model.built_from = model.built_from
reduced_model.description = HashlessObject(obj=physics_model.description)

return reduced_model

Expand Down

0 comments on commit 6f570bb

Please sign in to comment.