Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
JAX FP8 matmul (fusion) Jupyter Notebook tutorial.
In this tutorial notebook, we investigate how the ML stack JAX + XLA handles the specificities of FP8 matmuls, while still generating an optimal fused kernel call including: * FP8 inputs scaling; * FP8 output scaling & clamping; * Non-linearity & bias fusing; * Abs-max output capture; Note: some open questions remain on bias or gelu fusing.
- Loading branch information