Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use more conservative fastmath flags in numba backend #1147

Merged
merged 2 commits into from
Jan 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions doc/extending/creating_a_numba_jax_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,13 @@ Here's an example for the `CumOp`\ `Op`:
if mode == "add":
if axis is None or ndim == 1:

@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit()
def cumop(x):
return np.cumsum(x)

else:

@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand All @@ -382,13 +382,13 @@ Here's an example for the `CumOp`\ `Op`:
else:
if axis is None or ndim == 1:

@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit()
def cumop(x):
return np.cumprod(x)

else:

@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand Down
19 changes: 16 additions & 3 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,23 @@
return func


def numba_njit(*args, **kwargs):
def numba_njit(*args, fastmath=None, **kwargs):
kwargs.setdefault("cache", config.numba__cache)
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)
if fastmath is None:
if config.numba__fastmath:
# Opinionated default on fastmath flags
# https://llvm.org/docs/LangRef.html#fast-math-flags
fastmath = {
"arcp", # Allow Reciprocal
"contract", # Allow floating-point contraction
"afn", # Approximate functions
"reassoc",
"nsz", # no-signed zeros
}
else:
fastmath = False

Check warning on line 68 in pytensor/link/numba/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/basic.py#L68

Added line #L68 was not covered by tests

# Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
Expand All @@ -68,9 +81,9 @@
)

if len(args) > 0 and callable(args[0]):
return numba.njit(*args[1:], **kwargs)(args[0])
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])

return numba.njit(*args, **kwargs)
return numba.njit(*args, fastmath=fastmath, **kwargs)


def numba_vectorize(*args, **kwargs):
Expand Down
1 change: 0 additions & 1 deletion pytensor/link/numba/dispatch/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def numba_funcify_Blockwise(op: BlockwiseWithCoreShape, node, **kwargs):
core_op,
node=core_node,
parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs,
)
core_op_fn = store_core_outputs(core_op_fn, nin=nin, nout=nout)
Expand Down
58 changes: 5 additions & 53 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from collections.abc import Callable
from functools import singledispatch
from textwrap import dedent, indent
from typing import Any

import numba
import numpy as np
from numba.core.extending import overload
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple

