Skip to content

Commit

Permalink
Add a workaround for LLVM bug in codegen for bf16 vector cast. (trito…
Browse files Browse the repository at this point in the history
…n-lang#4)

Signed-off-by: Ilya Enkovich <[email protected]>
  • Loading branch information
ienkovich committed Dec 6, 2024
1 parent 693a9f8 commit 633e190
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
7 changes: 7 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_hip_mi300,
is_xpu,
get_arch,
is_cpu,
torch_float8_dtypes,
torch_dtypes,
numpy_random,
Expand Down Expand Up @@ -1668,6 +1669,12 @@ def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device):
if is_hip() and (dtype_z in ("bfloat16", "float8_e4m3fn") or dtype_x == "float8_e4m3fn"):
pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.')

# bf16 vector cast is broken in LLVM for large vectors:
# https://github.com/llvm/llvm-project/issues/92471
# TODO: Remove the change after the bug is fixed.
if is_cpu() and dtype_x == 'bfloat16' and size > 128:
size = 128

torch.manual_seed(0)
# This is tricky because numpy doesn't have bfloat, and torch doesn't have uints.
if dtype_x.startswith('bfloat'):
Expand Down
5 changes: 5 additions & 0 deletions python/triton/_internal_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def get_arch():
return "" if target is None else str(target.arch)


def is_cpu():
return not is_interpreter() and \
triton.runtime.driver.active.get_current_target().backend == "cpu"


def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, high=None):
"""
Override `rs` if you're calling this function twice and don't want the same
Expand Down

0 comments on commit 633e190

Please sign in to comment.