From 8f077145122ea935adefcfff3e2f9d310ad848f3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 3 Jan 2025 14:43:06 +0100 Subject: [PATCH 1/2] Remove unused numba dispatch function --- pytensor/link/numba/dispatch/elemwise.py | 39 ------------------------ 1 file changed, 39 deletions(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 2759422bf6..3559117d8a 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -1,7 +1,5 @@ -from collections.abc import Callable from functools import singledispatch from textwrap import dedent, indent -from typing import Any import numba import numpy as np @@ -9,7 +7,6 @@ from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from pytensor import config -from pytensor.graph.basic import Apply from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( @@ -124,42 +121,6 @@ def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr): """ -def create_vectorize_func( - scalar_op_fn: Callable, - node: Apply, - use_signature: bool = False, - identity: Any | None = None, - **kwargs, -) -> Callable: - r"""Create a vectorized Numba function from a `Apply`\s Python function.""" - - if len(node.outputs) > 1: - raise NotImplementedError( - "Multi-output Elemwise Ops are not supported by the Numba backend" - ) - - if use_signature: - signature = [create_numba_signature(node, force_scalar=True)] - else: - signature = [] - - target = ( - getattr(node.tag, "numba__vectorize_target", None) - or config.numba__vectorize_target - ) - - numba_vectorized_fn = numba_basic.numba_vectorize( - signature, identity=identity, target=target, fastmath=config.numba__fastmath - ) - - py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn) - - elemwise_fn = numba_vectorized_fn(scalar_op_fn) - elemwise_fn.py_scalar_func = py_scalar_func - - return elemwise_fn - - def create_multiaxis_reducer( scalar_op, identity, From 769d134a092766713150e18708fc771798f35cfb Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 2 Jan 2025 15:11:10 +0100 Subject: [PATCH 2/2] Use more specific Numba fastmath flags everywhere --- doc/extending/creating_a_numba_jax_op.rst | 8 ++++---- pytensor/link/numba/dispatch/basic.py | 19 ++++++++++++++++--- pytensor/link/numba/dispatch/blockwise.py | 1 - pytensor/link/numba/dispatch/elemwise.py | 19 +++++-------------- pytensor/link/numba/dispatch/extra_ops.py | 9 ++++----- pytensor/link/numba/dispatch/scalar.py | 22 ++++++++-------------- tests/link/numba/test_basic.py | 8 +++++++- tests/link/numba/test_scalar.py | 19 +++++++++++++++++++ 8 files changed, 63 insertions(+), 42 deletions(-) diff --git a/doc/extending/creating_a_numba_jax_op.rst b/doc/extending/creating_a_numba_jax_op.rst index 23faea9465..8be08b4953 100644 --- a/doc/extending/creating_a_numba_jax_op.rst +++ b/doc/extending/creating_a_numba_jax_op.rst @@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`: if mode == "add": if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit() def cumop(x): return np.cumsum(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`: else: if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit() def cumop(x): return np.cumprod(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 8bf827b52f..843a4dbf1f 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -49,10 +49,23 @@ def global_numba_func(func): return func -def numba_njit(*args, **kwargs): +def numba_njit(*args, fastmath=None, **kwargs): kwargs.setdefault("cache", config.numba__cache) kwargs.setdefault("no_cpython_wrapper", True) kwargs.setdefault("no_cfunc_wrapper", True) + if fastmath is None: + if config.numba__fastmath: + # Opinionated default on fastmath flags + # https://llvm.org/docs/LangRef.html#fast-math-flags + fastmath = { + "arcp", # Allow Reciprocal + "contract", # Allow floating-point contraction + "afn", # Approximate functions + "reassoc", + "nsz", # no-signed zeros + } + else: + fastmath = False # Suppress cache warning for internal functions # We have to add an ansi escape code for optional bold text by numba @@ -68,9 +81,9 @@ def numba_njit(*args, **kwargs): ) if len(args) > 0 and callable(args[0]): - return numba.njit(*args[1:], **kwargs)(args[0]) + return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0]) - return numba.njit(*args, **kwargs) + return numba.njit(*args, fastmath=fastmath, **kwargs) def numba_vectorize(*args, **kwargs): diff --git a/pytensor/link/numba/dispatch/blockwise.py b/pytensor/link/numba/dispatch/blockwise.py index 131788e843..b7481bd5a3 100644 --- a/pytensor/link/numba/dispatch/blockwise.py +++ b/pytensor/link/numba/dispatch/blockwise.py @@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs): core_op, node=core_node, parent_node=node, - fastmath=_jit_options["fastmath"], **kwargs, ) core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 3559117d8a..ae5ef3dcb1 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -6,7 +6,6 @@ from numba.core.extending import overload from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple -from pytensor import config from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import ( @@ -281,7 +280,6 @@ def jit_compile_reducer( res = numba_basic.numba_njit( *args, boundscheck=False, - fastmath=config.numba__fastmath, **kwds, )(fn) @@ -315,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs): op.scalar_op, node=scalar_node, parent_node=node, - fastmath=_jit_options["fastmath"], **kwargs, ) @@ -403,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs): if ndim_input == len(axes): # Slightly faster than `numba_funcify_CAReduce` for this case - @numba_njit(fastmath=config.numba__fastmath) + @numba_njit def impl_sum(array): return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) elif len(axes) == 0: # These cases should be removed by rewrites! - @numba_njit(fastmath=config.numba__fastmath) + @numba_njit def impl_sum(array): return np.asarray(array, dtype=out_dtype) @@ -568,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs): add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: @@ -602,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_sum = jit_fn(reduce_sum_py) else: reduce_sum = np.sum @@ -642,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True ) - jit_fn = numba_basic.numba_njit( - boundscheck=False, fastmath=config.numba__fastmath - ) + jit_fn = numba_basic.numba_njit(boundscheck=False) reduce_max = jit_fn(reduce_max_py) reduce_sum = jit_fn(reduce_sum_py) else: diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index 3629b0e44c..1f0a33e595 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -4,7 +4,6 @@ import numba import numpy as np -from pytensor import config from pytensor.graph import Apply from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify @@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs): if mode == "add": if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit def cumop(x): return np.cumsum(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: @@ -74,13 +73,13 @@ def cumop(x): else: if axis is None or ndim == 1: - @numba_basic.numba_njit(fastmath=config.numba__fastmath) + @numba_basic.numba_njit def cumop(x): return np.cumprod(x) else: - @numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath) + @numba_basic.numba_njit(boundscheck=False) def cumop(x): out_dtype = x.dtype if x.shape[axis] < 2: diff --git a/pytensor/link/numba/dispatch/scalar.py b/pytensor/link/numba/dispatch/scalar.py index 82ee380029..e9b637b00f 100644 --- a/pytensor/link/numba/dispatch/scalar.py +++ b/pytensor/link/numba/dispatch/scalar.py @@ -2,7 +2,6 @@ import numpy as np -from pytensor import config from pytensor.compile.ops import ViewOp from pytensor.graph.basic import Variable from pytensor.link.numba.dispatch import basic as numba_basic @@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}): return numba_basic.numba_njit( signature, - fastmath=config.numba__fastmath, # Functions that call a function pointer can't be cached cache=False, )(scalar_op_fn) @@ -177,9 +175,7 @@ def numba_funcify_Add(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "add", "+") - return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( - nary_add_fn - ) + return numba_basic.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Mul) @@ -187,9 +183,7 @@ def numba_funcify_Mul(op, node, **kwargs): signature = create_numba_signature(node, force_scalar=True) nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*") - return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( - nary_add_fn - ) + return numba_basic.numba_njit(signature)(nary_add_fn) @numba_funcify.register(Cast) @@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs): _ = kwargs.pop("storage_map", None) - composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)( + composite_fn = numba_basic.numba_njit(signature)( numba_funcify(op.fgraph, squeeze_output=True, **kwargs) ) return composite_fn @@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs): return numba_basic.global_numba_func(reciprocal) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs): return numba_basic.global_numba_func(sigmoid) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def gammaln(x): return math.lgamma(x) @@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs): return numba_basic.global_numba_func(gammaln) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def logp1mexp(x): if x < np.log(0.5): return np.log1p(-np.exp(x)) @@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs): return numba_basic.global_numba_func(logp1mexp) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def erf(x): return math.erf(x) @@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs): return numba_basic.global_numba_func(erf) -@numba_basic.numba_njit(fastmath=config.numba__fastmath) +@numba_basic.numba_njit def erfc(x): return math.erfc(x) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 9a3e96c858..1b0fa8fd52 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -838,7 +838,13 @@ def test_config_options_fastmath(): pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode) print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__)) numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"] - assert numba_mul_fn.targetoptions["fastmath"] is True + assert numba_mul_fn.targetoptions["fastmath"] == { + "afn", + "arcp", + "contract", + "nsz", + "reassoc", + } def test_config_options_cached(): diff --git a/tests/link/numba/test_scalar.py b/tests/link/numba/test_scalar.py index 437956bdc0..655e507da6 100644 --- a/tests/link/numba/test_scalar.py +++ b/tests/link/numba/test_scalar.py @@ -9,6 +9,7 @@ from pytensor.graph.basic import Constant from pytensor.graph.fg import FunctionGraph from pytensor.scalar.basic import Composite +from pytensor.tensor import tensor from pytensor.tensor.elemwise import Elemwise from tests.link.numba.test_basic import compare_numba_and_py, set_test_value @@ -140,3 +141,21 @@ def test_reciprocal(v, dtype): if not isinstance(i, SharedVariable | Constant) ], ) + + +@pytest.mark.parametrize("composite", (False, True)) +def test_isnan(composite): + # Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath + x = tensor(shape=(2,), dtype="float64") + + if composite: + x_scalar = psb.float64() + scalar_out = ~psb.isnan(x_scalar) + out = Elemwise(Composite([x_scalar], [scalar_out]))(x) + else: + out = pt.isnan(x) + + compare_numba_and_py( + ([x], [out]), + [np.array([1, 0], dtype="float64")], + )