Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce new functional APIs #88

Merged
merged 57 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
f062ed5
Initialize new jaxsim.api package
diegoferigo Feb 19, 2024
9be8148
Add JaxSimModelData class
diegoferigo Feb 19, 2024
9d546e7
Add JaxSimModel class
diegoferigo Feb 28, 2024
ba3e3c6
Add model.reduce
diegoferigo Feb 19, 2024
bb1a6db
Extend PhysicsModel with new temporary attributes for functional APIs
diegoferigo Feb 23, 2024
52e7db6
Update PhysicsModel{State|Input} classes with builder methods
diegoferigo Feb 19, 2024
2602cbb
Initialize jaxsim.api.link module with index-related functions
diegoferigo Feb 20, 2024
13603fd
Initialize jaxsim.api.joint module with index-related functions
diegoferigo Feb 20, 2024
b6409b2
Initialize jaxsim.api.contacts with kinematics of collidable points
diegoferigo Feb 19, 2024
d9996af
Update SoftContactsState class with builder methods
diegoferigo Feb 19, 2024
969fe43
Add logic to build model-specific SoftContactsParams
diegoferigo Feb 19, 2024
061a729
Add contact.in_contact
diegoferigo Feb 19, 2024
1530336
Add contact.estimate_good_soft_contact_parameters
diegoferigo Feb 19, 2024
f35b395
Add link.mass and link.spatial_inertia
diegoferigo Feb 19, 2024
a5086d9
Add joint.position_limit and joint.position_limits
diegoferigo Feb 19, 2024
bed8056
Add joint.random_joint_positions
diegoferigo Feb 19, 2024
8b7eecb
Add data.random_model_data
diegoferigo Feb 19, 2024
2d36df2
Add link.transform
diegoferigo Feb 19, 2024
414d038
Add link.com_position
diegoferigo Feb 19, 2024
dcb3b6c
Add link.jacobian
diegoferigo Feb 19, 2024
6e01e45
Add model.total_mass
diegoferigo Feb 19, 2024
5bf2004
Add model.com_position
diegoferigo Feb 19, 2024
9125a05
Add model.forward_kinematics
diegoferigo Feb 19, 2024
f729b07
Add model.generalized_free_floating_jacobian
diegoferigo Feb 19, 2024
dd3ee3f
Add forward dynamics functions (ABA and CRB)
diegoferigo Feb 20, 2024
e94cc1d
Add model.free_floating_mass_matrix function (CRBA)
diegoferigo Feb 20, 2024
6e3c4c5
Add model.inverse_dynamics function (RNEA)
diegoferigo Feb 20, 2024
1512103
Add model.free_floating_gravity_forces function
diegoferigo Feb 20, 2024
4f71e0c
Add model.free_floating_bias_forces function
diegoferigo Feb 20, 2024
620d8d8
Add model.total_momentum function
diegoferigo Feb 20, 2024
57cd41e
Add functions to compute the potential, kinetic, and mechanical energy
diegoferigo Feb 20, 2024
e8c47b5
Import modules into jaxsim.api package
diegoferigo Feb 20, 2024
e4b6732
Initialize new jaxsim.integrators package
diegoferigo Feb 21, 2024
7435056
Add integrators.common module with Integrator base class
diegoferigo Feb 21, 2024
ce12b50
Add ExplicitRungeKutta base class supporting generic pytrees
diegoferigo Feb 21, 2024
6bcea25
Add ForwardEuler, Heun, and RungeKutta4 fixed-step integrators
diegoferigo Feb 21, 2024
9023f9e
Add missing includes in api.model
diegoferigo Feb 23, 2024
fc0beec
Update signature and notation in api.model
diegoferigo Feb 23, 2024
93efda9
Add model.step function
diegoferigo Feb 23, 2024
8a5d3eb
Update type hints of integrators
diegoferigo Feb 23, 2024
c7a25cc
Fix RK stage for pytrees with multidimensional leaves
diegoferigo Feb 23, 2024
fd2e53a
Allow integrating having different State and StateDerivative classes
diegoferigo Feb 23, 2024
f11c318
Allow applying a custom transformation to output states of RK schemes
diegoferigo Feb 23, 2024
443049f
Add fixed-step integrators working on ODEState integrating over SO(3)
diegoferigo Feb 23, 2024
64409ff
Update integrators.__init__.py
diegoferigo Feb 23, 2024
004c1b4
Add api.ode module with the system dynamics to be integrated
diegoferigo Feb 23, 2024
1aaf7bb
Update ODE{State|Input} classes with builder methods
diegoferigo Feb 23, 2024
be0ffee
Fix link.transform to use right function
diegoferigo Feb 23, 2024
0892195
Add typing_extensions dependency for Python < 3.12
diegoferigo Feb 23, 2024
5b2394f
Apply suggestions from code review
diegoferigo Feb 23, 2024
ab7b2ff
Add missing key when generating random joint positions
diegoferigo Feb 23, 2024
2e6d82e
Remove weak ref of model stored inside data
diegoferigo Feb 29, 2024
f0dfe66
Add data.time method
diegoferigo Feb 29, 2024
45ba219
Add data.reset_* metods
diegoferigo Feb 29, 2024
599746a
Enhance utilities to change the representation of 6D quantities
diegoferigo Mar 1, 2024
2b98640
Add missing jax.jit decorators in data module
diegoferigo Feb 29, 2024
e5dc210
Import Self compatibly with Python 3.10
diegoferigo Feb 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- jax-dataclasses >= 1.4.0
- pptree
- rod
- typing_extensions # python<3.12
# Optional dependencies from setup.cfg
# [style]
- black
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ install_requires =
jax_dataclasses >= 1.4.0
pptree
rod
typing_extensions ; python_version < '3.12'

