Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up various functions avoiding using jax.vmap on indexes #235

Merged
merged 7 commits into from
Sep 18, 2024

Conversation

flferretti
Copy link
Collaborator

@flferretti flferretti commented Sep 18, 2024

This PR removes the usage of jax.vmap on functions that can directly extract from arrays using the numpy-like indexing, e.g. using an index of an array for extracting and then stacking the resulting extracted components. This speeds up the compilation and the runtime performance of the concerned methods.

A couple benchmarks on stickbot model with 23 joints:

In [61]: %timeit js.contact.transforms(model=model, data=data) # previous implementation
11.9 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [66]: %timeit js.contact.transforms(model=model, data=data) # current implementation
1.63 ms ± 45.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [98]: %timeit -r 1 -n 1 js.model.total_mass(model) # previous implementation
5.98 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

In [100]: %timeit -r 1 -n 1 js.model.total_mass(model) # current implementation
1.61 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)

Apart from having a noticeable impact on the single methods, this should speed up all the functionalities that depend on the concerned functions.


📚 Documentation preview 📚: https://jaxsim--235.org.readthedocs.build//235/

@flferretti flferretti self-assigned this Sep 18, 2024
Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @flferretti, all these tiny performance improvements are more than welcome! All good, except the change in the usage of jaxsim.api.link.mass() that puzzles me a bit. I suggested a low-level workaround.

tests/test_api_link.py Outdated Show resolved Hide resolved
tests/test_api_link.py Outdated Show resolved Hide resolved
tests/test_api_link.py Outdated Show resolved Hide resolved
src/jaxsim/api/model.py Outdated Show resolved Hide resolved
@flferretti flferretti force-pushed the fix/avoid_vmap_on_index branch from fd224ec to 14dca49 Compare September 18, 2024 14:11
@flferretti flferretti merged commit 3587b33 into main Sep 18, 2024
24 checks passed
@flferretti flferretti deleted the fix/avoid_vmap_on_index branch September 18, 2024 14:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants