Skip to content

Commit

Permalink
Add singularity function for maximum and minimum (#1642)
Browse files Browse the repository at this point in the history
  • Loading branch information
beverlylytle authored and riccardofelluga committed Jan 27, 2025
1 parent 35e68d5 commit d64c021
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
40 changes: 27 additions & 13 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from thunder.core.symbol import Symbol
import thunder.executors as executors
from thunder.tests.framework import _all_devicetypes, JAX_AVAILABLE, custom_comparator, IS_WINDOWS
from thunder.tests.make_tensor import make_tensor
from thunder.tests.make_tensor import make_tensor, make_tensor_like
import thunder.tests.bf16
import thunder.torch as ltorch

Expand Down Expand Up @@ -60,16 +60,6 @@ def round_remainder(x, y):
return x - torch.round(x / y) * y


def push_away_from_singularities(x, singularity_fn, eps):
"""This function takes a tensor and moves individual values away
from singularities in `eps` increments, until they are further than
`eps` away from them. The `singularity_fn` returns the (signed)
distance from `x` to the nearest singularity."""
x_dist = singularity_fn(x)
x_ = torch.where((x_dist >= 0) & (x_dist < eps), x + eps, x)
return torch.where((x_dist <= 0) & (x_dist > -eps), x_ - eps, x_)


# Randomly select a fraction of the elements in a tensor and set them to specified value
def replace_random_percentage(a: torch.Tensor, value: Number, percentage: float) -> torch.Tensor:
flat = torch.flatten(a.detach().clone())
Expand Down Expand Up @@ -208,10 +198,24 @@ def _to(x):
args, kwargs = tree_map(_to, self.args), tree_map(_to, self.kwargs)
return SampleInput(*args, **kwargs)

def remove_singularities(self, singularity_fn, eps):
def remove_singularities(self, op, eps):

singularity_fn = op.singularity_fn_producer(self)
if singularity_fn is None:
return self

def _push_away_from_singularities(x, dist_fn, eps):
"""This function takes a tensor and moves individual values away
from singularities in `eps` increments, until they are further than
`eps` away from them. The `dist_fn` returns the (signed)
distance from `x` to the nearest singularity."""
x_dist = dist_fn(x)
x_ = torch.where((x_dist >= 0) & (x_dist < eps), x + eps, x)
return torch.where((x_dist < 0) & (x_dist > -eps), x_ - eps, x_)

def _remove_singularities(x):
if isinstance(x, torch.Tensor) and datatypes.is_float_dtype(datatypes.to_dtype(x)):
return push_away_from_singularities(x, singularity_fn, eps)
return _push_away_from_singularities(x, singularity_fn, eps)

return x

Expand Down Expand Up @@ -2195,18 +2199,28 @@ def fmod_sample_input_generator(op, device, dtype, requires_grad, **kwargs):
)
elementwise_binary_ops.append(lt_opinfo)


def min_max_singularity_fn_producer(sample):
a, b = sample.args
if a.shape == b.shape or b.shape == ():
return lambda x: x - b if x is a else make_tensor_like(x, low=1)
return lambda x: x - a if x is b else make_tensor_like(x, low=1)


maximum_opinfo = OpInfo(
clang.maximum,
sample_input_generator=partial(elementwise_binary_generator, no_rhs_numbers=True),
torch_reference=torch.maximum,
supports_grad=True,
singularity_fn_producer=min_max_singularity_fn_producer,
)
elementwise_binary_ops.append(maximum_opinfo)

minimum_opinfo = OpInfo(
clang.minimum,
sample_input_generator=partial(elementwise_binary_generator, no_rhs_numbers=True),
torch_reference=torch.minimum,
singularity_fn_producer=min_max_singularity_fn_producer,
)
elementwise_binary_ops.append(minimum_opinfo)

Expand Down
13 changes: 6 additions & 7 deletions thunder/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
version_between,
)
from thunder.tests.make_tensor import make_tensor, make_tensor_like
from thunder.tests.opinfos import opinfos, push_away_from_singularities, tensor_creation_ops, get_opinfo
from thunder.tests.opinfos import get_opinfo, opinfos, tensor_creation_ops

# TODO: Move this to thunder.tests.opinfos
op_skip = {
Expand Down Expand Up @@ -409,14 +409,14 @@ def test_vjp_correctness(op, device, dtype, executor, comp):
# for non-differentiable arguments like dtypes so that the test will
# execute properly.
sample = sample.thunder() # converts torch.dtype to thunder.dtype
sample = sample.remove_singularities(op, eps)

flat_op, flat_args, spec = flatten_func(op.op, sample.args, sample.kwargs)

filtered_op, filtered_args = _make_differentiable_wrapper(flat_op, flat_args)
if len(filtered_args) == 0:
continue
if (singularity_fn := op.singularity_fn_producer(sample)) is not None:
filtered_args = [push_away_from_singularities(arg, singularity_fn, eps) for arg in filtered_args]

at_least_one_differentiable_input = True
result = run_snippet(
snippet_vjp_correctness,
Expand Down Expand Up @@ -1302,9 +1302,7 @@ def func(a):
# TODO Add more module tests


def snippet_phantom_grad_vs_torch_consistency(op, torch_op, sample, comp, singularity_fn):
if singularity_fn:
sample = sample.remove_singularities(singularity_fn, 1e-2)
def snippet_phantom_grad_vs_torch_consistency(op, torch_op, sample, comp):

args, kwargs = sample.args, sample.kwargs

Expand Down Expand Up @@ -1404,6 +1402,8 @@ def test_phantom_grad_vs_torch_consistency(op, device: str, dtype: dtypes.dtype,
for sample in op.sample_inputs(device, dtype, requires_grad=True):
comp = sample.comp if sample.comp is not None else comp

sample = sample.remove_singularities(op, 1e-2)

result = run_snippet(
snippet_phantom_grad_vs_torch_consistency,
op,
Expand All @@ -1413,7 +1413,6 @@ def test_phantom_grad_vs_torch_consistency(op, device: str, dtype: dtypes.dtype,
op.torch_reference,
sample,
lambda a, b, **kwargs: comp(a, b, equal_nan=True, **kwargs),
op.singularity_fn_producer(sample),
)

# See [NOTE] dynamo reset
Expand Down

0 comments on commit d64c021

Please sign in to comment.