From d64c021f310e606c853aab3b146aec567eb4e7f9 Mon Sep 17 00:00:00 2001 From: beverlylytle <57254617+beverlylytle@users.noreply.github.com> Date: Tue, 14 Jan 2025 10:52:04 +0100 Subject: [PATCH] Add singularity function for maximum and minimum (#1642) --- thunder/tests/opinfos.py | 40 +++++++++++++++++++++++++------------- thunder/tests/test_grad.py | 13 ++++++------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/thunder/tests/opinfos.py b/thunder/tests/opinfos.py index d4b0197240..07919c0ef4 100644 --- a/thunder/tests/opinfos.py +++ b/thunder/tests/opinfos.py @@ -25,7 +25,7 @@ from thunder.core.symbol import Symbol import thunder.executors as executors from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS -from thunder.tests.make_tensor import make_tensor +from thunder.tests.make_tensor import make_tensor, make_tensor_like import thunder.tests.bf16 import thunder.torch as ltorch @@ -60,16 +60,6 @@ def round_remainder(x, y): return x - torch.round(x / y) * y -def push_away_from_singularities(x, singularity_fn, eps): - """This function takes a tensor and moves individual values away - from singularities in `eps` increments, until they are further than - `eps` away from them. The `singularity_fn` returns the (signed) - distance from `x` to the nearest singularity.""" - x_dist = singularity_fn(x) - x_ = torch.where((x_dist >= 0) & (x_dist < eps), x + eps, x) - return torch.where((x_dist <= 0) & (x_dist > -eps), x_ - eps, x_) - - # Randomly select a fraction of the elements in a tensor and set them to specified value def replace_random_percentage(a: torch.Tensor, value: Number, percentage: float) -> torch.Tensor: flat = torch.flatten(a.detach().clone()) @@ -208,10 +198,24 @@ def _to(x): args, kwargs = tree_map(_to, self.args), tree_map(_to, self.kwargs) return SampleInput(*args, **kwargs) - def remove_singularities(self, singularity_fn, eps): + def remove_singularities(self, op, eps): + + singularity_fn = op.singularity_fn_producer(self) + if singularity_fn is None: + return self + + def _push_away_from_singularities(x, dist_fn, eps): + """This function takes a tensor and moves individual values away + from singularities in `eps` increments, until they are further than + `eps` away from them. The `dist_fn` returns the (signed) + distance from `x` to the nearest singularity.""" + x_dist = dist_fn(x) + x_ = torch.where((x_dist >= 0) & (x_dist < eps), x + eps, x) + return torch.where((x_dist < 0) & (x_dist > -eps), x_ - eps, x_) + def _remove_singularities(x): if isinstance(x, torch.Tensor) and datatypes.is_float_dtype(datatypes.to_dtype(x)): - return push_away_from_singularities(x, singularity_fn, eps) + return _push_away_from_singularities(x, singularity_fn, eps) return x @@ -2195,11 +2199,20 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): ) elementwise_binary_ops.append(lt_opinfo) + +def min_max_singularity_fn_producer(sample): + a, b = sample.args + if a.shape == b.shape or b.shape == (): + return lambda x: x - b if x is a else make_tensor_like(x, low=1) + return lambda x: x - a if x is b else make_tensor_like(x, low=1) + + maximum_opinfo = OpInfo( clang.maximum, sample_input_generator=partial(elementwise_binary_generator, no_rhs_numbers=True), torch_reference=torch.maximum, supports_grad=True, + singularity_fn_producer=min_max_singularity_fn_producer, ) elementwise_binary_ops.append(maximum_opinfo) @@ -2207,6 +2220,7 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs): clang.minimum, sample_input_generator=partial(elementwise_binary_generator, no_rhs_numbers=True), torch_reference=torch.minimum, + singularity_fn_producer=min_max_singularity_fn_producer, ) elementwise_binary_ops.append(minimum_opinfo) diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 406af29943..d3ee7bd42e 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -32,7 +32,7 @@ version_between, ) from thunder.tests.make_tensor import make_tensor, make_tensor_like -from thunder.tests.opinfos import opinfos, push_away_from_singularities, tensor_creation_ops, get_opinfo +from thunder.tests.opinfos import get_opinfo, opinfos, tensor_creation_ops # TODO: Move this to thunder.tests.opinfos op_skip = { @@ -409,14 +409,14 @@ def test_vjp_correctness(op, device, dtype, executor, comp): # for non-differentiable arguments like dtypes so that the test will # execute properly. sample = sample.thunder() # converts torch.dtype to thunder.dtype + sample = sample.remove_singularities(op, eps) flat_op, flat_args, spec = flatten_func(op.op, sample.args, sample.kwargs) filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args) if len(filtered_args) == 0: continue - if (singularity_fn := op.singularity_fn_producer(sample)) is not None: - filtered_args = [push_away_from_singularities(arg, singularity_fn, eps) for arg in filtered_args] + at_least_one_differentiable_input = True result = run_snippet( snippet_vjp_correctness, @@ -1302,9 +1302,7 @@ def func(a): # TODO Add more module tests -def snippet_phantom_grad_vs_torch_consistency(op, torch_op, sample, comp, singularity_fn): - if singularity_fn: - sample = sample.remove_singularities(singularity_fn, 1e-2) +def snippet_phantom_grad_vs_torch_consistency(op, torch_op, sample, comp): args, kwargs = sample.args, sample.kwargs @@ -1404,6 +1402,8 @@ def test_phantom_grad_vs_torch_consistency(op, device: str, dtype: dtypes.dtype, for sample in op.sample_inputs(device, dtype, requires_grad=True): comp = sample.comp if sample.comp is not None else comp + sample = sample.remove_singularities(op, 1e-2) + result = run_snippet( snippet_phantom_grad_vs_torch_consistency, op, @@ -1413,7 +1413,6 @@ def test_phantom_grad_vs_torch_consistency(op, device: str, dtype: dtypes.dtype, op.torch_reference, sample, lambda a, b, **kwargs: comp(a, b, equal_nan=True, **kwargs), - op.singularity_fn_producer(sample), ) # See [NOTE] dynamo reset