Skip to content

Commit

Permalink
Avoid manipulation of deprecated _mpm_cheap
Browse files Browse the repository at this point in the history
Internal API changed in numba 0.61

Existing benchmarks don't show any difference in performance
  • Loading branch information
ricardoV94 committed Feb 3, 2025
1 parent 42e31c4 commit 2f2d0d3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 66 deletions.
18 changes: 0 additions & 18 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import operator
import sys
import warnings
from contextlib import contextmanager
from copy import copy
from functools import singledispatch
from textwrap import dedent
Expand Down Expand Up @@ -362,23 +361,6 @@ def create_arg_string(x):
return args


@contextmanager
def use_optimized_cheap_pass(*args, **kwargs):
"""Temporarily replace the cheap optimization pass with a better one."""
from numba.core.registry import cpu_target

context = cpu_target.target_context._internal_codegen
old_pm = context._mpm_cheap
new_pm = context._module_pass_manager(
loop_vectorize=True, slp_vectorize=True, opt=3, cost="cheap"
)
context._mpm_cheap = new_pm
try:
yield
finally:
context._mpm_cheap = old_pm


@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
Expand Down
52 changes: 4 additions & 48 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_numba_signature,
numba_funcify,
numba_njit,
use_optimized_cheap_pass,
)
from pytensor.link.numba.dispatch.vectorize_codegen import (
_jit_options,
Expand Down Expand Up @@ -245,47 +243,6 @@ def {careduce_fn_name}(x):
return careduce_fn


def jit_compile_reducer(
node, fn, *, reduce_to_scalar=False, infer_signature=True, **kwds
):
"""Compile Python source for reduction loops using additional optimizations.
Parameters
==========
node
An node from which the signature can be derived.
fn
The Python function object to compile.
reduce_to_scalar: bool, default False
Whether to reduce output to a scalar (instead of 0d array)
infer_signature: bool: default True
Whether to try and infer the function signature from the Apply node.
kwds
Extra keywords to be added to the :func:`numba.njit` function.
Returns
=======
A :func:`numba.njit`-compiled function.
"""
if infer_signature:
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
args = (signature,)
else:
args = ()

# Eagerly compile the function using increased optimizations. This should
# help improve nested loop reductions.
with use_optimized_cheap_pass():
res = numba_basic.numba_njit(
*args,
boundscheck=False,
**kwds,
)(fn)

return res


def create_axis_apply_fn(fn, axis, ndim, dtype):
axis = normalize_axis_index(axis, ndim)

Expand Down Expand Up @@ -448,7 +405,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
np.dtype(node.outputs[0].type.dtype),
)

careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False)
careduce_fn = numba_njit(careduce_py_fn, boundscheck=False)
return careduce_fn


Expand Down Expand Up @@ -579,7 +536,7 @@ def softmax_py_fn(x):
sm = e_x / w
return sm

softmax = jit_compile_reducer(node, softmax_py_fn)
softmax = numba_njit(softmax_py_fn, boundscheck=False)

return softmax

Expand Down Expand Up @@ -608,8 +565,7 @@ def softmax_grad_py_fn(dy, sm):
dx = dy_times_sm - sum_dy_times_sm * sm
return dx

# The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn, infer_signature=False)
softmax_grad = numba_njit(softmax_grad_py_fn, boundscheck=False)

return softmax_grad

Expand Down Expand Up @@ -647,7 +603,7 @@ def log_softmax_py_fn(x):
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm

log_softmax = jit_compile_reducer(node, log_softmax_py_fn)
log_softmax = numba_njit(log_softmax_py_fn, boundscheck=False)
return log_softmax


Expand Down

0 comments on commit 2f2d0d3

Please sign in to comment.