-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance maintainability and performance #78
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks! It makes sense to unify all the ode_integration_*
functions since these integrators evolved to have the same APIs (compatible with jax.experimental.ode
by the way). It can be done similarly to the odeint_*
functions, but I'm considering to refactor soon this file with the new resources from #72, so let's keep them as they are - at least for now.
For the rest, just minor changes.
A comment. I'm not really sure if the performance analysis that usually apply to plain Python code is valid also here with jitted code. For example, one way to check e.g. the dict
construction is to create two simple functions and check the resulting IR with jax.make_jaxpr
. I suspect that it outputs same expression regardless to the construction. However, perhaps it runs more efficiently during the creation of the expression, so it may not be entirely useless.
Co-authored-by: Diego Ferigo <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>
Co-authored-by: Diego Ferigo <[email protected]>
73f8ca6
to
d3c65ac
Compare
You were right, the In [3]: jax.make_jaxpr(lambda k,v: dict(k=v))(k,v)
Out[3]: { lambda ; a:i32[] b:i32[]. let in (b,) }
In [4]: jax.make_jaxpr(lambda k,v: {k:v})(k,v)
Out[4]: { lambda ; a:i32[] b:i32[]. let in (b,) } It does run slightly faster, but I'm pretty sure it won't make a huge difference when the code is compiled In [3]: %timeit dict(k=v)
63.3 ns ± 4.07 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [4]: %timeit {k:v}
52.8 ns ± 3.91 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
In [5]: %timeit jax.make_jaxpr(lambda k,v: dict(k=v))(k,v)
186 µs ± 6.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [6]: %timeit jax.make_jaxpr(lambda k,v: {k:v})(k,v):
179 µs ± 2.36 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) |
This PR aims to improve code readability, efficiency, and maintainability. In particular:
high_level.Model.valid()
with generatorsdict()
for better performance (see Performance Analysis of Python’s dict() and {})📚 Documentation preview 📚: https://jaxsim--78.org.readthedocs.build//78/