-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize Jacobian algorithm #121
Conversation
84588d5
to
e9302ad
Compare
e9302ad
to
906a49d
Compare
906a49d
to
18f9553
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! Did you did any benchmark on the difference before and after this change?
I was running them while you were typing your comment :) Here below a benchmark on a 58-DoFs ErgoCub model on CPU. Before
After
Overview
scriptimport jaxsim.api as js
import resolve_robotics_uri_py
import rod
# Find the urdf file.
urdf_path = resolve_robotics_uri_py.resolve_robotics_uri(
uri="model://ergoCubSN001/model.urdf"
)
# Build the ROD model.
rod_sdf = rod.Sdf.load(sdf=urdf_path)
# Build the model.
model = js.model.JaxSimModel.build_from_model_description(
model_description=rod_sdf.model,
)
# Create random data.
data0 = js.data.random_model_data(
model=model,
base_pos_bounds=((0, 0, 0.85), (0, 0, 0.85)),
joint_vel_bounds=(0, 0),
base_vel_lin_bounds=(0, 0),
base_vel_ang_bounds=(0, 0),
)
%timeit -r1 -n1 _ = js.model.generalized_free_floating_jacobian(model, data0)
%timeit -r 10 -n 1000 _ = js.model.generalized_free_floating_jacobian(model, data0)
%timeit -r1 -n1 _ = js.model.forward_dynamics_crb(model, data0)
%timeit -r10 -n1000 _ = js.model.forward_dynamics_crb(model, data0) |
In other words, on such a large model, with this change we can compute FD with CRB in less than 1ms with JAX running on CPU. Not too bad. As a comparison, the equivalent computation with ABA runs 3x faster:
|
cc @ami-iit/vertical_control-oriented-learning |
This PR enhances the computation of the Jacobians of all links. Before this PR, we computed in parallel with${}^L J_{W,L/B}$ . Then, we were adjusting the input and output representations to match the desired ones.
jax.vmap
the free-floating left-trivialized Jacobian of all linksThis PR, instead of calling a parallelized version of${}^B J_{W,\text{\_}/B}$ (note that there is a $\text{\_}$ instead of $L$ ), defined as the free-floating Jacobian with all rows filled. The free-floating doubly-left Jacobian of the i-th link ${}^B J_{W,L/B}$ is then computed by filtering the columns of the full Jacobian with the support parent array $\kappa(i)$ of the link, that sets to zero the columns corresponding to all links not part of the path $\pi_B(L)$ .
jaxsim.rbda.jacobian
, computes only once a full doubly-left JacobianAll of this allows to compute the Jacobians of all links by only vmapping the full Jacobian using$\kappa(i)$ , therefore the expensive algorithm to compute the full Jacobian is executed only once. The vmapped operation is just a column filtering, therefore its cost is almost zero.
This enhancement should speed up the computation of
jaxsim.api.model.forward_dynamics_crb
. While at this stage is always better to usejaxsim.api.model.forward_dynamics_aba
, the CRB version might become useful in all experiments that extend the forward dynamics with e.g. actuators, friction models, or any other stateless second-order dynamics elements1. In these cases, operating on ABA could become quite difficult, and prototyping using the CRB version might be pretty helpful.cc @traversaro @DanielePucci
📚 Documentation preview 📚: https://jaxsim--121.org.readthedocs.build//121/
Footnotes
Under the assumption that they do not extend the integrated state vector. In that case, things become more complex since also the integrator has to properly updated. ↩