From 5864232e21cffde948a23eadb971deec267f2c1d Mon Sep 17 00:00:00 2001 From: Protonu Date: Fri, 10 Jan 2025 13:04:00 -0500 Subject: [PATCH] Adding a decomposition and a test for triu (#1631) --- thunder/executors/torchex.py | 11 +++++++++++ thunder/tests/opinfos.py | 24 ++++++++++++++++++++++++ thunder/torch/__init__.py | 23 +++++++++++++++++++++++ thunder/torch/default_torch_ops.py | 2 -- 4 files changed, 58 insertions(+), 2 deletions(-) diff --git a/thunder/executors/torchex.py b/thunder/executors/torchex.py index 6de644204d..5798506804 100644 --- a/thunder/executors/torchex.py +++ b/thunder/executors/torchex.py @@ -1099,6 +1099,7 @@ def _lerp_checker(start: TensorLike, end: TensorLike, weight: Number | TensorLik where = _register_torch_operation("where") masked_fill = _register_torch_operation("masked_fill", module=torch.Tensor) tril = _register_torch_operation("tril") +triu = _register_torch_operation("triu") def _where_prim_checker(pred: Number | TensorProxy, a: Number | TensorProxy, b: Number | TensorProxy) -> bool: @@ -1133,11 +1134,21 @@ def _tril_transform(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | N return tril(a, diagonal) +# NOTE PyTorch's triu like tril does not have a fill_value parameter +def _triu_checker(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = None) -> bool: + return fill_value is None + + +def _triu_transform(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = None) -> TensorLike: + return triu(a, diagonal) + + _register_implementation(prims.where, where, checker=_where_prim_checker) _register_implementation(ltorch.clamp, clamp, checker=_always_executable) _register_implementation(ltorch.masked_fill, masked_fill, checker=_masked_fill_checker) _register_implementation(ltorch.tril, checker=_tril_checker, execution_transform=_tril_transform) +_register_implementation(ltorch.triu, checker=_triu_checker, execution_transform=_triu_transform) _register_implementation(ltorch.where, where, checker=_always_executable) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index c63c4667b3..d4b0197240 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -2913,6 +2913,30 @@ def tril_sample_generator(op, device, dtype, requires_grad, **kwargs): ) conditional_and_mask_ops.append(tril_opinfo) +triu_opinfo = OpInfo( + ltorch.triu, + sample_input_generator=tril_sample_generator, + torch_reference=torch.triu, + test_directives=( + # Not all PyTorch versions support complex32 triu + DecorateInfo( + pytest.mark.xfail, + "test_core_vs_torch_consistency", + dtypes=(datatypes.complex32,), + ), + # PyTorch 2.0 doesn't support CUDA bfloat16 triu + DecorateInfo( + pytest.mark.xfail, + "test_core_vs_torch_consistency", + devicetypes=(devices.DeviceType.CUDA,), + dtypes=(datatypes.bfloat16,), + active_if=(LooseVersion(torch.__version__) < "2.1"), + ), + ), +) + +conditional_and_mask_ops.append(triu_opinfo) + # Puts all elementwise ternary opinfos into the "opinfos" list opinfos.extend(conditional_and_mask_ops) diff --git a/thunder/torch/__init__.py b/thunder/torch/__init__.py index eb771f05e3..0d10f03563 100644 --- a/thunder/torch/__init__.py +++ b/thunder/torch/__init__.py @@ -2409,6 +2409,29 @@ def tril_(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = No return prims.copy_(tril(a, diagonal, fill_value=fill_value), a) +# NOTE triu is the same as tril except that we modify the inequality to return the upper triangular +# NOTE matrix instead of the lower triangular matrix. +@torchsymbol(torch.triu, is_method=True) +def triu(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = None) -> TensorLike: + utils.check(a.ndim >= 2, lambda: f"triu: a ({a.ndim=}) must have at least two dimensions") + + nrows, ncols = a.shape[-2:] + row_numbers = arange(nrows, device=a.device).unsqueeze(-1) + col_numbers = arange(ncols, device=a.device).unsqueeze(-2) + + mask = (col_numbers - row_numbers) >= diagonal + + if fill_value is None: + fill_value = 0 + + return _mask_tensor(a, mask, fill_value) + + +@torchsymbol(torch.Tensor.triu_, is_method=True, tags=(prims.OpTags.IN_PLACE,)) +def triu_(a: TensorLike, /, diagonal: int = 0, *, fill_value: None | Number = None) -> TensorLike: + return prims.copy_(triu(a, diagonal, fill_value=fill_value), a) + + @torchsymbol(torch.where, is_method=True) def where( pred: TensorLike, diff --git a/thunder/torch/default_torch_ops.py b/thunder/torch/default_torch_ops.py index e6b56ece4b..009fe5aba6 100644 --- a/thunder/torch/default_torch_ops.py +++ b/thunder/torch/default_torch_ops.py @@ -293,7 +293,6 @@ torch.trapz, torch.triangular_solve, torch.triplet_margin_loss, - torch.triu, torch.unbind_copy, torch.unfold_copy, torch.unique_consecutive, @@ -609,7 +608,6 @@ torch.Tensor.tolist, torch.Tensor.trace, torch.Tensor.triangular_solve, - torch.Tensor.triu, torch.Tensor.unique, torch.Tensor.unique_consecutive, torch.Tensor.unsafe_chunk,