Skip to content

Commit

Permalink
[wip] Prepare transition to jaxsim.api from jaxsim.{high_level|physics}
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Mar 13, 2024
1 parent 9568eb7 commit 35df971
Show file tree
Hide file tree
Showing 26 changed files with 119 additions and 108 deletions.
7 changes: 3 additions & 4 deletions src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def _is_editable() -> bool:
del _np_options
del _is_editable

from . import high_level, logging, math, simulation, sixd
from .high_level.common import VelRepr
from .simulation.ode_integration import IntegratorType
from .simulation.simulator import JaxSim
from . import terrain # isort:skip
from . import api, integrators, logging, math, rbda
from .api.common import VelRepr
3 changes: 2 additions & 1 deletion src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import common # isort:skip
from . import model, data # isort:skip
from . import common, contact, joint, kin_dyn_parameters, link, ode, references
from . import contact, joint, kin_dyn_parameters, link, ode, ode_data, references
13 changes: 12 additions & 1 deletion src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import contextlib
import dataclasses
import enum
import functools
from typing import ContextManager

Expand All @@ -11,7 +12,6 @@
from jax_dataclasses import Static

import jaxsim.typing as jtp
from jaxsim.high_level.common import VelRepr
from jaxsim.utils import JaxsimDataclass, Mutability

try:
Expand All @@ -20,6 +20,17 @@
from typing_extensions import Self


@enum.unique
class VelRepr(enum.IntEnum):
"""
Enumeration of all supported 6D velocity representations.
"""

Body = enum.auto()
Mixed = enum.auto()
Inertial = enum.auto()


@jax_dataclasses.pytree_dataclass
class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
"""
Expand Down
29 changes: 16 additions & 13 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.physics.algos import soft_contacts


@jax.jit
Expand All @@ -28,9 +28,9 @@ def collidable_point_kinematics(
the linear component of the mixed 6D frame velocity.
"""

from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel
from jaxsim.rbda import soft_contacts

