-
Notifications
You must be signed in to change notification settings - Fork 67
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add fp8 types exposed in jax.numpy. #251
Conversation
Thank you for the contribution! These all look good to me. Do you have a link to any documentation on these? |
Yes. They are defined in https://github.com/jax-ml/jax/blob/main/jax/_src/dtypes.py#L92-L97. I have updated the patch to include a link. |
Hi @patrick-kidger, can we merge this change into main? |
Is there any proper documentation, though? Not just their existence in the source code. I'm a little hesitant to expand our own public API to include undocumented features. (Advanced users such as yourself can continue to subclass
As with many open-source projects this is a volunteer effort that happens primarily in my evenings and weekends. Please have a little patience. |
Unfortunately I cannot find any public document about fp8. jax-ml/jax@d203926 is the first commit adding fp8 support to JAX, and there's no additional info attached. |
Kindly reminder on this PR. |
Since these dtypes are not yet public in JAX then I don't think we should make them public either, I'm afraid. That might change in the future, though :) |
Hi Patrick – these are public symbols in JAX, so it should be safe to add them here. Thanks! |
What greater confirmation could we ask for :) |
No description provided.