Skip to content

Commit

Permalink
Add fp8 types exposed in jax.numpy.
Browse files Browse the repository at this point in the history
  • Loading branch information
liudangyi committed Oct 1, 2024
1 parent f4ca4c9 commit 36e3831
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 1 deletion.
10 changes: 10 additions & 0 deletions jaxtyping/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
Complex64 as Complex64,
Complex128 as Complex128,
Float as Float,
Float8e4m3b11fnuz as Float8e4m3b11fnuz,
Float8e4m3fn as Float8e4m3fn,
Float8e4m3fnuz as Float8e4m3fnuz,
Float8e5m2 as Float8e5m2,
Float8e5m2fnuz as Float8e5m2fnuz,
Float16 as Float16,
Float32 as Float32,
Float64 as Float64,
Expand Down Expand Up @@ -110,6 +115,11 @@
Complex64 as Complex64,
Complex128 as Complex128,
Float as Float,
Float8e4m3b11fnuz as Float8e4m3b11fnuz,
Float8e4m3fn as Float8e4m3fn,
Float8e4m3fnuz as Float8e4m3fnuz,
Float8e5m2 as Float8e5m2,
Float8e5m2fnuz as Float8e5m2fnuz,
Float16 as Float16,
Float32 as Float32,
Float64 as Float64,
Expand Down
19 changes: 18 additions & 1 deletion jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,6 +733,11 @@ def __init_subclass__(cls, **kwargs):
_int16 = "int16"
_int32 = "int32"
_int64 = "int64"
_float8_e4m3b11fnuz = "float8_e4m3b11fnuz"
_float8_e4m3fn = "float8_e4m3fn"
_float8_e4m3fnuz = "float8_e4m3fnuz"
_float8_e5m2 = "float8_e5m2"
_float8_e5m2fnuz = "float8_e5m2fnuz"
_bfloat16 = "bfloat16"
_float16 = "float16"
_float32 = "float32"
Expand Down Expand Up @@ -764,6 +769,11 @@ class _Cls(AbstractDtype):
Int16 = _make_dtype(_int16, "Int16")
Int32 = _make_dtype(_int32, "Int32")
Int64 = _make_dtype(_int64, "Int64")
Float8e4m3b11fnuz = _make_dtype(_float8_e4m3b11fnuz, "Float8e4m3b11fnuz")
Float8e4m3fn = _make_dtype(_float8_e4m3fn, "Float8e4m3fn")
Float8e4m3fnuz = _make_dtype(_float8_e4m3fnuz, "Float8e4m3fnuz")
Float8e5m2 = _make_dtype(_float8_e5m2, "Float8e5m2")
Float8e5m2fnuz = _make_dtype(_float8_e5m2fnuz, "Float8e5m2fnuz")
BFloat16 = _make_dtype(_bfloat16, "BFloat16")
Float16 = _make_dtype(_float16, "Float16")
Float32 = _make_dtype(_float32, "Float32")
Expand All @@ -774,7 +784,14 @@ class _Cls(AbstractDtype):
bools = [_bool, _bool_]
uints = [_uint4, _uint8, _uint16, _uint32, _uint64]
ints = [_int4, _int8, _int16, _int32, _int64]
floats = [_bfloat16, _float16, _float32, _float64]
float8 = [
_float8_e4m3b11fnuz,
_float8_e4m3fn,
_float8_e4m3fnuz,
_float8_e5m2,
_float8_e5m2fnuz,
]
floats = float8 + [_bfloat16, _float16, _float32, _float64]
complexes = [_complex64, _complex128]

# We match NumPy's type hierarachy in what types to provide. See the diagram at
Expand Down
5 changes: 5 additions & 0 deletions jaxtyping/_indirection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
Annotated as Complex64, # noqa: F401
Annotated as Complex128, # noqa: F401
Annotated as Float, # noqa: F401
Annotated as Float8e4m3b11fnuz, # noqa: F401
Annotated as Float8e4m3fn, # noqa: F401
Annotated as Float8e4m3fnuz, # noqa: F401
Annotated as Float8e5m2, # noqa: F401
Annotated as Float8e5m2fnuz, # noqa: F401
Annotated as Float16, # noqa: F401
Annotated as Float32, # noqa: F401
Annotated as Float64, # noqa: F401
Expand Down
5 changes: 5 additions & 0 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def test_dtypes():
Complex64,
Complex128,
Float,
Float8e4m3b11fnuz,
Float8e4m3fn,
Float8e4m3fnuz,
Float8e5m2,
Float8e5m2fnuz,
Float16,
Float32,
Float64,
Expand Down

0 comments on commit 36e3831

Please sign in to comment.