[options.packages.find]
where = src
Expand Down
1 change: 1 addition & 0 deletions src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import contact, data, joint, link, model, ode
194 changes: 194 additions & 0 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import functools

import jax
import jax.numpy as jnp

import jaxsim.typing as jtp
from jaxsim.physics.algos import soft_contacts

from . import data as Data
from . import model as Model


@jax.jit
def collidable_point_kinematics(
model: Model.JaxSimModel, data: Data.JaxSimModelData
) -> tuple[jtp.Matrix, jtp.Matrix]:
"""
Compute the position and 3D velocity of the collidable points in the world frame.

Args:
model: The model to consider.
data: The data of the considered model.

Returns:
The position and velocity of the collidable points in the world frame.

Note:
The collidable point velocity is the plain coordinate derivative of the position.
If we attach a frame C = (p_C, [C]) to the collidable point, it corresponds to
the linear component of the mixed 6D frame velocity.
"""

from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel

W_p_Ci, W_ṗ_Ci = collidable_points_pos_vel(
model=model.physics_model,
q=data.state.physics_model.joint_positions,
qd=data.state.physics_model.joint_velocities,
xfb=data.state.physics_model.xfb(),
)

return W_p_Ci.T, W_ṗ_Ci.T


@jax.jit
def collidable_point_positions(
model: Model.JaxSimModel, data: Data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the position of the collidable points in the world frame.

Args:
model: The model to consider.
data: The data of the considered model.

Returns:
The position of the collidable points in the world frame.
"""

return collidable_point_kinematics(model=model, data=data)[0]


@jax.jit
def collidable_point_velocities(
model: Model.JaxSimModel, data: Data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the 3D velocity of the collidable points in the world frame.

Args:
model: The model to consider.
data: The data of the considered model.

Returns:
The 3D velocity of the collidable points.
"""

return collidable_point_kinematics(model=model, data=data)[1]


@functools.partial(jax.jit, static_argnames=["link_names"])
def in_contact(
model: Model.JaxSimModel,
data: Data.JaxSimModelData,
*,
link_names: tuple[str, ...] | None = None,
) -> jtp.Vector:
"""
Return whether the links are in contact with the terrain.

Args:
model: The model to consider.
data: The data of the considered model.
link_names:
The names of the links to consider. If None, all links are considered.

Returns:
A boolean vector indicating whether the links are in contact with the terrain.
"""

link_names = link_names if link_names is not None else model.link_names()

if set(link_names) - set(model.link_names()) != set():
raise ValueError("One or more link names are not part of the model")

from jaxsim.physics.algos.soft_contacts import collidable_points_pos_vel

W_p_Ci, _ = collidable_points_pos_vel(
model=model.physics_model,
q=data.state.physics_model.joint_positions,
qd=data.state.physics_model.joint_velocities,
xfb=data.state.physics_model.xfb(),
)

terrain_height = jax.vmap(lambda x, y: model.terrain.height(x=x, y=y))(
W_p_Ci[0, :], W_p_Ci[1, :]
)

below_terrain = W_p_Ci[2, :] <= terrain_height

links_in_contact = jax.vmap(
lambda link_index: jnp.where(
model.physics_model.gc.body == link_index,
below_terrain,
jnp.zeros_like(below_terrain, dtype=bool),
).any()
)(jnp.arange(model.number_of_links()))

return links_in_contact


@jax.jit
def estimate_good_soft_contacts_parameters(
model: Model.JaxSimModel,
static_friction_coefficient: jtp.FloatLike = 0.5,
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
max_penetration: jtp.FloatLike | None = None,
) -> soft_contacts.SoftContactsParams:
"""
Estimate good soft contacts parameters for the given model.

Args:
model: The model to consider.
static_friction_coefficient: The static friction coefficient.
number_of_active_collidable_points_steady_state:
The number of active collidable points in steady state supporting
the weight of the robot.
damping_ratio: The damping ratio.
max_penetration:
The maximum penetration allowed in steady state when the robot is
supported by the configured number of active collidable points.

Returns:
The estimated good soft contacts parameters.

Note:
This method provides a good starting point for the soft contacts parameters.
The user is encouraged to fine-tune the parameters based on the
specific application.
"""

def estimate_model_height(model: Model.JaxSimModel) -> jtp.Float:
""""""

zero_data = Data.JaxSimModelData.build(
model=model, soft_contacts_params=soft_contacts.SoftContactsParams()
)

W_p_CoM = Model.com_position(model=model, data=zero_data)

if model.physics_model.is_floating_base:
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
return 2 * (W_p_CoM[2] - W_pz_C.min())

return 2 * W_p_CoM

max_δ = (
max_penetration
if max_penetration is not None
else 0.005 * estimate_model_height(model=model)
)

nc = number_of_active_collidable_points_steady_state

sc_parameters = soft_contacts.SoftContactsParams.build_default_from_physics_model(
physics_model=model.physics_model,
static_friction_coefficient=static_friction_coefficient,
max_penetration=max_δ,
number_of_active_collidable_points_steady_state=nc,
damping_ratio=damping_ratio,
)

return sc_parameters
Loading
Loading