W_p_Ci, W_ṗ_Ci = collidable_points_pos_vel(
W_p_Ci, W_ṗ_Ci = soft_contacts.collidable_points_pos_vel(
model=model.physics_model,
q=data.state.physics_model.joint_positions,
qd=data.state.physics_model.joint_velocities,
Expand Down Expand Up @@ -101,9 +101,9 @@ def in_contact(
if set(link_names) - set(model.link_names()) != set():
raise ValueError("One or more link names are not part of the model")

from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel
from jaxsim.rbda import soft_contacts

W_p_Ci, _ = collidable_points_pos_vel(
W_p_Ci, _ = soft_contacts.collidable_points_pos_vel(
model=model.physics_model,
q=data.state.physics_model.joint_positions,
qd=data.state.physics_model.joint_velocities,
Expand Down Expand Up @@ -134,7 +134,7 @@ def estimate_good_soft_contacts_parameters(
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> soft_contacts.SoftContactsParams:
) -> jaxsim.rbda.soft_contacts.SoftContactsParams:
"""
Estimate good soft contacts parameters for the given model.
Expand Down Expand Up @@ -162,7 +162,8 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
""""""

zero_data = js.data.JaxSimModelData.build(
model=model, soft_contacts_params=soft_contacts.SoftContactsParams()
model=model,
soft_contacts_params=jaxsim.rbda.soft_contacts.SoftContactsParams(),
)

W_pz_CoM = js.model.com_position(model=model, data=zero_data)[2]
Expand All @@ -181,12 +182,14 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:

nc = number_of_active_collidable_points_steady_state

sc_parameters = soft_contacts.SoftContactsParams.build_default_from_physics_model(
physics_model=model.physics_model,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
sc_parameters = (
jaxsim.rbda.soft_contacts.SoftContactsParams.build_default_from_physics_model(
physics_model=model.physics_model,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
)
)

return sc_parameters
25 changes: 11 additions & 14 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,13 @@
import numpy as np

import jaxsim.api as js
import jaxsim.physics.algos.aba
import jaxsim.physics.algos.crba
import jaxsim.physics.algos.forward_kinematics
import jaxsim.physics.algos.rnea
import jaxsim.physics.model.physics_model
import jaxsim.physics.model.physics_model_state
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.high_level.common import VelRepr
from jaxsim.physics.algos import soft_contacts
from jaxsim.simulation.ode_data import ODEState
from jaxsim.utils import Mutability

from . import common
from .common import VelRepr
from .ode_data import ODEState

try:
from typing import Self
Expand All @@ -41,9 +35,10 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

gravity: jtp.Array

soft_contacts_params: soft_contacts.SoftContactsParams = dataclasses.field(
repr=False
soft_contacts_params: jaxsim.rbda.soft_contacts.SoftContactsParams = (
dataclasses.field(repr=False)
)

time_ns: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
)
Expand Down Expand Up @@ -96,8 +91,10 @@ def build(
base_angular_velocity: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
gravity: jtp.Vector | None = None,
soft_contacts_state: soft_contacts.SoftContactsState | None = None,
soft_contacts_params: soft_contacts.SoftContactsParams | None = None,
soft_contacts_state: jaxsim.rbda.soft_contacts.SoftContactsState | None = None,
soft_contacts_params: (
jaxsim.rbda.soft_contacts.SoftContactsParams | None
) = None,
velocity_representation: VelRepr = VelRepr.Inertial,
time: jtp.FloatLike | None = None,
) -> JaxSimModelData:
Expand Down Expand Up @@ -186,7 +183,7 @@ def build(

ode_state = ODEState.build(
physics_model=model.physics_model,
physics_model_state=jaxsim.physics.model.physics_model_state.PhysicsModelState(
physics_model_state=js.ode_data.PhysicsModelState(
base_position=base_position.astype(float),
base_quaternion=base_quaternion.astype(float),
joint_positions=joint_positions.astype(float),
Expand Down
12 changes: 7 additions & 5 deletions src/jaxsim/api/link.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import numpy as np

import jaxsim.api as js
import jaxsim.physics.algos.jacobian
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.high_level.common import VelRepr

from .common import VelRepr

# =======================
# Index-related functions
Expand Down Expand Up @@ -210,11 +211,12 @@ def jacobian(
velocity representation.
"""

if output_vel_repr is None:
output_vel_repr = data.velocity_representation
output_vel_repr = (
output_vel_repr if output_vel_repr is not None else data.velocity_representation
)

# Compute the doubly left-trivialized free-floating jacobian
L_J_WL_B = jaxsim.physics.algos.jacobian.jacobian(
L_J_WL_B = jaxsim.rbda.jacobian.jacobian(
model=model.physics_model,
body_index=link_index,
q=data.joint_positions(),
Expand Down
19 changes: 10 additions & 9 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@

import jaxsim.api as js
import jaxsim.parsers.descriptions
import jaxsim.physics.algos.aba
import jaxsim.physics.algos.crba
import jaxsim.physics.algos.forward_kinematics
import jaxsim.physics.algos.rnea
import jaxsim.physics.model.physics_model
import jaxsim.physics.model.physics_model_state
import jaxsim.typing as jtp
from jaxsim.high_level.common import VelRepr
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
from jaxsim.utils import HashlessObject, JaxsimDataclass, Mutability

from .common import VelRepr


@jax_dataclasses.pytree_dataclass
class JaxSimModel(JaxsimDataclass):
Expand All @@ -37,8 +34,8 @@ class JaxSimModel(JaxsimDataclass):
repr=False, compare=False, hash=False
)

terrain: Static[Terrain] = dataclasses.field(
default=FlatTerrain(), repr=False, compare=False, hash=False
terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
default=jaxsim.terrain.FlatTerrain(), repr=False, compare=False, hash=False
)

built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
Expand Down Expand Up @@ -388,7 +385,7 @@ def forward_kinematics(model: JaxSimModel, data: js.data.JaxSimModelData) -> jtp
The first axis is the link index.
"""

W_H_LL = jaxsim.physics.algos.forward_kinematics.forward_kinematics_model(
W_H_LL = jaxsim.rbda.forward_kinematics.forward_kinematics_model(
model=model.physics_model,
q=data.state.physics_model.joint_positions,
xfb=data.state.physics_model.xfb(),
Expand Down Expand Up @@ -719,6 +716,8 @@ def free_floating_mass_matrix(
The free-floating mass matrix of the model.
"""

import jaxsim.physics.algos.crba

M_body = jaxsim.physics.algos.crba.crba(
model=model.physics_model,
q=data.state.physics_model.joint_positions,
Expand Down Expand Up @@ -852,6 +851,8 @@ def to_inertial(C_v̇_WB, W_H_C, C_v_WB, W_vl_WC):
velocity_representation=data.velocity_representation,
)

import jaxsim.physics.algos.rnea

# Compute RNEA
with references.switch_velocity_representation(VelRepr.Inertial):
W_f_B, τ = jaxsim.physics.algos.rnea.rnea(
Expand Down
14 changes: 6 additions & 8 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,15 @@

import jax
import jax.numpy as jnp
import jaxlie

import jaxsim.api as js
import jaxsim.physics.algos.soft_contacts
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim import VelRepr, integrators
from jaxsim.integrators.common import Time
from jaxsim.integrators import Time
from jaxsim.math.quaternion import Quaternion
from jaxsim.physics.algos.soft_contacts import SoftContactsState
from jaxsim.physics.model.physics_model_state import PhysicsModelState
from jaxsim.simulation.ode_data import ODEState

from .common import VelRepr
from .ode_data import ODEState, PhysicsModelState, SoftContactsState


class SystemDynamicsFromModelAndData(Protocol):
Expand Down Expand Up @@ -127,7 +125,7 @@ def system_velocity_dynamics(

# Compute the 3D forces applied to each collidable point.
W_f_Ci, = jax.vmap(
lambda p, , m: jaxsim.physics.algos.soft_contacts.SoftContacts(
lambda p, , m: jaxsim.rbda.soft_contacts.SoftContacts(
parameters=data.soft_contacts_params, terrain=model.terrain
).contact_model(position=p, velocity=, tangential_deformation=m)
)(W_p_Ci, W_ṗ_Ci, data.state.soft_contacts.tangential_deformation.T)
Expand Down
6 changes: 6 additions & 0 deletions src/jaxsim/api/ode_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from jaxsim.physics.algos.soft_contacts import SoftContactsState
from jaxsim.physics.model.physics_model_state import (
PhysicsModelInput,
PhysicsModelState,
)
from jaxsim.simulation.ode_data import ODEInput, ODEState
3 changes: 2 additions & 1 deletion src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim import VelRepr
from jaxsim.simulation.ode_data import ODEInput

from .common import VelRepr

try:
from typing import Self
except ImportError:
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/high_level/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from ..api.common import VelRepr
from . import common, joint, link, model
from .common import VelRepr
11 changes: 0 additions & 11 deletions src/jaxsim/high_level/common.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/jaxsim/integrators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import fixed_step
from .common import Integrator, Time, TimeStep
from .common import Integrator, SystemDynamics, Time, TimeStep
14 changes: 7 additions & 7 deletions src/jaxsim/integrators/fixed_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import jax_dataclasses
import jaxlie

from jaxsim.simulation.ode_data import ODEState
import jaxsim.api as js

from .common import ExplicitRungeKutta, PyTreeType, Time, TimeStep

ODEStateDerivative = ODEState
ODEStateDerivative = js.ode_data.ODEState


# =====================================================
Expand Down Expand Up @@ -97,8 +97,8 @@ class ExplicitRungeKuttaSO3Mixin:

@classmethod
def post_process_state(
cls, x0: ODEState, t0: Time, xf: ODEState, dt: TimeStep
) -> ODEState:
cls, x0: js.ode_data.ODEState, t0: Time, xf: js.ode_data.ODEState, dt: TimeStep
) -> js.ode_data.ODEState:

# Indices to convert quaternions between serializations.
to_xyzw = jnp.array([1, 2, 3, 0])
Expand Down Expand Up @@ -130,15 +130,15 @@ def post_process_state(


@jax_dataclasses.pytree_dataclass
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[ODEState]):
class ForwardEulerSO3(ExplicitRungeKuttaSO3Mixin, ForwardEuler[js.ode_data.ODEState]):
pass


@jax_dataclasses.pytree_dataclass
class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[ODEState]):
class Heun2SO3(ExplicitRungeKuttaSO3Mixin, Heun2[js.ode_data.ODEState]):
pass


@jax_dataclasses.pytree_dataclass
class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[ODEState]):
class RungeKutta4SO3(ExplicitRungeKuttaSO3Mixin, RungeKutta4[js.ode_data.ODEState]):
pass
Loading

0 comments on commit 35df971

Please sign in to comment.