from pytensor import config
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
Expand Down Expand Up @@ -124,42 +120,6 @@
"""


def create_vectorize_func(
scalar_op_fn: Callable,
node: Apply,
use_signature: bool = False,
identity: Any | None = None,
**kwargs,
) -> Callable:
r"""Create a vectorized Numba function from a `Apply`\s Python function."""

if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)

if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
else:
signature = []

target = (
getattr(node.tag, "numba__vectorize_target", None)
or config.numba__vectorize_target
)

numba_vectorized_fn = numba_basic.numba_vectorize(
signature, identity=identity, target=target, fastmath=config.numba__fastmath
)

py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn)

elemwise_fn = numba_vectorized_fn(scalar_op_fn)
elemwise_fn.py_scalar_func = py_scalar_func

return elemwise_fn


def create_multiaxis_reducer(
scalar_op,
identity,
Expand Down Expand Up @@ -320,7 +280,6 @@
res = numba_basic.numba_njit(
*args,
boundscheck=False,
fastmath=config.numba__fastmath,
**kwds,
)(fn)

Expand Down Expand Up @@ -354,7 +313,6 @@
op.scalar_op,
node=scalar_node,
parent_node=node,
fastmath=_jit_options["fastmath"],
**kwargs,
)

Expand Down Expand Up @@ -442,13 +400,13 @@

if ndim_input == len(axes):
# Slightly faster than `numba_funcify_CAReduce` for this case
@numba_njit(fastmath=config.numba__fastmath)
@numba_njit
def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)

elif len(axes) == 0:
# These cases should be removed by rewrites!
@numba_njit(fastmath=config.numba__fastmath)
@numba_njit

Check warning on line 409 in pytensor/link/numba/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/numba/dispatch/elemwise.py#L409

Added line #L409 was not covered by tests
def impl_sum(array):
return np.asarray(array, dtype=out_dtype)

Expand Down Expand Up @@ -607,9 +565,7 @@
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
)

jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else:
Expand Down Expand Up @@ -641,9 +597,7 @@
add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True
)

jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_sum = jit_fn(reduce_sum_py)
else:
reduce_sum = np.sum
Expand Down Expand Up @@ -681,9 +635,7 @@
add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True
)

jit_fn = numba_basic.numba_njit(
boundscheck=False, fastmath=config.numba__fastmath
)
jit_fn = numba_basic.numba_njit(boundscheck=False)
reduce_max = jit_fn(reduce_max_py)
reduce_sum = jit_fn(reduce_sum_py)
else:
Expand Down
9 changes: 4 additions & 5 deletions pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import numba
import numpy as np

from pytensor import config
from pytensor.graph import Apply
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
Expand Down Expand Up @@ -50,13 +49,13 @@ def numba_funcify_CumOp(op: CumOp, node: Apply, **kwargs):
if mode == "add":
if axis is None or ndim == 1:

@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def cumop(x):
return np.cumsum(x)

else:

@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand All @@ -74,13 +73,13 @@ def cumop(x):
else:
if axis is None or ndim == 1:

@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def cumop(x):
return np.cumprod(x)

else:

@numba_basic.numba_njit(boundscheck=False, fastmath=config.numba__fastmath)
@numba_basic.numba_njit(boundscheck=False)
def cumop(x):
out_dtype = x.dtype
if x.shape[axis] < 2:
Expand Down
22 changes: 8 additions & 14 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np

from pytensor import config
from pytensor.compile.ops import ViewOp
from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic
Expand Down Expand Up @@ -137,7 +136,6 @@ def {scalar_op_fn_name}({', '.join(input_names)}):

return numba_basic.numba_njit(
signature,
fastmath=config.numba__fastmath,
# Functions that call a function pointer can't be cached
cache=False,
)(scalar_op_fn)
Expand Down Expand Up @@ -177,19 +175,15 @@ def numba_funcify_Add(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "add", "+")

return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
nary_add_fn
)
return numba_basic.numba_njit(signature)(nary_add_fn)


@numba_funcify.register(Mul)
def numba_funcify_Mul(op, node, **kwargs):
signature = create_numba_signature(node, force_scalar=True)
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")

return numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
nary_add_fn
)
return numba_basic.numba_njit(signature)(nary_add_fn)


@numba_funcify.register(Cast)
Expand Down Expand Up @@ -239,7 +233,7 @@ def numba_funcify_Composite(op, node, **kwargs):

_ = kwargs.pop("storage_map", None)

composite_fn = numba_basic.numba_njit(signature, fastmath=config.numba__fastmath)(
composite_fn = numba_basic.numba_njit(signature)(
numba_funcify(op.fgraph, squeeze_output=True, **kwargs)
)
return composite_fn
Expand Down Expand Up @@ -267,7 +261,7 @@ def numba_funcify_Reciprocal(op, node, **kwargs):
return numba_basic.global_numba_func(reciprocal)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def sigmoid(x):
return 1 / (1 + np.exp(-x))

Expand All @@ -277,7 +271,7 @@ def numba_funcify_Sigmoid(op, node, **kwargs):
return numba_basic.global_numba_func(sigmoid)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def gammaln(x):
return math.lgamma(x)

Expand All @@ -287,7 +281,7 @@ def numba_funcify_GammaLn(op, node, **kwargs):
return numba_basic.global_numba_func(gammaln)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def logp1mexp(x):
if x < np.log(0.5):
return np.log1p(-np.exp(x))
Expand All @@ -300,7 +294,7 @@ def numba_funcify_Log1mexp(op, node, **kwargs):
return numba_basic.global_numba_func(logp1mexp)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def erf(x):
return math.erf(x)

Expand All @@ -310,7 +304,7 @@ def numba_funcify_Erf(op, **kwargs):
return numba_basic.global_numba_func(erf)


@numba_basic.numba_njit(fastmath=config.numba__fastmath)
@numba_basic.numba_njit
def erfc(x):
return math.erfc(x)

Expand Down
8 changes: 7 additions & 1 deletion tests/link/numba/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,13 @@ def test_config_options_fastmath():
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__))
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["fastmath"] is True
assert numba_mul_fn.targetoptions["fastmath"] == {
"afn",
"arcp",
"contract",
"nsz",
"reassoc",
}


def test_config_options_cached():
Expand Down
19 changes: 19 additions & 0 deletions tests/link/numba/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.scalar.basic import Composite
from pytensor.tensor import tensor
from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py, set_test_value

Expand Down Expand Up @@ -140,3 +141,21 @@ def test_reciprocal(v, dtype):
if not isinstance(i, SharedVariable | Constant)
],
)


@pytest.mark.parametrize("composite", (False, True))
def test_isnan(composite):
# Testing with tensor just to make sure Elemwise does not revert the scalar behavior of fastmath
x = tensor(shape=(2,), dtype="float64")

if composite:
x_scalar = psb.float64()
scalar_out = ~psb.isnan(x_scalar)
out = Elemwise(Composite([x_scalar], [scalar_out]))(x)
else:
out = pt.isnan(x)

compare_numba_and_py(
([x], [out]),
[np.array([1, 0], dtype="float64")],
)
Loading