From 361280c871a622c22edbbf41b69e8052da29bc2b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 11:16:08 +0200 Subject: [PATCH 01/19] Update numpy deprecated imports - replaced np.AxisError with np.exceptions.AxisError - the `numpy.core` submodule has been renamed to `numpy._core` - some parts of `numpy.core` have been moved to `numpy.lib.array_utils` Except for `AxisError`, the updated imports are conditional on the version of numpy, so the imports should work for numpy >= 1.26. The conditional imports have been added to `npy_2_compat.py`, so the imports elsewhere are unconditonal. --- pytensor/link/c/basic.py | 7 +- pytensor/link/numba/dispatch/elemwise.py | 2 +- pytensor/npy_2_compat.py | 275 +++++++++++++++++++++++ pytensor/tensor/__init__.py | 2 +- pytensor/tensor/basic.py | 8 +- pytensor/tensor/conv/abstract_conv.py | 3 +- pytensor/tensor/einsum.py | 9 +- pytensor/tensor/elemwise.py | 5 +- pytensor/tensor/extra_ops.py | 12 +- pytensor/tensor/math.py | 2 +- pytensor/tensor/nlinalg.py | 2 +- pytensor/tensor/shape.py | 2 +- pytensor/tensor/slinalg.py | 3 +- pytensor/tensor/subtensor.py | 2 + pytensor/tensor/utils.py | 6 +- tests/tensor/test_elemwise.py | 2 +- tests/tensor/test_extra_ops.py | 2 +- tests/tensor/test_io.py | 2 +- 18 files changed, 311 insertions(+), 35 deletions(-) create mode 100644 pytensor/npy_2_compat.py diff --git a/pytensor/link/c/basic.py b/pytensor/link/c/basic.py index d7f43e7377..d509bd1d76 100644 --- a/pytensor/link/c/basic.py +++ b/pytensor/link/c/basic.py @@ -10,8 +10,6 @@ from io import StringIO from typing import TYPE_CHECKING, Any, Optional -import numpy as np - from pytensor.compile.compilelock import lock_ctx from pytensor.configdefaults import config from pytensor.graph.basic import ( @@ -33,6 +31,7 @@ from pytensor.link.c.cmodule import get_module_cache as _get_module_cache from pytensor.link.c.interface import CLinkerObject, CLinkerOp, CLinkerType from pytensor.link.utils import gc_helper, map_storage, raise_with_op, streamline +from pytensor.npy_2_compat import ndarray_c_version from pytensor.utils import difference, uniq @@ -1367,10 +1366,6 @@ def cmodule_key_( # We must always add the numpy ABI version here as # DynamicModule always add the include - if np.lib.NumpyVersion(np.__version__) < "1.16.0a": - ndarray_c_version = np.core.multiarray._get_ndarray_c_version() - else: - ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() sig.append(f"NPY_ABI_VERSION=0x{ndarray_c_version:X}") if c_compiler: sig.append("c_compiler_str=" + c_compiler.version_str()) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 2a98985efe..03c7084a8f 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -4,7 +4,6 @@ import numba import numpy as np from numba.core.extending import overload -from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple from pytensor.graph.op import Op from pytensor.link.numba.dispatch import basic as numba_basic @@ -19,6 +18,7 @@ store_core_outputs, ) from pytensor.link.utils import compile_function_src +from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.scalar.basic import ( AND, OR, diff --git a/pytensor/npy_2_compat.py b/pytensor/npy_2_compat.py new file mode 100644 index 0000000000..30214154a2 --- /dev/null +++ b/pytensor/npy_2_compat.py @@ -0,0 +1,275 @@ +from textwrap import dedent + +import numpy as np + + +# Conditional numpy imports for numpy 1.26 and 2.x compatibility +try: + from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple +except ModuleNotFoundError: + # numpy < 2.0 + from numpy.core.multiarray import normalize_axis_index # type: ignore[no-redef] + from numpy.core.numeric import normalize_axis_tuple # type: ignore[no-redef] + + +try: + from numpy._core.einsumfunc import ( # type: ignore[attr-defined] + _find_contraction, + _parse_einsum_input, + ) +except ModuleNotFoundError: + from numpy.core.einsumfunc import ( # type: ignore[no-redef] + _find_contraction, + _parse_einsum_input, + ) + + +# suppress linting warning by "using" the imports here: +__all__ = [ + "_find_contraction", + "_parse_einsum_input", + "normalize_axis_index", + "normalize_axis_tuple", +] + + +numpy_version_tuple = tuple(int(n) for n in np.__version__.split(".")[:2]) +numpy_version = np.lib.NumpyVersion( + np.__version__ +) # used to compare with version strings, e.g. numpy_version < "1.16.0" +using_numpy_2 = numpy_version >= "2.0.0rc1" + + +if using_numpy_2: + ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() +else: + ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined] + + +if using_numpy_2: + UintOverflowError = OverflowError +else: + UintOverflowError = TypeError + + +def npy_2_compat_header() -> str: + """Compatibility header that Numpy suggests is vendored with code that uses Numpy < 2.0 and Numpy 2.x""" + return dedent(""" + #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ + #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ + + + /* + * This header is meant to be included by downstream directly for 1.x compat. + * In that case we need to ensure that users first included the full headers + * and not just `ndarraytypes.h`. + */ + + #ifndef NPY_FEATURE_VERSION + #error "The NumPy 2 compat header requires `import_array()` for which " \\ + "the `ndarraytypes.h` header include is not sufficient. Please " \\ + "include it after `numpy/ndarrayobject.h` or similar." \\ + "" \\ + "To simplify inclusion, you may use `PyArray_ImportNumPy()` " \\ + "which is defined in the compat header and is lightweight (can be)." + #endif + + #if NPY_ABI_VERSION < 0x02000000 + /* + * Define 2.0 feature version as it is needed below to decide whether we + * compile for both 1.x and 2.x (defining it gaurantees 1.x only). + */ + #define NPY_2_0_API_VERSION 0x00000012 + /* + * If we are compiling with NumPy 1.x, PyArray_RUNTIME_VERSION so we + * pretend the `PyArray_RUNTIME_VERSION` is `NPY_FEATURE_VERSION`. + * This allows downstream to use `PyArray_RUNTIME_VERSION` if they need to. + */ + #define PyArray_RUNTIME_VERSION NPY_FEATURE_VERSION + /* Compiling on NumPy 1.x where these are the same: */ + #define PyArray_DescrProto PyArray_Descr + #endif + + + /* + * Define a better way to call `_import_array()` to simplify backporting as + * we now require imports more often (necessary to make ABI flexible). + */ + #ifdef import_array1 + + static inline int + PyArray_ImportNumPyAPI() + { + if (NPY_UNLIKELY(PyArray_API == NULL)) { + import_array1(-1); + } + return 0; + } + + #endif /* import_array1 */ + + + /* + * NPY_DEFAULT_INT + * + * The default integer has changed, `NPY_DEFAULT_INT` is available at runtime + * for use as type number, e.g. `PyArray_DescrFromType(NPY_DEFAULT_INT)`. + * + * NPY_RAVEL_AXIS + * + * This was introduced in NumPy 2.0 to allow indicating that an axis should be + * raveled in an operation. Before NumPy 2.0, NPY_MAXDIMS was used for this purpose. + * + * NPY_MAXDIMS + * + * A constant indicating the maximum number dimensions allowed when creating + * an ndarray. + * + * NPY_NTYPES_LEGACY + * + * The number of built-in NumPy dtypes. + */ + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + #define NPY_DEFAULT_INT NPY_INTP + #define NPY_RAVEL_AXIS NPY_MIN_INT + #define NPY_MAXARGS 64 + + #elif NPY_ABI_VERSION < 0x02000000 + #define NPY_DEFAULT_INT NPY_LONG + #define NPY_RAVEL_AXIS 32 + #define NPY_MAXARGS 32 + + /* Aliases of 2.x names to 1.x only equivalent names */ + #define NPY_NTYPES NPY_NTYPES_LEGACY + #define PyArray_DescrProto PyArray_Descr + #define _PyArray_LegacyDescr PyArray_Descr + /* NumPy 2 definition always works, but add it for 1.x only */ + #define PyDataType_ISLEGACY(dtype) (1) + #else + #define NPY_DEFAULT_INT \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? NPY_INTP : NPY_LONG) + #define NPY_RAVEL_AXIS \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? -1 : 32) + #define NPY_MAXARGS \\ + (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION ? 64 : 32) + #endif + + + /* + * Access inline functions for descriptor fields. Except for the first + * few fields, these needed to be moved (elsize, alignment) for + * additional space. Or they are descriptor specific and are not generally + * available anymore (metadata, c_metadata, subarray, names, fields). + * + * Most of these are defined via the `DESCR_ACCESSOR` macro helper. + */ + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION || NPY_ABI_VERSION < 0x02000000 + /* Compiling for 1.x or 2.x only, direct field access is OK: */ + + static inline void + PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size) + { + dtype->elsize = size; + } + + static inline npy_uint64 + PyDataType_FLAGS(const PyArray_Descr *dtype) + { + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + return dtype->flags; + #else + return (unsigned char)dtype->flags; /* Need unsigned cast on 1.x */ + #endif + } + + #define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\ + static inline type \\ + PyDataType_##FIELD(const PyArray_Descr *dtype) { \\ + if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\ + return (type)0; \\ + } \\ + return ((_PyArray_LegacyDescr *)dtype)->field; \\ + } + #else /* compiling for both 1.x and 2.x */ + + static inline void + PyDataType_SET_ELSIZE(PyArray_Descr *dtype, npy_intp size) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + ((_PyArray_DescrNumPy2 *)dtype)->elsize = size; + } + else { + ((PyArray_DescrProto *)dtype)->elsize = (int)size; + } + } + + static inline npy_uint64 + PyDataType_FLAGS(const PyArray_Descr *dtype) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + return ((_PyArray_DescrNumPy2 *)dtype)->flags; + } + else { + return (unsigned char)((PyArray_DescrProto *)dtype)->flags; + } + } + + /* Cast to LegacyDescr always fine but needed when `legacy_only` */ + #define DESCR_ACCESSOR(FIELD, field, type, legacy_only) \\ + static inline type \\ + PyDataType_##FIELD(const PyArray_Descr *dtype) { \\ + if (legacy_only && !PyDataType_ISLEGACY(dtype)) { \\ + return (type)0; \\ + } \\ + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { \\ + return ((_PyArray_LegacyDescr *)dtype)->field; \\ + } \\ + else { \\ + return ((PyArray_DescrProto *)dtype)->field; \\ + } \\ + } + #endif + + DESCR_ACCESSOR(ELSIZE, elsize, npy_intp, 0) + DESCR_ACCESSOR(ALIGNMENT, alignment, npy_intp, 0) + DESCR_ACCESSOR(METADATA, metadata, PyObject *, 1) + DESCR_ACCESSOR(SUBARRAY, subarray, PyArray_ArrayDescr *, 1) + DESCR_ACCESSOR(NAMES, names, PyObject *, 1) + DESCR_ACCESSOR(FIELDS, fields, PyObject *, 1) + DESCR_ACCESSOR(C_METADATA, c_metadata, NpyAuxData *, 1) + + #undef DESCR_ACCESSOR + + + #if !(defined(NPY_INTERNAL_BUILD) && NPY_INTERNAL_BUILD) + #if NPY_FEATURE_VERSION >= NPY_2_0_API_VERSION + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + return _PyDataType_GetArrFuncs(descr); + } + #elif NPY_ABI_VERSION < 0x02000000 + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + return descr->f; + } + #else + static inline PyArray_ArrFuncs * + PyDataType_GetArrFuncs(const PyArray_Descr *descr) + { + if (PyArray_RUNTIME_VERSION >= NPY_2_0_API_VERSION) { + return _PyDataType_GetArrFuncs(descr); + } + else { + return ((PyArray_DescrProto *)descr)->f; + } + } + #endif + + + #endif /* not internal build */ + + #endif /* NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPAT_H_ */ + + """) diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 67b6ab071e..88d3f33199 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -123,7 +123,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: # isort: on # Allow accessing numpy constants from pytensor.tensor -from numpy import e, euler_gamma, inf, infty, nan, newaxis, pi +from numpy import e, euler_gamma, inf, nan, newaxis, pi from pytensor.tensor.basic import * from pytensor.tensor.blas import batched_dot, batched_tensordot diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 26bd34692b..061a159fc2 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -14,8 +14,7 @@ from typing import cast as type_cast import numpy as np -from numpy.core.multiarray import normalize_axis_index -from numpy.core.numeric import normalize_axis_tuple +from numpy.exceptions import AxisError import pytensor import pytensor.scalar.sharedvar @@ -32,6 +31,7 @@ from pytensor.graph.type import HasShape, Type from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.raise_op import CheckAndRaise from pytensor.scalar import int32 @@ -228,7 +228,7 @@ def constant(x, name=None, ndim=None, dtype=None) -> TensorConstant: elif x_.ndim > ndim: try: x_ = np.squeeze(x_, axis=tuple(range(x_.ndim - ndim))) - except np.AxisError: + except AxisError: raise ValueError( f"ndarray could not be cast to constant with {int(ndim)} dimensions" ) @@ -4405,7 +4405,7 @@ def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVa axis = (axis,) out_ndim = len(axis) + a.ndim - axis = np.core.numeric.normalize_axis_tuple(axis, out_ndim) + axis = normalize_axis_tuple(axis, out_ndim) if not axis: return a diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index d1dfe44b90..fc937bf404 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -8,6 +8,7 @@ from math import gcd import numpy as np +from numpy.exceptions import ComplexWarning try: @@ -2338,7 +2339,7 @@ def conv( bval = _bvalfromboundary("fill") with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.ComplexWarning) + warnings.simplefilter("ignore", ComplexWarning) for b in range(img.shape[0]): for g in range(self.num_groups): for n in range(output_channel_offset): diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index cba40ec6f8..88a6257c9c 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -6,13 +6,14 @@ from typing import cast import numpy as np -from numpy.core.einsumfunc import _find_contraction, _parse_einsum_input # type: ignore -from numpy.core.numeric import ( # type: ignore + +from pytensor.compile.builders import OpFromGraph +from pytensor.npy_2_compat import ( + _find_contraction, + _parse_einsum_input, normalize_axis_index, normalize_axis_tuple, ) - -from pytensor.compile.builders import OpFromGraph from pytensor.tensor import TensorLike from pytensor.tensor.basic import ( arange, diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index c37597906a..a07ec0d9dd 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -4,7 +4,6 @@ from typing import Literal import numpy as np -from numpy.core.numeric import normalize_axis_tuple import pytensor.tensor.basic from pytensor.configdefaults import config @@ -17,6 +16,7 @@ from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp from pytensor.link.c.params_type import ParamsType from pytensor.misc.frozendict import frozendict +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type from pytensor.scalar.basic import bool as scalar_bool @@ -41,9 +41,6 @@ from pytensor.utils import uniq -_numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] - - class DimShuffle(ExternalCOp): """ Allows to reorder the dimensions of a tensor or insert or remove diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 27eabc5ba4..e9d06ae9c2 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -2,7 +2,7 @@ from collections.abc import Collection, Iterable import numpy as np -from numpy.core.multiarray import normalize_axis_index +from numpy.exceptions import AxisError import pytensor import pytensor.scalar.basic as ps @@ -17,6 +17,10 @@ from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.link.c.type import EnumList, Generic +from pytensor.npy_2_compat import ( + normalize_axis_index, + normalize_axis_tuple, +) from pytensor.raise_op import Assert from pytensor.scalar import int32 as int_t from pytensor.scalar import upcast @@ -596,9 +600,9 @@ def squeeze(x, axis=None): # scalar inputs are treated as 1D regarding axis in this `Op` try: - axis = np.core.numeric.normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) - except np.AxisError: - raise np.AxisError(axis, ndim=_x.ndim) + axis = normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) + except AxisError: + raise AxisError(axis, ndim=_x.ndim) if not axis: # Nothing to do diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 4dbf30685d..c4f3dc50a5 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Optional import numpy as np -from numpy.core.numeric import normalize_axis_tuple from pytensor import config, printing from pytensor import scalar as ps @@ -14,6 +13,7 @@ from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.printing import pprint from pytensor.raise_op import Assert from pytensor.scalar.basic import BinaryScalarOp diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index a9d7016099..ee33f6533c 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -4,13 +4,13 @@ from typing import Literal, cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore from pytensor import scalar as ps from pytensor.compile.builders import OpFromGraph from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.tensor import TensorLike from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 1c23a21347..e839ac1f08 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -6,7 +6,6 @@ from typing import cast as typing_cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore import pytensor from pytensor.gradient import DisconnectedType @@ -16,6 +15,7 @@ from pytensor.graph.type import HasShape from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.scalar import int32 from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import basic as ptb diff --git a/pytensor/tensor/slinalg.py b/pytensor/tensor/slinalg.py index f101315172..94973810fd 100644 --- a/pytensor/tensor/slinalg.py +++ b/pytensor/tensor/slinalg.py @@ -6,6 +6,7 @@ import numpy as np import scipy.linalg as scipy_linalg +from numpy.exceptions import ComplexWarning import pytensor import pytensor.tensor as pt @@ -767,7 +768,7 @@ def perform(self, node, inputs, outputs): Y = U.dot(V.T.dot(gA).dot(U) * X).dot(V.T) with warnings.catch_warnings(): - warnings.simplefilter("ignore", np.ComplexWarning) + warnings.simplefilter("ignore", ComplexWarning) out[0] = Y.astype(A.dtype) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index a3a81f63bd..46b9cc06fd 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -18,6 +18,7 @@ from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType +from pytensor.npy_2_compat import numpy_version, using_numpy_2 from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.tensor import ( @@ -2522,6 +2523,7 @@ def c_code(self, node, name, input_names, output_names, sub): numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] if bool(numpy_ver < [1, 8]): raise NotImplementedError + x, y, idx = input_names out = output_names[0] copy_of_x = self.copy_of_x(x) diff --git a/pytensor/tensor/utils.py b/pytensor/tensor/utils.py index e6451c9236..9ce12296cd 100644 --- a/pytensor/tensor/utils.py +++ b/pytensor/tensor/utils.py @@ -3,10 +3,10 @@ from typing import cast import numpy as np -from numpy.core.numeric import normalize_axis_tuple # type: ignore import pytensor from pytensor.graph import FunctionGraph, Variable +from pytensor.npy_2_compat import normalize_axis_tuple from pytensor.utils import hash_from_code @@ -236,8 +236,8 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None: if axis is not None: try: axis = normalize_axis_tuple(axis, ndim=max(1, ndim)) - except np.AxisError: - raise np.AxisError(axis, ndim=ndim) + except np.exceptions.AxisError: + raise np.exceptions.AxisError(axis, ndim=ndim) # TODO: If axis tuple is equivalent to None, return None for more canonicalization? return cast(tuple, axis) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index bd208c5848..8555a1d29f 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -672,7 +672,7 @@ def test_scalar_input(self): assert self.op(ps.add, axis=(-1,))(x).eval({x: 5}) == 5 with pytest.raises( - np.AxisError, + np.exceptions.AxisError, match=re.escape("axis (-2,) is out of bounds for array of dimension 0"), ): self.op(ps.add, axis=(-2,))(x) diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index e4f4945393..8bf689bc15 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -469,7 +469,7 @@ def test_scalar_input(self): assert squeeze(x, axis=(0,)).eval({x: 5}) == 5 with pytest.raises( - np.AxisError, + np.exceptions.AxisError, match=re.escape("axis (1,) is out of bounds for array of dimension 0"), ): squeeze(x, axis=1) diff --git a/tests/tensor/test_io.py b/tests/tensor/test_io.py index cece2af277..4c5e5655fe 100644 --- a/tests/tensor/test_io.py +++ b/tests/tensor/test_io.py @@ -49,7 +49,7 @@ def test_memmap(self): path = Variable(Generic(), None) x = load(path, "int32", (None,), mmap_mode="c") fn = function([path], x) - assert isinstance(fn(self.filename), np.core.memmap) + assert isinstance(fn(self.filename), np.memmap) def teardown_method(self): (pytensor.config.compiledir / "_test.npy").unlink() From e6c26b23f38265ebeac78d03abee7e0c4753f34f Mon Sep 17 00:00:00 2001 From: Virgile Andreani Date: Wed, 3 Apr 2024 10:53:07 -0400 Subject: [PATCH 02/19] Changes for numpy 2.0 deprecations - Replace np.cast with np.asarray: in numpy 2.0, `np.cast[new_dtype](arr)` is deprecated. The literal replacement is `np.asarray(arr, dtype=new_dtype)`. - Replace np.sctype2char and np.obj2sctype. Added try/except to handle change in behavior of `np.dtype` - Replace np.find_common_type with np.result_type Further changes to `TensorType`: TensorType.dtype must be a string, so the code has been changed from `self.dtype = np.dtype(dtype).type`, where the right-hand side is of type `np.generic`, to `self.dtype = str(np.dtype(dtype))`, where the right-hand side is a string that satisfies: `self.dtype == str(np.dtype(self.dtype))` This doesn't change the behavior of `np.array(..., dtype=self.dtype)` etc. --- pytensor/scalar/basic.py | 22 +++++++++++----------- pytensor/tensor/elemwise.py | 2 +- pytensor/tensor/type.py | 27 +++++++++++++++------------ tests/scan/test_rewriting.py | 2 +- tests/tensor/test_extra_ops.py | 6 +++--- tests/tensor/utils.py | 2 +- tests/test_gradient.py | 28 +++++++++++++++------------- tests/typed_list/test_basic.py | 8 ++++---- 8 files changed, 51 insertions(+), 46 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index c13afbd6fa..94039f8091 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -2966,7 +2966,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (x * np.asarray(math.log(2.0)).astype(x.dtype)),) + return (gz / (x * np.array(math.log(2.0), dtype=x.dtype)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3009,7 +3009,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (x * np.asarray(math.log(10.0)).astype(x.dtype)),) + return (gz / (x * np.array(math.log(10.0), dtype=x.dtype)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3124,7 +3124,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * exp2(x) * log(np.cast[x.type](2)),) + return (gz * exp2(x) * log(np.array(2, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3263,7 +3263,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * np.asarray(np.pi / 180, gz.type),) + return (gz * np.array(np.pi / 180, dtype=gz.type),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3298,7 +3298,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz * np.asarray(180.0 / np.pi, gz.type),) + return (gz * np.array(180.0 / np.pi, dtype=gz.type),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3371,7 +3371,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (-gz / sqrt(np.cast[x.type](1) - sqr(x)),) + return (-gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3445,7 +3445,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(np.cast[x.type](1) - sqr(x)),) + return (gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3517,7 +3517,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (np.cast[x.type](1) + sqr(x)),) + return (gz / (np.array(1, dtype=x.type) + sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3640,7 +3640,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(sqr(x) - np.cast[x.type](1)),) + return (gz / sqrt(sqr(x) - np.array(1, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3717,7 +3717,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / sqrt(sqr(x) + np.cast[x.type](1)),) + return (gz / sqrt(sqr(x) + np.array(1, dtype=x.type)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs @@ -3795,7 +3795,7 @@ def L_op(self, inputs, outputs, gout): else: return [x.zeros_like()] - return (gz / (np.cast[x.type](1) - sqr(x)),) + return (gz / (np.array(1, dtype=x.type) - sqr(x)),) def c_code(self, node, name, inputs, outputs, sub): (x,) = inputs diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index a07ec0d9dd..37acfc8e86 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -668,7 +668,7 @@ def prepare_node(self, node, storage_map, compute_map, impl): and isinstance(self.nfunc, np.ufunc) and node.inputs[0].dtype in discrete_dtypes ): - char = np.sctype2char(out_dtype) + char = np.dtype(out_dtype).char sig = char * node.nin + "->" + char * node.nout node.tag.sig = sig node.tag.fake_node = Apply( diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 0f99fa48aa..d48a7a6f08 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Literal, Optional import numpy as np +import numpy.typing as npt import pytensor from pytensor import scalar as ps @@ -69,7 +70,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): def __init__( self, - dtype: str | np.dtype, + dtype: str | npt.DTypeLike, shape: Iterable[bool | int | None] | None = None, name: str | None = None, broadcastable: Iterable[bool] | None = None, @@ -101,11 +102,11 @@ def __init__( if str(dtype) == "floatX": self.dtype = config.floatX else: - if np.obj2sctype(dtype) is None: + try: + self.dtype = str(np.dtype(dtype)) + except TypeError: raise TypeError(f"Invalid dtype: {dtype}") - self.dtype = np.dtype(dtype).name - def parse_bcast_and_shape(s): if isinstance(s, bool | np.bool_): return 1 if s else None @@ -789,14 +790,16 @@ def tensor( **kwargs, ) -> "TensorVariable": if name is not None: - # Help catching errors with the new tensor API - # Many single letter strings are valid sctypes - if str(name) == "floatX" or (len(str(name)) > 1 and np.obj2sctype(name)): - np.obj2sctype(name) - raise ValueError( - f"The first and only positional argument of tensor is now `name`. Got {name}.\n" - "This name looks like a dtype, which you should pass as a keyword argument only." - ) + try: + # Help catching errors with the new tensor API + # Many single letter strings are valid sctypes + if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type): + raise ValueError( + f"The first and only positional argument of tensor is now `name`. Got {name}.\n" + "This name looks like a dtype, which you should pass as a keyword argument only." + ) + except TypeError: + pass if dtype is None: dtype = config.floatX diff --git a/tests/scan/test_rewriting.py b/tests/scan/test_rewriting.py index 6f77625f2f..fd9c43b129 100644 --- a/tests/scan/test_rewriting.py +++ b/tests/scan/test_rewriting.py @@ -673,7 +673,7 @@ def test_machine_translation(self): zi = tensor3("zi") zi_value = x_value - init = pt.alloc(np.cast[config.floatX](0), batch_size, dim) + init = pt.alloc(np.asarray(0, dtype=config.floatX), batch_size, dim) def rnn_step1( # sequences diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 8bf689bc15..54bb7f4333 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -708,7 +708,7 @@ def test_perform(self, shp): y = scalar() f = function([x, y], fill_diagonal(x, y)) a = rng.random(shp).astype(config.floatX) - val = np.cast[config.floatX](rng.random()) + val = rng.random(dtype=config.floatX) out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out), val) @@ -720,7 +720,7 @@ def test_perform_3d(self): x = tensor3() y = scalar() f = function([x, y], fill_diagonal(x, y)) - val = np.cast[config.floatX](rng.random() + 10) + val = rng.random(dtype=config.floatX) + 10 out = f(a, val) # We can't use np.fill_diagonal as it is bugged. assert out[0, 0, 0] == val @@ -782,7 +782,7 @@ def test_perform(self, test_offset, shp): f = function([x, y, z], fill_diagonal_offset(x, y, z)) a = rng.random(shp).astype(config.floatX) - val = np.cast[config.floatX](rng.random()) + val = rng.random(dtype=config.floatX) out = f(a, val, test_offset) # We can't use np.fill_diagonal as it is bugged. assert np.allclose(np.diag(out, test_offset), val) diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index 9eb06f28a3..b94750ffe2 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -152,7 +152,7 @@ def upcast_float16_ufunc(fn): """ def ret(*args, **kwargs): - out_dtype = np.find_common_type([a.dtype for a in args], [np.float16]) + out_dtype = np.result_type(np.float16, *args) if out_dtype == "float16": # Force everything to float32 sig = "f" * fn.nin + "->" + "f" * fn.nout diff --git a/tests/test_gradient.py b/tests/test_gradient.py index 79c55caf44..24f5964c92 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -481,12 +481,12 @@ def make_grad_func(X): int_type = imatrix().dtype float_type = "float64" - X = np.cast[int_type](rng.standard_normal((m, d)) * 127.0) - W = np.cast[W.dtype](rng.standard_normal((d, n))) - b = np.cast[b.dtype](rng.standard_normal(n)) + X = np.asarray(rng.standard_normal((m, d)) * 127.0, dtype=int_type) + W = rng.standard_normal((d, n), dtype=W.dtype) + b = rng.standard_normal(n, dtype=b.dtype) int_result = int_func(X, W, b) - float_result = float_func(np.cast[float_type](X), W, b) + float_result = float_func(np.asarray(X, dtype=float_type), W, b) assert np.allclose(int_result, float_result), (int_result, float_result) @@ -508,7 +508,7 @@ def test_grad_disconnected(self): # the output f = pytensor.function([x], g) rng = np.random.default_rng([2012, 9, 5]) - x = np.cast[x.dtype](rng.standard_normal(3)) + x = rng.standard_normal(3, dtype=x.dtype) g = f(x) assert np.allclose(g, np.ones(x.shape, dtype=x.dtype)) @@ -631,7 +631,8 @@ def test_known_grads(): rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(10), rng.integers(10), rng.standard_normal()] values = [ - np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values, strict=True) + np.asarray(value, dtype=ipt.dtype) + for ipt, value in zip(inputs, values, strict=True) ] true_grads = grad(cost, inputs, disconnected_inputs="ignore") @@ -679,7 +680,7 @@ def test_known_grads_integers(): f = pytensor.function([g_expected], g_grad) x = -3 - gv = np.cast[config.floatX](0.6) + gv = np.asarray(0.6, dtype=config.floatX) g_actual = f(gv) @@ -746,7 +747,8 @@ def test_subgraph_grad(): rng = np.random.default_rng([2012, 11, 15]) values = [rng.standard_normal(2), rng.standard_normal(3)] values = [ - np.cast[ipt.dtype](value) for ipt, value in zip(inputs, values, strict=True) + np.asarray(value, dtype=ipt.dtype) + for ipt, value in zip(inputs, values, strict=True) ] wrt = [w2, w1] @@ -1031,21 +1033,21 @@ def test_jacobian_scalar(): # test when the jacobian is called with a tensor as wrt Jx = jacobian(y, x) f = pytensor.function([x], Jx) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a tuple as wrt Jx = jacobian(y, (x,)) assert isinstance(Jx, tuple) f = pytensor.function([x], Jx[0]) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a list as wrt Jx = jacobian(y, [x]) assert isinstance(Jx, list) f = pytensor.function([x], Jx[0]) - vx = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) assert np.allclose(f(vx), 2) # test when the jacobian is called with a list of two elements @@ -1053,8 +1055,8 @@ def test_jacobian_scalar(): y = x * z Jx = jacobian(y, [x, z]) f = pytensor.function([x, z], Jx) - vx = np.cast[pytensor.config.floatX](rng.uniform()) - vz = np.cast[pytensor.config.floatX](rng.uniform()) + vx = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) + vz = np.asarray(rng.uniform(), dtype=pytensor.config.floatX) vJx = f(vx, vz) assert np.allclose(vJx[0], vz) diff --git a/tests/typed_list/test_basic.py b/tests/typed_list/test_basic.py index 466bdc865d..19598bfb21 100644 --- a/tests/typed_list/test_basic.py +++ b/tests/typed_list/test_basic.py @@ -577,10 +577,10 @@ def test_correct_answer(self): x = tensor3() y = tensor3() - A = np.cast[pytensor.config.floatX](np.random.random((5, 3))) - B = np.cast[pytensor.config.floatX](np.random.random((7, 2))) - X = np.cast[pytensor.config.floatX](np.random.random((5, 6, 1))) - Y = np.cast[pytensor.config.floatX](np.random.random((1, 9, 3))) + A = np.random.random((5, 3)).astype(pytensor.config.floatX) + B = np.random.random((7, 2)).astype(pytensor.config.floatX) + X = np.random.random((5, 6, 1)).astype(pytensor.config.floatX) + Y = np.random.random((1, 9, 3)).astype(pytensor.config.floatX) make_list((3.0, 4.0)) c = make_list((a, b)) From 910b27c00ba93ead70a9994d8b651fa179c19380 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Wed, 5 Feb 2025 10:19:16 +0000 Subject: [PATCH 03/19] Updated lazylinker C code Some macros were removed from npy_3k_compat.h. Following numpy, I updated the affected functions to the Python 3 names, and removed support for Python 2. Also updated lazylinker_c version to indicate substantial changes to the C code. --- pytensor/link/c/c_code/lazylinker_c.c | 53 +++++++------------- pytensor/link/c/c_code/pytensor_mod_helper.h | 8 +-- pytensor/link/c/lazylinker_c.py | 2 +- 3 files changed, 21 insertions(+), 42 deletions(-) diff --git a/pytensor/link/c/c_code/lazylinker_c.c b/pytensor/link/c/c_code/lazylinker_c.c index a64614a908..08f3e4d0fb 100644 --- a/pytensor/link/c/c_code/lazylinker_c.c +++ b/pytensor/link/c/c_code/lazylinker_c.c @@ -5,9 +5,6 @@ #if PY_VERSION_HEX >= 0x03000000 #include "numpy/npy_3kcompat.h" -#define PyCObject_AsVoidPtr NpyCapsule_AsVoidPtr -#define PyCObject_GetDesc NpyCapsule_GetDesc -#define PyCObject_Check NpyCapsule_Check #endif #ifndef Py_TYPE @@ -323,9 +320,9 @@ static int CLazyLinker_init(CLazyLinker *self, PyObject *args, PyObject *kwds) { if (PyObject_HasAttrString(thunk, "cthunk")) { PyObject *cthunk = PyObject_GetAttrString(thunk, "cthunk"); // new reference - assert(cthunk && PyCObject_Check(cthunk)); - self->thunk_cptr_fn[i] = PyCObject_AsVoidPtr(cthunk); - self->thunk_cptr_data[i] = PyCObject_GetDesc(cthunk); + assert(cthunk && NpyCapsule_Check(cthunk)); + self->thunk_cptr_fn[i] = NpyCapsule_AsVoidPtr(cthunk); + self->thunk_cptr_data[i] = NpyCapsule_GetDesc(cthunk); Py_DECREF(cthunk); // cthunk is kept alive by membership in self->thunks } @@ -487,8 +484,8 @@ static PyObject *pycall(CLazyLinker *self, Py_ssize_t node_idx, int verbose) { PyList_SetItem(self->call_times, node_idx, PyFloat_FromDouble(t1 - t0 + ti)); PyObject *count = PyList_GetItem(self->call_counts, node_idx); - long icount = PyInt_AsLong(count); - PyList_SetItem(self->call_counts, node_idx, PyInt_FromLong(icount + 1)); + long icount = PyLong_AsLong(count); + PyList_SetItem(self->call_counts, node_idx, PyLong_FromLong(icount + 1)); } } else { if (verbose) { @@ -512,8 +509,8 @@ static int c_call(CLazyLinker *self, Py_ssize_t node_idx, int verbose) { PyList_SetItem(self->call_times, node_idx, PyFloat_FromDouble(t1 - t0 + ti)); PyObject *count = PyList_GetItem(self->call_counts, node_idx); - long icount = PyInt_AsLong(count); - PyList_SetItem(self->call_counts, node_idx, PyInt_FromLong(icount + 1)); + long icount = PyLong_AsLong(count); + PyList_SetItem(self->call_counts, node_idx, PyLong_FromLong(icount + 1)); } else { err = fn(self->thunk_cptr_data[node_idx]); } @@ -774,20 +771,20 @@ static PyObject *CLazyLinker_call(PyObject *_self, PyObject *args, output_subset = (char *)calloc(self->n_output_vars, sizeof(char)); for (int it = 0; it < output_subset_size; ++it) { PyObject *elem = PyList_GetItem(output_subset_ptr, it); - if (!PyInt_Check(elem)) { + if (!PyLong_Check(elem)) { err = 1; PyErr_SetString(PyExc_RuntimeError, "Some elements of output_subset list are not int"); } - output_subset[PyInt_AsLong(elem)] = 1; + output_subset[PyLong_AsLong(elem)] = 1; } } } self->position_of_error = -1; // create constants used to fill the var_compute_cells - PyObject *one = PyInt_FromLong(1); - PyObject *zero = PyInt_FromLong(0); + PyObject *one = PyLong_FromLong(1); + PyObject *zero = PyLong_FromLong(0); // pre-allocate our return value Py_INCREF(Py_None); @@ -942,11 +939,8 @@ static PyMemberDef CLazyLinker_members[] = { }; static PyTypeObject lazylinker_ext_CLazyLinkerType = { -#if defined(NPY_PY3K) PyVarObject_HEAD_INIT(NULL, 0) -#else - PyObject_HEAD_INIT(NULL) 0, /*ob_size*/ -#endif + "lazylinker_ext.CLazyLinker", /*tp_name*/ sizeof(CLazyLinker), /*tp_basicsize*/ 0, /*tp_itemsize*/ @@ -987,7 +981,7 @@ static PyTypeObject lazylinker_ext_CLazyLinkerType = { }; static PyObject *get_version(PyObject *dummy, PyObject *args) { - PyObject *result = PyFloat_FromDouble(0.212); + PyObject *result = PyFloat_FromDouble(0.3); return result; } @@ -996,7 +990,7 @@ static PyMethodDef lazylinker_ext_methods[] = { {NULL, NULL, 0, NULL} /* Sentinel */ }; -#if defined(NPY_PY3K) + static struct PyModuleDef moduledef = {PyModuleDef_HEAD_INIT, "lazylinker_ext", NULL, @@ -1006,28 +1000,19 @@ static struct PyModuleDef moduledef = {PyModuleDef_HEAD_INIT, NULL, NULL, NULL}; -#endif -#if defined(NPY_PY3K) -#define RETVAL m + PyMODINIT_FUNC PyInit_lazylinker_ext(void) { -#else -#define RETVAL -PyMODINIT_FUNC initlazylinker_ext(void) { -#endif + PyObject *m; lazylinker_ext_CLazyLinkerType.tp_new = PyType_GenericNew; if (PyType_Ready(&lazylinker_ext_CLazyLinkerType) < 0) - return RETVAL; -#if defined(NPY_PY3K) + return NULL; + m = PyModule_Create(&moduledef); -#else - m = Py_InitModule3("lazylinker_ext", lazylinker_ext_methods, - "Example module that creates an extension type."); -#endif Py_INCREF(&lazylinker_ext_CLazyLinkerType); PyModule_AddObject(m, "CLazyLinker", (PyObject *)&lazylinker_ext_CLazyLinkerType); - return RETVAL; + return m; } diff --git a/pytensor/link/c/c_code/pytensor_mod_helper.h b/pytensor/link/c/c_code/pytensor_mod_helper.h index d3e4b29a2b..2f857e6775 100644 --- a/pytensor/link/c/c_code/pytensor_mod_helper.h +++ b/pytensor/link/c/c_code/pytensor_mod_helper.h @@ -18,14 +18,8 @@ #define PYTENSOR_EXTERN #endif -#if PY_MAJOR_VERSION < 3 -#define PYTENSOR_RTYPE void -#else -#define PYTENSOR_RTYPE PyObject * -#endif - /* We need to redefine PyMODINIT_FUNC to add MOD_PUBLIC in the middle */ #undef PyMODINIT_FUNC -#define PyMODINIT_FUNC PYTENSOR_EXTERN MOD_PUBLIC PYTENSOR_RTYPE +#define PyMODINIT_FUNC PYTENSOR_EXTERN MOD_PUBLIC PyObject * #endif diff --git a/pytensor/link/c/lazylinker_c.py b/pytensor/link/c/lazylinker_c.py index 679cb4e290..ce67190342 100644 --- a/pytensor/link/c/lazylinker_c.py +++ b/pytensor/link/c/lazylinker_c.py @@ -14,7 +14,7 @@ _logger = logging.getLogger(__file__) force_compile = False -version = 0.212 # must match constant returned in function get_version() +version = 0.3 # must match constant returned in function get_version() lazylinker_ext: ModuleType | None = None From 92d96ff372933608f7e175d5c29deed60be41d91 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 5 Apr 2024 11:42:14 +0200 Subject: [PATCH 04/19] Changes for deprecations in numpy 2.0 C-API - replace `->elsize` by `PyArray_ITEMSIZE` - don't use deprecated PyArray_MoveInto --- pytensor/sparse/basic.py | 20 +++---- pytensor/sparse/rewriting.py | 94 ++++++++++++++++----------------- pytensor/tensor/blas.py | 14 ++--- pytensor/tensor/blas_headers.py | 4 +- tests/compile/test_debugmode.py | 6 +-- 5 files changed, 69 insertions(+), 69 deletions(-) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index c590bc804a..7f200b2a7c 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -3610,7 +3610,7 @@ def perform(self, node, inputs, outputs): out[0] = g_a_data def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): (_indices, _indptr, _d, _g) = inputs @@ -3647,11 +3647,11 @@ def c_code(self, node, name, inputs, outputs, sub): npy_intp nnz = PyArray_DIMS({_indices})[0]; npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this - npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; - npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; + npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices}); + npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr}); - const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; - const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; + const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d}); + const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g}); const npy_intp K = PyArray_DIMS({_d})[1]; @@ -3744,7 +3744,7 @@ def perform(self, node, inputs, outputs): out[0] = g_a_data def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inputs, outputs, sub): (_indices, _indptr, _d, _g) = inputs @@ -3782,11 +3782,11 @@ def c_code(self, node, name, inputs, outputs, sub): // extract number of rows npy_intp N = PyArray_DIMS({_indptr})[0]-1; //TODO: error checking with this - npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_DESCR({_indices})->elsize; - npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_DESCR({_indptr})->elsize; + npy_intp Sindices = PyArray_STRIDES({_indices})[0]/PyArray_ITEMSIZE({_indices}); + npy_intp Sindptr = PyArray_STRIDES({_indptr})[0]/PyArray_ITEMSIZE({_indptr}); - const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_DESCR({_d})->elsize; - const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_DESCR({_g})->elsize; + const npy_intp Sd1 = PyArray_STRIDES({_d})[1]/PyArray_ITEMSIZE({_d}); + const npy_intp Sg1 = PyArray_STRIDES({_g})[1]/PyArray_ITEMSIZE({_g}); const npy_intp K = PyArray_DIMS({_d})[1]; diff --git a/pytensor/sparse/rewriting.py b/pytensor/sparse/rewriting.py index bf6d6f0bc6..13735d2aca 100644 --- a/pytensor/sparse/rewriting.py +++ b/pytensor/sparse/rewriting.py @@ -158,8 +158,8 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{y}* ydata = (dtype_{y}*)PyArray_DATA({y}); dtype_{z}* zdata = (dtype_{z}*)PyArray_DATA({z}); - npy_intp Yi = PyArray_STRIDES({y})[0]/PyArray_DESCR({y})->elsize; - npy_intp Yj = PyArray_STRIDES({y})[1]/PyArray_DESCR({y})->elsize; + npy_intp Yi = PyArray_STRIDES({y})[0]/PyArray_ITEMSIZE({y}); + npy_intp Yj = PyArray_STRIDES({y})[1]/PyArray_ITEMSIZE({y}); npy_intp pos; if ({format} == 0){{ @@ -186,7 +186,7 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[3]] def c_code_cache_version(self): - return (2,) + return (3,) @node_rewriter([sparse.AddSD]) @@ -361,13 +361,13 @@ def c_code(self, node, name, inputs, outputs, sub): {{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); {fail};}} // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - //npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_DESCR({b})->elsize; - npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_DESCR({b})->elsize; - npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; + npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + //npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_ITEMSIZE({b}); + npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_ITEMSIZE({b}); + npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -436,7 +436,7 @@ def c_code(self, node, name, inputs, outputs, sub): return rval def c_code_cache_version(self): - return (3,) + return (4,) sd_csc = StructuredDotCSC() @@ -555,13 +555,13 @@ def c_code(self, node, name, inputs, outputs, sub): {{PyErr_SetString(PyExc_NotImplementedError, "array too big (overflows int32 index)"); {fail};}} // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_DESCR({b})->elsize; - npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_DESCR({b})->elsize; - npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; + npy_intp Szm = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + npy_intp Sbm = PyArray_STRIDES({b})[0] / PyArray_ITEMSIZE({b}); + npy_intp Sbn = PyArray_STRIDES({b})[1] / PyArray_ITEMSIZE({b}); + npy_intp Sval = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -614,7 +614,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ def c_code_cache_version(self): - return (2,) + return (3,) sd_csr = StructuredDotCSR() @@ -845,12 +845,12 @@ def c_code(self, node, name, inputs, outputs, sub): const npy_int32 * __restrict__ Dptr = (npy_int32*)PyArray_DATA({x_ptr}); const dtype_{alpha} alpha = ((dtype_{alpha}*)PyArray_DATA({alpha}))[0]; - npy_intp Sz = PyArray_STRIDES({z})[1] / PyArray_DESCR({z})->elsize; - npy_intp Szn = PyArray_STRIDES({zn})[1] / PyArray_DESCR({zn})->elsize; - npy_intp Sval = PyArray_STRIDES({x_val})[0] / PyArray_DESCR({x_val})->elsize; - npy_intp Sind = PyArray_STRIDES({x_ind})[0] / PyArray_DESCR({x_ind})->elsize; - npy_intp Sptr = PyArray_STRIDES({x_ptr})[0] / PyArray_DESCR({x_ptr})->elsize; - npy_intp Sy = PyArray_STRIDES({y})[1] / PyArray_DESCR({y})->elsize; + npy_intp Sz = PyArray_STRIDES({z})[1] / PyArray_ITEMSIZE({z}); + npy_intp Szn = PyArray_STRIDES({zn})[1] / PyArray_ITEMSIZE({zn}); + npy_intp Sval = PyArray_STRIDES({x_val})[0] / PyArray_ITEMSIZE({x_val}); + npy_intp Sind = PyArray_STRIDES({x_ind})[0] / PyArray_ITEMSIZE({x_ind}); + npy_intp Sptr = PyArray_STRIDES({x_ptr})[0] / PyArray_ITEMSIZE({x_ptr}); + npy_intp Sy = PyArray_STRIDES({y})[1] / PyArray_ITEMSIZE({y}); // blas expects ints; convert here (rather than just making N etc ints) to avoid potential overflow in the negative-stride correction if ((N > 0x7fffffffL)||(Sy > 0x7fffffffL)||(Szn > 0x7fffffffL)||(Sy < -0x7fffffffL)||(Szn < -0x7fffffffL)) @@ -896,7 +896,7 @@ def c_code(self, node, name, inputs, outputs, sub): return rval def c_code_cache_version(self): - return (3, blas.blas_header_version()) + return (4, blas.blas_header_version()) usmm_csc_dense = UsmmCscDense(inplace=False) @@ -1035,13 +1035,13 @@ def c_code(self, node, name, inputs, outputs, sub): npy_intp sp_dim = (M == a_dim_0)?a_dim_1:a_dim_0; // strides tell you how many bytes to skip to go to next column/row entry - npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; - npy_intp Sa_val = PyArray_STRIDES({a_val})[0] / PyArray_DESCR({a_val})->elsize; - npy_intp Sa_ind = PyArray_STRIDES({a_ind})[0] / PyArray_DESCR({a_ind})->elsize; - npy_intp Sa_ptr = PyArray_STRIDES({a_ptr})[0] / PyArray_DESCR({a_ptr})->elsize; - npy_intp Sb_val = PyArray_STRIDES({b_val})[0] / PyArray_DESCR({b_val})->elsize; - npy_intp Sb_ind = PyArray_STRIDES({b_ind})[0] / PyArray_DESCR({b_ind})->elsize; - npy_intp Sb_ptr = PyArray_STRIDES({b_ptr})[0] / PyArray_DESCR({b_ptr})->elsize; + npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); + npy_intp Sa_val = PyArray_STRIDES({a_val})[0] / PyArray_ITEMSIZE({a_val}); + npy_intp Sa_ind = PyArray_STRIDES({a_ind})[0] / PyArray_ITEMSIZE({a_ind}); + npy_intp Sa_ptr = PyArray_STRIDES({a_ptr})[0] / PyArray_ITEMSIZE({a_ptr}); + npy_intp Sb_val = PyArray_STRIDES({b_val})[0] / PyArray_ITEMSIZE({b_val}); + npy_intp Sb_ind = PyArray_STRIDES({b_ind})[0] / PyArray_ITEMSIZE({b_ind}); + npy_intp Sb_ptr = PyArray_STRIDES({b_ptr})[0] / PyArray_ITEMSIZE({b_ptr}); // pointers to access actual data in the arrays passed as params. dtype_{z}* __restrict__ Dz = (dtype_{z}*)PyArray_DATA({z}); @@ -1086,7 +1086,7 @@ def c_code(self, node, name, inputs, outputs, sub): """ def c_code_cache_version(self): - return (3,) + return (4,) csm_grad_c = CSMGradC() @@ -1482,7 +1482,7 @@ def make_node(self, a_data, a_indices, a_indptr, b): ) def c_code_cache_version(self): - return (2,) + return (3,) def c_code(self, node, name, inputs, outputs, sub): ( @@ -1544,7 +1544,7 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{_zout} * const __restrict__ zout = (dtype_{_zout}*)PyArray_DATA({_zout}); - const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_DESCR({_b})->elsize; + const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_ITEMSIZE({_b}); // loop over rows for (npy_intp j = 0; j < N; ++j) @@ -1655,7 +1655,7 @@ def make_node(self, a_data, a_indices, a_indptr, b): ) def c_code_cache_version(self): - return (3,) + return (4,) def c_code(self, node, name, inputs, outputs, sub): ( @@ -1723,7 +1723,7 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{_zout} * const __restrict__ zout = (dtype_{_zout}*)PyArray_DATA({_zout}); - const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_DESCR({_b})->elsize; + const npy_intp Sb = PyArray_STRIDES({_b})[0] / PyArray_ITEMSIZE({_b}); // loop over columns for (npy_intp j = 0; j < N; ++j) @@ -1868,7 +1868,7 @@ def make_node(self, x, y, p_data, p_ind, p_ptr, p_ncols): ) def c_code_cache_version(self): - return (4, blas.blas_header_version()) + return (5, blas.blas_header_version()) def c_support_code(self, **kwargs): return blas.blas_header_text() @@ -1995,14 +1995,14 @@ def c_code(self, node, name, inputs, outputs, sub): dtype_{z_ind}* __restrict__ Dzi = (dtype_{z_ind}*)PyArray_DATA({z_ind}); dtype_{z_ptr}* __restrict__ Dzp = (dtype_{z_ptr}*)PyArray_DATA({z_ptr}); - const npy_intp Sdx = PyArray_STRIDES({x})[1]/PyArray_DESCR({x})->elsize; - const npy_intp Sdy = PyArray_STRIDES({y})[1]/PyArray_DESCR({y})->elsize; - const npy_intp Sdpd = PyArray_STRIDES({p_data})[0] / PyArray_DESCR({p_data})->elsize; - const npy_intp Sdpi = PyArray_STRIDES({p_ind})[0] / PyArray_DESCR({p_ind})->elsize; - const npy_intp Sdpp = PyArray_STRIDES({p_ptr})[0] / PyArray_DESCR({p_ptr})->elsize; - const npy_intp Sdzd = PyArray_STRIDES({z_data})[0] / PyArray_DESCR({z_data})->elsize; - const npy_intp Sdzi = PyArray_STRIDES({z_ind})[0] / PyArray_DESCR({z_ind})->elsize; - const npy_intp Sdzp = PyArray_STRIDES({z_ptr})[0] / PyArray_DESCR({z_ptr})->elsize; + const npy_intp Sdx = PyArray_STRIDES({x})[1]/PyArray_ITEMSIZE({x}); + const npy_intp Sdy = PyArray_STRIDES({y})[1]/PyArray_ITEMSIZE({y}); + const npy_intp Sdpd = PyArray_STRIDES({p_data})[0] / PyArray_ITEMSIZE({p_data}); + const npy_intp Sdpi = PyArray_STRIDES({p_ind})[0] / PyArray_ITEMSIZE({p_ind}); + const npy_intp Sdpp = PyArray_STRIDES({p_ptr})[0] / PyArray_ITEMSIZE({p_ptr}); + const npy_intp Sdzd = PyArray_STRIDES({z_data})[0] / PyArray_ITEMSIZE({z_data}); + const npy_intp Sdzi = PyArray_STRIDES({z_ind})[0] / PyArray_ITEMSIZE({z_ind}); + const npy_intp Sdzp = PyArray_STRIDES({z_ptr})[0] / PyArray_ITEMSIZE({z_ptr}); memcpy(Dzi, Dpi, PyArray_DIMS({p_ind})[0]*sizeof(dtype_{p_ind})); memcpy(Dzp, Dpp, PyArray_DIMS({p_ptr})[0]*sizeof(dtype_{p_ptr})); diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index d0f524e413..592a4ba27c 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -498,7 +498,7 @@ def c_header_dirs(self, **kwargs): int unit = 0; int type_num = PyArray_DESCR(%(_x)s)->type_num; - int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes + int type_size = PyArray_ITEMSIZE(%(_x)s); // in bytes npy_intp* Nx = PyArray_DIMS(%(_x)s); npy_intp* Ny = PyArray_DIMS(%(_y)s); @@ -789,7 +789,7 @@ def build_gemm_call(self): ) def build_gemm_version(self): - return (13, blas_header_version()) + return (14, blas_header_version()) class Gemm(GemmRelated): @@ -1030,7 +1030,7 @@ def infer_shape(self, fgraph, node, input_shapes): %(fail)s } - if(PyArray_MoveInto(x_new, %(_x)s) == -1) + if(PyArray_CopyInto(x_new, %(_x)s) == -1) { %(fail)s } @@ -1056,7 +1056,7 @@ def infer_shape(self, fgraph, node, input_shapes): %(fail)s } - if(PyArray_MoveInto(y_new, %(_y)s) == -1) + if(PyArray_CopyInto(y_new, %(_y)s) == -1) { %(fail)s } @@ -1102,7 +1102,7 @@ def c_code(self, node, name, inp, out, sub): def c_code_cache_version(self): gv = self.build_gemm_version() if gv: - return (7, *gv) + return (8, *gv) else: return gv @@ -1538,7 +1538,7 @@ def contiguous(var, ndim): return f""" int type_num = PyArray_DESCR({_x})->type_num; - int type_size = PyArray_DESCR({_x})->elsize; // in bytes + int type_size = PyArray_ITEMSIZE({_x}); // in bytes if (PyArray_NDIM({_x}) != 3) {{ PyErr_Format(PyExc_NotImplementedError, @@ -1598,7 +1598,7 @@ def contiguous(var, ndim): def c_code_cache_version(self): from pytensor.tensor.blas_headers import blas_header_version - return (5, blas_header_version()) + return (6, blas_header_version()) def grad(self, inp, grads): x, y = inp diff --git a/pytensor/tensor/blas_headers.py b/pytensor/tensor/blas_headers.py index 645f04bfb3..5d49b70ec4 100644 --- a/pytensor/tensor/blas_headers.py +++ b/pytensor/tensor/blas_headers.py @@ -1053,7 +1053,7 @@ def openblas_threads_text(): def blas_header_version(): # Version for the base header - version = (9,) + version = (10,) if detect_macos_sdot_bug(): if detect_macos_sdot_bug.fix_works: # Version with fix @@ -1071,7 +1071,7 @@ def ____gemm_code(check_ab, a_init, b_init): const char * error_string = NULL; int type_num = PyArray_DESCR(_x)->type_num; - int type_size = PyArray_DESCR(_x)->elsize; // in bytes + int type_size = PyArray_ITEMSIZE(_x); // in bytes npy_intp* Nx = PyArray_DIMS(_x); npy_intp* Ny = PyArray_DIMS(_y); diff --git a/tests/compile/test_debugmode.py b/tests/compile/test_debugmode.py index 95e52d6b53..fae76fab0d 100644 --- a/tests/compile/test_debugmode.py +++ b/tests/compile/test_debugmode.py @@ -146,7 +146,7 @@ def dontuse_perform(self, node, inp, out_): raise ValueError(self.behaviour) def c_code_cache_version(self): - return (1,) + return (2,) def c_code(self, node, name, inp, out, sub): (a,) = inp @@ -165,8 +165,8 @@ def c_code(self, node, name, inp, out, sub): prep_vars = f""" //the output array has size M x N npy_intp M = PyArray_DIMS({a})[0]; - npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_DESCR({a})->elsize; - npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_DESCR({z})->elsize; + npy_intp Sa = PyArray_STRIDES({a})[0] / PyArray_ITEMSIZE({a}); + npy_intp Sz = PyArray_STRIDES({z})[0] / PyArray_ITEMSIZE({z}); npy_double * Da = (npy_double*)PyArray_BYTES({a}); npy_double * Dz = (npy_double*)PyArray_BYTES({z}); From b20f4015943ff46c160675d6ab79f05e3bffe581 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Thu, 13 Feb 2025 10:59:55 +0000 Subject: [PATCH 05/19] Update type hint for c_code_cache_version Anything `Hashable` should work, but I've made the return type `tuple[Hashable]` to keep with the current style. This means, e.g., we can use strings in the cache version. --- pytensor/link/c/interface.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/link/c/interface.py b/pytensor/link/c/interface.py index 7e281af947..e9375d2511 100644 --- a/pytensor/link/c/interface.py +++ b/pytensor/link/c/interface.py @@ -1,7 +1,7 @@ import typing import warnings from abc import abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Hashable from typing import Optional from pytensor.graph.basic import Apply, Constant @@ -155,7 +155,7 @@ def c_init_code(self, **kwargs) -> list[str]: """Return a list of code snippets to be inserted in module initialization.""" return [] - def c_code_cache_version(self) -> tuple[int, ...]: + def c_code_cache_version(self) -> tuple[Hashable, ...]: """Return a tuple of integers indicating the version of this `Op`. An empty tuple indicates an "unversioned" `Op` that will not be cached @@ -223,7 +223,7 @@ def c_code( """ raise NotImplementedError() - def c_code_cache_version_apply(self, node: Apply) -> tuple[int, ...]: + def c_code_cache_version_apply(self, node: Apply) -> tuple[Hashable, ...]: """Return a tuple of integers indicating the version of this `Op`. An empty tuple indicates an "unversioned" `Op` that will not be From 69713deef2a7d07a86bb0521478238e2057b493f Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Mon, 29 Jul 2024 09:42:41 +0100 Subject: [PATCH 06/19] Make complex scalars work with numpy 2.0 This is done using C++ generic functions to get/set the real/imag parts of complex numbers. This gives us an easy way to support Numpy v < 2.0, and allows the type underlying the bit width types, like pytensor_complex128, to be correctly inferred from the numpy complex types they inherit from. Updated pytensor_complex struct to use get/set real/imag aliases defined above. Also updated operators such as `Abs` to use get_real, get_imag. Macros have been added to ensure compatibility with numpy < 2.0 Note: redefining the complex arithmetic here means that we aren't treating NaNs and infinities as carefully as the C99 standard suggets (see Appendix G of the standard). The code has been like this since it was added to Theano, so we're keeping the existing behavior. --- pytensor/scalar/basic.py | 225 ++++++++++++++++++++++++++++----------- 1 file changed, 161 insertions(+), 64 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 94039f8091..d7d719e2f4 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -349,6 +349,8 @@ def c_headers(self, c_compiler=None, **kwargs): # we declare them here and they will be re-used by TensorType l.append("") l.append("") + l.append("") + if config.lib__amdlibm and c_compiler.supports_amdlibm: l += [""] return l @@ -517,73 +519,167 @@ def c_support_code(self, **kwargs): # In that case we add the 'int' type to the real types. real_types.append("int") + # Macros for backwards compatibility with numpy < 2.0 + # + # In numpy 2.0+, these are defined in npy_math.h, but + # for early versions, they must be vendored by users (e.g. PyTensor) + backwards_compat_macros = """ + #ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_ + #define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_ + + #include + + #ifndef NPY_CSETREALF + #define NPY_CSETREALF(c, r) (c)->real = (r) + #endif + #ifndef NPY_CSETIMAGF + #define NPY_CSETIMAGF(c, i) (c)->imag = (i) + #endif + #ifndef NPY_CSETREAL + #define NPY_CSETREAL(c, r) (c)->real = (r) + #endif + #ifndef NPY_CSETIMAG + #define NPY_CSETIMAG(c, i) (c)->imag = (i) + #endif + #ifndef NPY_CSETREALL + #define NPY_CSETREALL(c, r) (c)->real = (r) + #endif + #ifndef NPY_CSETIMAGL + #define NPY_CSETIMAGL(c, i) (c)->imag = (i) + #endif + + #endif + """ + + def _make_get_set_real_imag(scalar_type: str) -> str: + """Make overloaded getter/setter functions for real/imag parts of numpy complex types. + + The functions called by these getter/setter functions are defining in npy_math.h, or + in the `backward_compat_macros` defined above. + + Args: + scalar_type: float, double, or longdouble + + Returns: + C++ code for defining set_real, set_imag, get_real, and get_imag, overloaded for the + given type. + """ + complex_type = "npy_c" + scalar_type + suffix = "" if scalar_type == "double" else scalar_type[0] + + if scalar_type == "longdouble": + scalar_type = "npy_" + scalar_type + + return_type = scalar_type + + template = f""" + static inline {return_type} get_real(const {complex_type} z) + {{ + return npy_creal{suffix}(z); + }} + + static inline void set_real({complex_type} *z, const {scalar_type} r) + {{ + NPY_CSETREAL{suffix.upper()}(z, r); + }} + + static inline {return_type} get_imag(const {complex_type} z) + {{ + return npy_cimag{suffix}(z); + }} + + static inline void set_imag({complex_type} *z, const {scalar_type} i) + {{ + NPY_CSETIMAG{suffix.upper()}(z, i); + }} + """ + return template + + get_set_aliases = "\n".join( + _make_get_set_real_imag(stype) + for stype in ["float", "double", "longdouble"] + ) + + get_set_aliases = backwards_compat_macros + "\n" + get_set_aliases + + # Template for defining pytensor_complex64 and pytensor_complex128 structs/classes + # + # The npy_complex64, npy_complex128 types are aliases defined at run time based on + # the size of floats and doubles on the machine. This means that both types are + # not necessarily defined on every machine, but a machine with 32-bit floats and + # 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128 + # as an alias of npy_complex128. + # + # In any case, the get/set real/imag functions defined above will always work for + # npy_complex64 and npy_complex128. template = """ - struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s - { - typedef pytensor_complex%(nbits)s complex_type; - typedef npy_float%(half_nbits)s scalar_type; - - complex_type operator +(const complex_type &y) const { - complex_type ret; - ret.real = this->real + y.real; - ret.imag = this->imag + y.imag; - return ret; - } - - complex_type operator -() const { - complex_type ret; - ret.real = -this->real; - ret.imag = -this->imag; - return ret; - } - bool operator ==(const complex_type &y) const { - return (this->real == y.real) && (this->imag == y.imag); - } - bool operator ==(const scalar_type &y) const { - return (this->real == y) && (this->imag == 0); - } - complex_type operator -(const complex_type &y) const { - complex_type ret; - ret.real = this->real - y.real; - ret.imag = this->imag - y.imag; - return ret; - } - complex_type operator *(const complex_type &y) const { - complex_type ret; - ret.real = this->real * y.real - this->imag * y.imag; - ret.imag = this->real * y.imag + this->imag * y.real; - return ret; - } - complex_type operator /(const complex_type &y) const { - complex_type ret; - scalar_type y_norm_square = y.real * y.real + y.imag * y.imag; - ret.real = (this->real * y.real + this->imag * y.imag) / y_norm_square; - ret.imag = (this->imag * y.real - this->real * y.imag) / y_norm_square; - return ret; - } - template - complex_type& operator =(const T& y); - - pytensor_complex%(nbits)s() {} - - template - pytensor_complex%(nbits)s(const T& y) { *this = y; } - - template - pytensor_complex%(nbits)s(const TR& r, const TI& i) { this->real=r; this->imag=i; } + struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s { + typedef pytensor_complex%(nbits)s complex_type; + typedef npy_float%(half_nbits)s scalar_type; + + complex_type operator+(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) + get_real(y)); + set_imag(&ret, get_imag(*this) + get_imag(y)); + return ret; + } + + complex_type operator-() const { + complex_type ret; + set_real(&ret, -get_real(*this)); + set_imag(&ret, -get_imag(*this)); + return ret; + } + bool operator==(const complex_type &y) const { + return (get_real(*this) == get_real(y)) && (get_imag(*this) == get_imag(y)); + } + bool operator==(const scalar_type &y) const { + return (get_real(*this) == y) && (get_real(*this) == 0); + } + complex_type operator-(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) - get_real(y)); + set_imag(&ret, get_imag(*this) - get_imag(y)); + return ret; + } + complex_type operator*(const complex_type &y) const { + complex_type ret; + set_real(&ret, get_real(*this) * get_real(y) - get_imag(*this) * get_imag(y)); + set_imag(&ret, get_imag(*this) * get_real(y) + get_real(*this) * get_imag(y)); + return ret; + } + complex_type operator/(const complex_type &y) const { + complex_type ret; + scalar_type y_norm_square = get_real(y) * get_real(y) + get_imag(y) * get_imag(y); + set_real(&ret, (get_real(*this) * get_real(y) + get_imag(*this) * get_imag(y)) / y_norm_square); + set_imag(&ret, (get_imag(*this) * get_real(y) - get_real(*this) * get_imag(y)) / y_norm_square); + return ret; + } + template complex_type &operator=(const T &y); + + + pytensor_complex%(nbits)s() {} + + template pytensor_complex%(nbits)s(const T &y) { *this = y; } + + template + pytensor_complex%(nbits)s(const TR &r, const TI &i) { + set_real(this, r); + set_imag(this, i); + } }; """ def operator_eq_real(mytype, othertype): return f""" template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) - {{ this->real=y; this->imag=0; return *this; }} + {{ set_real(this, y); set_imag(this, 0); return *this; }} """ def operator_eq_cplx(mytype, othertype): return f""" template <> {mytype} & {mytype}::operator=<{othertype}>(const {othertype} & y) - {{ this->real=y.real; this->imag=y.imag; return *this; }} + {{ set_real(this, get_real(y)); set_imag(this, get_imag(y)); return *this; }} """ operator_eq = "".join( @@ -605,10 +701,10 @@ def operator_eq_cplx(mytype, othertype): def operator_plus_real(mytype, othertype): return f""" const {mytype} operator+(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real+y, x.imag); }} + {{ return {mytype}(get_real(x) + y, get_imag(x)); }} const {mytype} operator+(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(x.real+y, x.imag); }} + {{ return {mytype}(get_real(x) + y, get_imag(x)); }} """ operator_plus = "".join( @@ -620,10 +716,10 @@ def operator_plus_real(mytype, othertype): def operator_minus_real(mytype, othertype): return f""" const {mytype} operator-(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real-y, x.imag); }} + {{ return {mytype}(get_real(x) - y, get_imag(x)); }} const {mytype} operator-(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(y-x.real, -x.imag); }} + {{ return {mytype}(y - get_real(x), -get_imag(x)); }} """ operator_minus = "".join( @@ -635,10 +731,10 @@ def operator_minus_real(mytype, othertype): def operator_mul_real(mytype, othertype): return f""" const {mytype} operator*(const {mytype} &x, const {othertype} &y) - {{ return {mytype}(x.real*y, x.imag*y); }} + {{ return {mytype}(get_real(x) * y, get_imag(x) * y); }} const {mytype} operator*(const {othertype} &y, const {mytype} &x) - {{ return {mytype}(x.real*y, x.imag*y); }} + {{ return {mytype}(get_real(x) * y, get_imag(x) * y); }} """ operator_mul = "".join( @@ -648,7 +744,8 @@ def operator_mul_real(mytype, othertype): ) return ( - template % dict(nbits=64, half_nbits=32) + get_set_aliases + + template % dict(nbits=64, half_nbits=32) + template % dict(nbits=128, half_nbits=64) + operator_eq + operator_plus @@ -663,7 +760,7 @@ def c_init_code(self, **kwargs): return ["import_array();"] def c_code_cache_version(self): - return (13, np.__version__) + return (14, np.__version__) def get_shape_info(self, obj): return obj.itemsize @@ -2567,7 +2664,7 @@ def c_code(self, node, name, inputs, outputs, sub): if type in float_types: return f"{z} = fabs({x});" if type in complex_types: - return f"{z} = sqrt({x}.real*{x}.real + {x}.imag*{x}.imag);" + return f"{z} = sqrt(get_real({x}) * get_real({x}) + get_imag({x}) * get_imag({x}));" if node.outputs[0].type == bool: return f"{z} = ({x}) ? 1 : 0;" if type in uint_types: From 9416df28376888596ac3bc8719064b2df11bd381 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Mon, 29 Jul 2024 09:37:56 +0100 Subject: [PATCH 07/19] Use Python implementation for AdvancedInSubtensor1 MapIter was removed from the public numpy C-API in version 2.0, so we raise a not implemented error to default to the python code for the AdvancedInSubtensor1. The python version, defined in `AdvancedInSubtensor1.perform` calls `np.add.at`, which uses `MapIter` behind the scenes. There is active development on Numpy to improve the efficiency of `np.add.at`. To skip the C implementation and use the Python implementation, we raise a NotImplementedError for this op's c code if numpy>=2.0. --- pytensor/tensor/subtensor.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 46b9cc06fd..51e6dba0d8 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -2520,8 +2520,7 @@ def gen_num(typen): return code def c_code(self, node, name, input_names, output_names, sub): - numpy_ver = [int(n) for n in np.__version__.split(".")[:2]] - if bool(numpy_ver < [1, 8]): + if numpy_version < "1.8.0" or using_numpy_2: raise NotImplementedError x, y, idx = input_names From f4f58a4c42f200a4137946c4a4e70508f6017d29 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Tue, 6 Aug 2024 11:59:24 +0100 Subject: [PATCH 08/19] Changed copy to deepcopy for rng This was done for the python linker and numba linker. deepcopy seems to be the recommended method for copying a numpy Generator. After this numpy PR: https://github.com/numpy/numpy/pull/26293/commits/44ba7ca07984557f2006f9a6916adb8e3ecfca61 `copy` didn't seem to actually make an independent copy of the `np.random.Generator` objects spawned by `RandomStream`. This was causing the "test values" computed by e.g. `RandomStream.uniform` to increment the RNG state, which was causing tests that rely on `RandomStream` to fail. Here is some related discussion: https://github.com/numpy/numpy/issues/24086 I didn't see any official documentation about a change in numpy that would make copy stop working. --- pytensor/link/numba/dispatch/random.py | 4 ++-- pytensor/tensor/random/op.py | 4 ++-- tests/tensor/random/test_basic.py | 6 ++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index e80a033c82..e20d99c605 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from copy import copy +from copy import copy, deepcopy from functools import singledispatch from textwrap import dedent @@ -34,7 +34,7 @@ def copy_NumPyRandomGenerator(rng): def impl(rng): # TODO: Open issue on Numba? with numba.objmode(new_rng=types.npy_rng): - new_rng = copy(rng) + new_rng = deepcopy(rng) return new_rng diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index c76d250c9e..a8b67dee4f 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Sequence -from copy import copy +from copy import deepcopy from typing import Any, cast import numpy as np @@ -395,7 +395,7 @@ def perform(self, node, inputs, outputs): # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise. if not self.inplace: - rng = copy(rng) + rng = deepcopy(rng) outputs[0][0] = rng outputs[1][0] = np.asarray( diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 23d1b87020..4192a6c473 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -1,6 +1,6 @@ import pickle import re -from copy import copy +from copy import deepcopy import numpy as np import pytest @@ -114,7 +114,9 @@ def test_fn(*args, random_state=None, **kwargs): pt_rng = shared(rng, borrow=True) - numpy_res = np.asarray(test_fn(*param_vals, random_state=copy(rng), **kwargs_vals)) + numpy_res = np.asarray( + test_fn(*param_vals, random_state=deepcopy(rng), **kwargs_vals) + ) pytensor_res = rv(*params, rng=pt_rng, **kwargs) From 2944552d6216954fac4ec923132d05c199e86474 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Wed, 29 Jan 2025 11:11:52 +0000 Subject: [PATCH 09/19] Change rng.__getstate__ to rng.bit_generator.state numpy.random.Generator.__getstate__() now returns none; to see the state of the bit generator, you need to use Generator.bit_generator.state. This change affects `RandomGeneratorType`, and several of the random tests (including some for Jax.) --- pytensor/link/jax/dispatch/random.py | 2 +- pytensor/tensor/random/type.py | 4 ++-- tests/link/jax/test_random.py | 4 +++- tests/tensor/random/test_type.py | 10 +++++----- tests/tensor/random/test_utils.py | 12 +++++++++--- 5 files changed, 20 insertions(+), 12 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index d66ddc049d..8a33dfac13 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -56,7 +56,7 @@ def assert_size_argument_jax_compatible(node): @jax_typify.register(Generator) def jax_typify_Generator(rng, **kwargs): - state = rng.__getstate__() + state = rng.bit_generator.state state["bit_generator"] = numpy_bit_gens[state["bit_generator"]] # XXX: Is this a reasonable approach? diff --git a/pytensor/tensor/random/type.py b/pytensor/tensor/random/type.py index 88d5e6197f..df8e3b691d 100644 --- a/pytensor/tensor/random/type.py +++ b/pytensor/tensor/random/type.py @@ -87,8 +87,8 @@ def filter(self, data, strict=False, allow_downcast=None): @staticmethod def values_eq(a, b): - sa = a if isinstance(a, dict) else a.__getstate__() - sb = b if isinstance(b, dict) else b.__getstate__() + sa = a if isinstance(a, dict) else a.bit_generator.state + sb = b if isinstance(b, dict) else b.bit_generator.state def _eq(sa, sb): for key in sa: diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 2c0e4231c8..fa25f3aac0 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -63,7 +63,9 @@ def test_random_updates(rng_ctor): assert all( a == b if not isinstance(a, np.ndarray) else np.array_equal(a, b) for a, b in zip( - rng.get_value().__getstate__(), original_value.__getstate__(), strict=True + rng.get_value().bit_generator.state, + original_value.bit_generator.state, + strict=True, ) ) diff --git a/tests/tensor/random/test_type.py b/tests/tensor/random/test_type.py index d289862347..d358f2a93a 100644 --- a/tests/tensor/random/test_type.py +++ b/tests/tensor/random/test_type.py @@ -52,7 +52,7 @@ def test_filter(self): with pytest.raises(TypeError): rng_type.filter(1) - rng_dict = rng.__getstate__() + rng_dict = rng.bit_generator.state assert rng_type.is_valid_value(rng_dict) is False assert rng_type.is_valid_value(rng_dict, strict=False) @@ -88,13 +88,13 @@ def test_values_eq(self): assert rng_type.values_eq(bitgen_g, bitgen_h) assert rng_type.is_valid_value(bitgen_a, strict=True) - assert rng_type.is_valid_value(bitgen_b.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_b.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_c, strict=True) - assert rng_type.is_valid_value(bitgen_d.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_d.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_e, strict=True) - assert rng_type.is_valid_value(bitgen_f.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_f.bit_generator.state, strict=False) assert rng_type.is_valid_value(bitgen_g, strict=True) - assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False) + assert rng_type.is_valid_value(bitgen_h.bit_generator.state, strict=False) def test_may_share_memory(self): bg_a = np.random.PCG64() diff --git a/tests/tensor/random/test_utils.py b/tests/tensor/random/test_utils.py index 70e8a710e9..f7d8731c1b 100644 --- a/tests/tensor/random/test_utils.py +++ b/tests/tensor/random/test_utils.py @@ -165,14 +165,20 @@ def test_seed(self, rng_ctor): state_rng = random.state_updates[0][0].get_value(borrow=True) if hasattr(state_rng, "get_state"): - ref_state = ref_rng.get_state() random_state = state_rng.get_state() + + # hack to try to get something reasonable for ref_rng + try: + ref_state = ref_rng.get_state() + except AttributeError: + ref_state = list(ref_rng.bit_generator.state.values()) + assert np.array_equal(random_state[1], ref_state[1]) assert random_state[0] == ref_state[0] assert random_state[2:] == ref_state[2:] else: - ref_state = ref_rng.__getstate__() - random_state = state_rng.__getstate__() + ref_state = ref_rng.bit_generator.state + random_state = state_rng.bit_generator.state assert random_state["bit_generator"] == ref_state["bit_generator"] assert random_state["state"] == ref_state["state"] From 0aa10c0bccf2b573e906073962cb6ef15eea4eb8 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Wed, 7 Aug 2024 10:22:12 +0100 Subject: [PATCH 10/19] Replace use of `np.MAXDIMS` `np.MAXDIMS` was removed from the public API and no replacement is given in the migration docs. In numpy <= 1.26, the value of `np.MAXDIMS` was 32. This was often used as a flag to mean `axis=None`. In numpy >= 2.0, the maximum number of dims of an array has been increased to 64; simultaneously, a constant `NPY_RAVEL_AXIS` was added to the C-API to indicate that `axis=None`. In most cases, the use of `np.MAXDIMS` to check for `axis=None` can be replaced by the new constant `NPY_RAVEL_AXIS`. To make this constant accessible when using numpy <= 1.26, I added a function to insert `npy_2_compat.h` into the support code for the affected ops. --- pytensor/npy_2_compat.py | 15 ++++++-- pytensor/tensor/extra_ops.py | 47 ++++++++++++++++--------- pytensor/tensor/math.py | 14 ++++++-- pytensor/tensor/special.py | 66 +++++++++++++++++++++++------------ pytensor/tensor/subtensor.py | 10 +++--- tests/tensor/test_elemwise.py | 4 ++- 6 files changed, 106 insertions(+), 50 deletions(-) diff --git a/pytensor/npy_2_compat.py b/pytensor/npy_2_compat.py index 30214154a2..facc3b8865 100644 --- a/pytensor/npy_2_compat.py +++ b/pytensor/npy_2_compat.py @@ -46,10 +46,21 @@ ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined] +# used in tests: the type of error thrown if a value is too large for the specified +# numpy data type is different in numpy 2.x +UintOverflowError = OverflowError if using_numpy_2 else TypeError + + +# to patch up some of the C code, we need to use these special values... if using_numpy_2: - UintOverflowError = OverflowError + numpy_axis_is_none_flag = np.iinfo(np.int32).min # the value of "NPY_RAVEL_AXIS" else: - UintOverflowError = TypeError + # 32 is the value used to mark axis = None in Numpy C-API prior to version 2.0 + numpy_axis_is_none_flag = 32 + + +# max number of dims is 64 in numpy 2.x; 32 in older versions +numpy_maxdims = 64 if using_numpy_2 else 32 def npy_2_compat_header() -> str: diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index e9d06ae9c2..7c6dfb9876 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -2,7 +2,6 @@ from collections.abc import Collection, Iterable import numpy as np -from numpy.exceptions import AxisError import pytensor import pytensor.scalar.basic as ps @@ -19,10 +18,11 @@ from pytensor.link.c.type import EnumList, Generic from pytensor.npy_2_compat import ( normalize_axis_index, - normalize_axis_tuple, + npy_2_compat_header, + numpy_axis_is_none_flag, ) from pytensor.raise_op import Assert -from pytensor.scalar import int32 as int_t +from pytensor.scalar import int64 as int_t from pytensor.scalar import upcast from pytensor.tensor import TensorLike, as_tensor_variable from pytensor.tensor import basic as ptb @@ -47,6 +47,7 @@ from pytensor.tensor.shape import Shape_i from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector +from pytensor.tensor.utils import normalize_reduce_axis from pytensor.tensor.variable import TensorVariable from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH @@ -302,7 +303,11 @@ def __init__(self, axis: int | None = None, mode="add"): self.axis = axis self.mode = mode - c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis) + @property + def c_axis(self) -> int: + if self.axis is None: + return numpy_axis_is_none_flag + return self.axis def make_node(self, x): x = ptb.as_tensor_variable(x) @@ -359,24 +364,37 @@ def infer_shape(self, fgraph, node, shapes): return shapes + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inames, onames, sub): (x,) = inames (z,) = onames fail = sub["fail"] params = sub["params"] - code = f""" - int axis = {params}->c_axis; + if self.axis is None: + axis_code = "int axis = NPY_RAVEL_AXIS;\n" + else: + axis_code = f"int axis = {params}->c_axis;\n" + + code = ( + axis_code + + f""" + #undef NPY_UF_DBG_TRACING + #define NPY_UF_DBG_TRACING 1 + if (axis == 0 && PyArray_NDIM({x}) == 1) - axis = NPY_MAXDIMS; + axis = NPY_RAVEL_AXIS; npy_intp shape[1] = {{ PyArray_SIZE({x}) }}; - if(axis == NPY_MAXDIMS && !({z} && PyArray_DIMS({z})[0] == shape[0])) + if(axis == NPY_RAVEL_AXIS && !({z} && PyArray_DIMS({z})[0] == shape[0])) {{ Py_XDECREF({z}); - {z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_{x})); + {z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE({x})); }} - else if(axis != NPY_MAXDIMS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) + else if(axis != NPY_RAVEL_AXIS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x})))) {{ Py_XDECREF({z}); {z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x})); @@ -403,11 +421,12 @@ def c_code(self, node, name, inames, onames, sub): Py_XDECREF(t); }} """ + ) return code def c_code_cache_version(self): - return (8,) + return (9,) def __str__(self): return f"{self.__class__.__name__}{{{self.axis}, {self.mode}}}" @@ -598,11 +617,7 @@ def squeeze(x, axis=None): elif not isinstance(axis, Collection): axis = (axis,) - # scalar inputs are treated as 1D regarding axis in this `Op` - try: - axis = normalize_axis_tuple(axis, ndim=max(1, _x.ndim)) - except AxisError: - raise AxisError(axis, ndim=_x.ndim) + axis = normalize_reduce_axis(axis, ndim=_x.ndim) if not axis: # Nothing to do diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index c4f3dc50a5..a88d678392 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -13,7 +13,11 @@ from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.npy_2_compat import normalize_axis_tuple +from pytensor.npy_2_compat import ( + normalize_axis_tuple, + npy_2_compat_header, + numpy_axis_is_none_flag, +) from pytensor.printing import pprint from pytensor.raise_op import Assert from pytensor.scalar.basic import BinaryScalarOp @@ -160,7 +164,7 @@ def get_params(self, node): c_axis = np.int64(self.axis[0]) else: # The value here doesn't matter, it won't be used - c_axis = np.int64(-1) + c_axis = numpy_axis_is_none_flag return self.params_type.get_params(c_axis=c_axis) def make_node(self, x): @@ -203,13 +207,17 @@ def perform(self, node, inp, outs): max_idx[0] = np.asarray(np.argmax(reshaped_x, axis=-1), dtype="int64") + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (argmax,) = out fail = sub["fail"] params = sub["params"] if self.axis is None: - axis_code = "axis = NPY_MAXDIMS;" + axis_code = "axis = NPY_RAVEL_AXIS;" else: if len(self.axis) != 1: raise NotImplementedError() diff --git a/pytensor/tensor/special.py b/pytensor/tensor/special.py index a2f02fabd8..5b05ad03f4 100644 --- a/pytensor/tensor/special.py +++ b/pytensor/tensor/special.py @@ -6,6 +6,7 @@ from pytensor.graph.basic import Apply from pytensor.graph.replace import _vectorize_node from pytensor.link.c.op import COp +from pytensor.npy_2_compat import npy_2_compat_header from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.math import gamma, gammaln, log, neg, sum @@ -60,12 +61,16 @@ def infer_shape(self, fgraph, node, shape): return [shape[1]] def c_code_cache_version(self): - return (4,) + return (5,) + + def c_support_code_apply(self, node: Apply, name: str) -> str: + # return super().c_support_code_apply(node, name) + return npy_2_compat_header() def c_code(self, node, name, inp, out, sub): dy, sm = inp (dx,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] return dedent( @@ -79,7 +84,7 @@ def c_code(self, node, name, inp, out, sub): int sm_ndim = PyArray_NDIM({sm}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || sm_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || sm_ndim == 1); // Validate inputs if ((PyArray_TYPE({dy}) != NPY_DOUBLE) && @@ -95,13 +100,15 @@ def c_code(self, node, name, inp, out, sub): {fail}; }} - if (axis < 0) axis = sm_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > sm_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad"); - {fail}; + if (axis < 0) axis = sm_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > sm_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in SoftmaxGrad"); + {fail}; + }} }} - if (({dx} == NULL) || !(PyArray_CompareLists(PyArray_DIMS({dx}), PyArray_DIMS({sm}), sm_ndim))) {{ @@ -289,10 +296,14 @@ def infer_shape(self, fgraph, node, shape): def c_headers(self, **kwargs): return ["", ""] + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (sm,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] # dtype = node.inputs[0].type.dtype_specs()[1] # TODO: put this into a templated function, in the support code @@ -309,7 +320,7 @@ def c_code(self, node, name, inp, out, sub): int x_ndim = PyArray_NDIM({x}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || x_ndim == 1); // Validate inputs if ((PyArray_TYPE({x}) != NPY_DOUBLE) && @@ -319,11 +330,14 @@ def c_code(self, node, name, inp, out, sub): {fail} }} - if (axis < 0) axis = x_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax"); - {fail} + if (axis < 0) axis = x_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in Softmax"); + {fail} + }} }} // Allocate Output Array @@ -481,7 +495,7 @@ def c_code(self, node, name, inp, out, sub): @staticmethod def c_code_cache_version(): - return (4,) + return (5,) def softmax(c, axis=None): @@ -541,10 +555,14 @@ def infer_shape(self, fgraph, node, shape): def c_headers(self, **kwargs): return [""] + def c_support_code_apply(self, node: Apply, name: str) -> str: + """Needed to define NPY_RAVEL_AXIS""" + return npy_2_compat_header() + def c_code(self, node, name, inp, out, sub): (x,) = inp (sm,) = out - axis = self.axis if self.axis is not None else np.MAXDIMS + axis = self.axis if self.axis is not None else "NPY_RAVEL_AXIS" fail = sub["fail"] return dedent( @@ -558,7 +576,7 @@ def c_code(self, node, name, inp, out, sub): int x_ndim = PyArray_NDIM({x}); int axis = {axis}; - int iterate_axis = !(axis == NPY_MAXDIMS || x_ndim == 1); + int iterate_axis = !(axis == NPY_RAVEL_AXIS || x_ndim == 1); // Validate inputs if ((PyArray_TYPE({x}) != NPY_DOUBLE) && @@ -568,13 +586,15 @@ def c_code(self, node, name, inp, out, sub): {fail} }} - if (axis < 0) axis = x_ndim + axis; - if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + if (iterate_axis) {{ - PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax"); - {fail} + if (axis < 0) axis = x_ndim + axis; + if ((axis < 0) || (iterate_axis && (axis > x_ndim))) + {{ + PyErr_SetString(PyExc_ValueError, "invalid axis in LogSoftmax"); + {fail} + }} }} - // Allocate Output Array if (({sm}) == NULL || !(PyArray_CompareLists(PyArray_DIMS({sm}), PyArray_DIMS({x}), x_ndim))) {{ @@ -730,7 +750,7 @@ def c_code(self, node, name, inp, out, sub): @staticmethod def c_code_cache_version(): - return (1,) + return (2,) def log_softmax(c, axis=None): diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 51e6dba0d8..c1fdb463b6 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -18,7 +18,7 @@ from pytensor.graph.utils import MethodNotDefined from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType -from pytensor.npy_2_compat import numpy_version, using_numpy_2 +from pytensor.npy_2_compat import npy_2_compat_header, numpy_version, using_numpy_2 from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.tensor import ( @@ -2149,7 +2149,7 @@ def infer_shape(self, fgraph, node, ishapes): def c_support_code(self, **kwargs): # In some versions of numpy, NPY_MIN_INTP is defined as MIN_LONG, # which is not defined. It should be NPY_MIN_LONG instead in that case. - return dedent( + return npy_2_compat_header() + dedent( """\ #ifndef MIN_LONG #define MIN_LONG NPY_MIN_LONG @@ -2174,7 +2174,7 @@ def c_code(self, node, name, input_names, output_names, sub): if (!PyArray_CanCastSafely(i_type, NPY_INTP) && PyArray_SIZE({i_name}) > 0) {{ npy_int64 min_val, max_val; - PyObject* py_min_val = PyArray_Min({i_name}, NPY_MAXDIMS, + PyObject* py_min_val = PyArray_Min({i_name}, NPY_RAVEL_AXIS, NULL); if (py_min_val == NULL) {{ {fail}; @@ -2184,7 +2184,7 @@ def c_code(self, node, name, input_names, output_names, sub): if (min_val == -1 && PyErr_Occurred()) {{ {fail}; }} - PyObject* py_max_val = PyArray_Max({i_name}, NPY_MAXDIMS, + PyObject* py_max_val = PyArray_Max({i_name}, NPY_RAVEL_AXIS, NULL); if (py_max_val == NULL) {{ {fail}; @@ -2243,7 +2243,7 @@ def c_code(self, node, name, input_names, output_names, sub): """ def c_code_cache_version(self): - return (0, 1, 2) + return (0, 1, 2, 3) advanced_subtensor1 = AdvancedSubtensor1() diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 8555a1d29f..45a7f53c2c 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -18,6 +18,7 @@ from pytensor.graph.replace import vectorize_node from pytensor.link.basic import PerformLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker +from pytensor.npy_2_compat import numpy_maxdims from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import get_scalar_constant_value, second from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise @@ -121,7 +122,8 @@ def test_infer_shape(self): def test_too_big_rank(self): x = self.type(self.dtype, shape=())() - y = x.dimshuffle(("x",) * (np.MAXDIMS + 1)) + y = x.dimshuffle(("x",) * (numpy_maxdims + 1)) + with pytest.raises(ValueError): y.eval({x: 0}) From b349a9a763f2a054016b152a05bca2c8bfc1cc77 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Tue, 27 Aug 2024 13:22:14 +0100 Subject: [PATCH 11/19] Fixed failed test due to uint8 overflow In numpy 2.0, -1 as uint8 is out of bounds, whereas previously it would be converted to 255. This affected the test helper function `reduced_bitwise_and`. The helper function was changed to use 255 instead of -1 if the dtype was uint8, since this is what is needed to match the behavior of the "bitwise and" op. `reduced_bitwise_and` was only used by `TestCAReduce` in `tests/tensor/test_elemwise.py`, so it was moved there from `tests/tensor/test_math.py` --- tests/compile/function/test_function.py | 9 +++++---- tests/compile/function/test_pfunc.py | 17 ++++++++++------- tests/tensor/test_elemwise.py | 22 +++++++++++++++++++++- tests/tensor/test_math.py | 16 ---------------- 4 files changed, 36 insertions(+), 28 deletions(-) diff --git a/tests/compile/function/test_function.py b/tests/compile/function/test_function.py index f835953b19..9f75ef15d8 100644 --- a/tests/compile/function/test_function.py +++ b/tests/compile/function/test_function.py @@ -11,6 +11,7 @@ from pytensor.compile.function import function, function_dump from pytensor.compile.io import In from pytensor.configdefaults import config +from pytensor.npy_2_compat import UintOverflowError from pytensor.tensor.type import ( bscalar, bvector, @@ -166,12 +167,12 @@ def test_in_allow_downcast_int(self): # Value too big for a, silently ignored assert np.array_equal(f([2**20], np.ones(1, dtype="int8"), 1), [2]) - # Value too big for b, raises TypeError - with pytest.raises(TypeError): + # Value too big for b, raises OverflowError (in numpy >= 2.0... TypeError in numpy < 2.0) + with pytest.raises(UintOverflowError): f([3], [312], 1) - # Value too big for c, raises TypeError - with pytest.raises(TypeError): + # Value too big for c, raises OverflowError + with pytest.raises(UintOverflowError): f([3], [6], 806) def test_in_allow_downcast_floatX(self): diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index 0a9bda9846..249f230d81 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -9,6 +9,7 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.graph.utils import MissingInputError +from pytensor.npy_2_compat import UintOverflowError from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.type import ( bscalar, @@ -237,12 +238,12 @@ def test_param_allow_downcast_int(self): # Value too big for a, silently ignored assert np.all(f([2**20], np.ones(1, dtype="int8"), 1) == 2) - # Value too big for b, raises TypeError - with pytest.raises(TypeError): + # Value too big for b, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): f([3], [312], 1) - # Value too big for c, raises TypeError - with pytest.raises(TypeError): + # Value too big for c, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): f([3], [6], 806) def test_param_allow_downcast_floatX(self): @@ -327,8 +328,8 @@ def test_allow_input_downcast_int(self): with pytest.raises(TypeError): g([3], np.array([6], dtype="int16"), 0) - # Value too big for b, raises TypeError - with pytest.raises(TypeError): + # Value too big for b, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): g([3], [312], 0) h = pfunc([a, b, c], (a + b + c)) # Default: allow_input_downcast=None @@ -336,7 +337,9 @@ def test_allow_input_downcast_int(self): assert np.all(h([3], [6], 0) == 9) with pytest.raises(TypeError): h([3], np.array([6], dtype="int16"), 0) - with pytest.raises(TypeError): + + # Value too big for b, raises OverflowError in numpy >= 2.0, TypeError in numpy <2.0 + with pytest.raises(UintOverflowError): h([3], [312], 0) def test_allow_downcast_floatX(self): diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 45a7f53c2c..5ce533d3a3 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -40,7 +40,27 @@ ) from tests import unittest_tools from tests.link.test_link import make_function -from tests.tensor.test_math import reduce_bitwise_and + + +def reduce_bitwise_and(x, axis=-1, dtype="int8"): + """Helper function for TestCAReduce""" + if dtype == "uint8": + # in numpy version >= 2.0, out of bounds uint8 values are not converted + identity = np.array((255,), dtype=dtype)[0] + else: + identity = np.array((-1,), dtype=dtype)[0] + + shape_without_axis = tuple(s for i, s in enumerate(x.shape) if i != axis) + if 0 in shape_without_axis: + return np.empty(shape=shape_without_axis, dtype=x.dtype) + + def custom_reduce(a): + out = identity + for i in range(a.size): + out = np.bitwise_and(a[i], out) + return out + + return np.apply_along_axis(custom_reduce, axis, x) class TestDimShuffle(unittest_tools.InferShapeTester): diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 40c505b7b4..64af7057a5 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -3444,22 +3444,6 @@ def test_var_axes(self): x.var(a) -def reduce_bitwise_and(x, axis=-1, dtype="int8"): - identity = np.array((-1,), dtype=dtype)[0] - - shape_without_axis = tuple(s for i, s in enumerate(x.shape) if i != axis) - if 0 in shape_without_axis: - return np.empty(shape=shape_without_axis, dtype=x.dtype) - - def custom_reduce(a): - out = identity - for i in range(a.size): - out = np.bitwise_and(a[i], out) - return out - - return np.apply_along_axis(custom_reduce, axis, x) - - def test_clip_grad(): # test the gradient of clip def func(x, y, z): From 9e919c74d65d9b27445765b7211ec18d20be246b Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Fri, 30 Aug 2024 12:01:19 +0100 Subject: [PATCH 12/19] Changes due to new numpy scalar promotion rules 1. Changed autocaster due to new promotion rules With "weak promotion" of python types in Numpy 2.0, the statement `1.1 == np.asarray(1.1).astype('float32')` is True, whereas in Numpy 1.26, it was false. However, in numpy 1.26, `1.1 == np.asarray([1.1]).astype('float32')` was true, so the scalar behavior and array behavior are the same in Numpy 2.0, while they were different in numpy 1.26. Essentially, in Numpy 2.0, if python floats are used in operations with numpy floats or arrays, then the type of the numpy object will be used (i.e. the python value will be treated as the type of the numpy objects). To preserve the behavior of `NumpyAutocaster` from numpy <= 1.26, I've added an explicit conversion of the value to be converted to a numpy type using `np.asarray` during the check that decides what dtype to cast to. 2. Updates due to new numpy conversion rules for out-of-bounds python ints In numpy 2.0, out of bounds python ints will not be automatically converted, and will raise an `OverflowError` instead. For instance, converting 255 to int8 will raise an error, instead of returning -1. To explicitly force conversion, we must use `np.asarray(value).astype(dtype)`, rather than `np.asarray(value, dtype=dtype)`. The code in `TensorType.filter` has been changed to the new recommended way to downcast, and the error type caught by some tests has been changed to OverflowError from TypeError --- pytensor/scalar/basic.py | 4 +++- pytensor/tensor/type.py | 2 +- tests/compile/function/test_pfunc.py | 1 + tests/tensor/test_basic.py | 1 - 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index d7d719e2f4..f8ecabd7b2 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -183,7 +183,9 @@ def __call__(self, x): for dtype in try_dtypes: x_ = np.asarray(x).astype(dtype=dtype) - if np.all(x == x_): + if np.all( + np.asarray(x) == x_ + ): # use np.asarray(x) to match TensorType.filter break # returns either an exact x_==x, or the last cast x_ return x_ diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index d48a7a6f08..b96113c8e3 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -178,7 +178,7 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray: else: if allow_downcast: # Convert to self.dtype, regardless of the type of data - data = np.asarray(data, dtype=self.dtype) + data = np.asarray(data).astype(self.dtype) # TODO: consider to pad shape with ones to make it consistent # with self.broadcastable... like vector->row type thing else: diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index 249f230d81..3e23b12f74 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -335,6 +335,7 @@ def test_allow_input_downcast_int(self): h = pfunc([a, b, c], (a + b + c)) # Default: allow_input_downcast=None # Everything here should behave like with False assert np.all(h([3], [6], 0) == 9) + with pytest.raises(TypeError): h([3], np.array([6], dtype="int16"), 0) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 6b5ec48112..467dc66407 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3198,7 +3198,6 @@ def test_autocast_custom(): assert (dvector() + 1.1).dtype == "float64" assert (fvector() + np.float32(1.1)).dtype == "float32" assert (fvector() + np.float64(1.1)).dtype == "float64" - assert (fvector() + 1.1).dtype == config.floatX assert (lvector() + np.int64(1)).dtype == "int64" assert (lvector() + np.int32(1)).dtype == "int64" assert (lvector() + np.int16(1)).dtype == "int64" From bce361323bddf4be19280c4b364184b748eca372 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Sun, 10 Nov 2024 14:57:16 +0000 Subject: [PATCH 13/19] Fix for NameError in test I was getting a NameError from the list comprehensions saying that e.g. `pytensor_scalar` was not defined. I'm not sure why, but this is another (more verbose) way to do the same thing. --- tests/tensor/test_math.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 64af7057a5..374a22ab5d 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -2492,11 +2492,22 @@ def pytensor_i_scalar(dtype): def numpy_i_scalar(dtype): return numpy_scalar(dtype) + pytensor_funcs = { + "scalar": pytensor_scalar, + "array": pytensor_array, + "i_scalar": pytensor_i_scalar, + } + numpy_funcs = { + "scalar": numpy_scalar, + "array": numpy_array, + "i_scalar": numpy_i_scalar, + } + with config.change_flags(cast_policy="numpy+floatX"): # We will test all meaningful combinations of # scalar and array operations. - pytensor_args = [eval(f"pytensor_{c}") for c in combo] - numpy_args = [eval(f"numpy_{c}") for c in combo] + pytensor_args = [pytensor_funcs[c] for c in combo] + numpy_args = [numpy_funcs[c] for c in combo] pytensor_arg_1 = pytensor_args[0](a_type) pytensor_arg_2 = pytensor_args[1](b_type) pytensor_dtype = op( From 45c3a0182cc394eda1f3ac05b9d545e50d6e8a4b Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Fri, 24 Jan 2025 15:43:23 +0000 Subject: [PATCH 14/19] Updated doctests From numpy PR https://github.com/numpy/numpy/pull/22449, the repr of scalar values has changed, e.g. from "1" to "np.int64(1)", which caused two doctests to fail. --- pytensor/tensor/einsum.py | 2 +- pytensor/tensor/subtensor.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/einsum.py b/pytensor/tensor/einsum.py index 88a6257c9c..660c16d387 100644 --- a/pytensor/tensor/einsum.py +++ b/pytensor/tensor/einsum.py @@ -256,7 +256,7 @@ def _general_dot( .. testoutput:: - (3, 4, 2) + (np.int64(3), np.int64(4), np.int64(2)) """ # Shortcut for non batched case if not batch_axes[0] and not batch_axes[1]: diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index c1fdb463b6..3a2304eb7b 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -757,13 +757,15 @@ def get_constant_idx( Example usage where `v` and `a` are appropriately typed PyTensor variables : >>> from pytensor.scalar import int64 >>> from pytensor.tensor import matrix + >>> import numpy as np + >>> >>> v = int64("v") >>> a = matrix("a") >>> b = a[v, 1:3] >>> b.owner.op.idx_list (ScalarType(int64), slice(ScalarType(int64), ScalarType(int64), None)) >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs, allow_partial=True) - [v, slice(1, 3, None)] + [v, slice(np.int64(1), np.int64(3), None)] >>> get_constant_idx(b.owner.op.idx_list, b.owner.inputs) Traceback (most recent call last): pytensor.tensor.exceptions.NotScalarConstantError From 2bfe6dd86be5957c5469d3ef112bbd18db3edc01 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Thu, 30 Jan 2025 14:02:45 +0000 Subject: [PATCH 15/19] Preserve numpy < 2.0 Unique inverse output shape In numpy 2.0, if axis=None, then np.unique does not flatten the inverse indices returned if return_inverse=True A helper function has been added to npy_2_compat.py to mimic the output of `np.unique` from version of numpy before 2.0 --- pytensor/npy_2_compat.py | 22 ++++++++++++++++++++++ pytensor/tensor/extra_ops.py | 19 ++++++++++++++++--- tests/tensor/test_extra_ops.py | 17 +++++++++-------- 3 files changed, 47 insertions(+), 11 deletions(-) diff --git a/pytensor/npy_2_compat.py b/pytensor/npy_2_compat.py index facc3b8865..667a5c074e 100644 --- a/pytensor/npy_2_compat.py +++ b/pytensor/npy_2_compat.py @@ -63,6 +63,28 @@ numpy_maxdims = 64 if using_numpy_2 else 32 +# function that replicates np.unique from numpy < 2.0 +def old_np_unique( + arr, return_index=False, return_inverse=False, return_counts=False, axis=None +): + """Replicate np.unique from numpy versions < 2.0""" + if not return_inverse or not using_numpy_2: + return np.unique(arr, return_index, return_inverse, return_counts, axis) + + outs = list(np.unique(arr, return_index, return_inverse, return_counts, axis)) + + inv_idx = 2 if return_index else 1 + + if axis is None: + outs[inv_idx] = np.ravel(outs[inv_idx]) + else: + inv_shape = (arr.shape[axis],) + outs[inv_idx] = outs[inv_idx].reshape(inv_shape) + + return tuple(outs) + + +# compatibility header for C code def npy_2_compat_header() -> str: """Compatibility header that Numpy suggests is vendored with code that uses Numpy < 2.0 and Numpy 2.x""" return dedent(""" diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 7c6dfb9876..7a1bc75b0b 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -20,6 +20,7 @@ normalize_axis_index, npy_2_compat_header, numpy_axis_is_none_flag, + old_np_unique, ) from pytensor.raise_op import Assert from pytensor.scalar import int64 as int_t @@ -1226,6 +1227,9 @@ class Unique(Op): """ Wraps `numpy.unique`. + The indices returned when `return_inverse` is True are ravelled + to match the behavior of `numpy.unique` from before numpy version 2.0. + Examples -------- >>> import numpy as np @@ -1271,17 +1275,21 @@ def make_node(self, x): outputs = [TensorType(dtype=x.dtype, shape=out_shape)()] typ = TensorType(dtype="int64", shape=(None,)) + if self.return_index: outputs.append(typ()) + if self.return_inverse: outputs.append(typ()) + if self.return_counts: outputs.append(typ()) + return Apply(self, [x], outputs) def perform(self, node, inputs, output_storage): [x] = inputs - outs = np.unique( + outs = old_np_unique( x, return_index=self.return_index, return_inverse=self.return_inverse, @@ -1306,9 +1314,14 @@ def infer_shape(self, fgraph, node, i0_shapes): out_shapes[0] = tuple(shape) if self.return_inverse: - shape = prod(x_shape) if self.axis is None else x_shape[axis] return_index_out_idx = 2 if self.return_index else 1 - out_shapes[return_index_out_idx] = (shape,) + + if self.axis is not None: + shape = (x_shape[axis],) + else: + shape = (prod(x_shape),) + + out_shapes[return_index_out_idx] = shape return out_shapes diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 54bb7f4333..6a93f3c7fd 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -9,6 +9,7 @@ from pytensor.compile.mode import Mode from pytensor.configdefaults import config from pytensor.graph.basic import Constant, applys_between, equal_computations +from pytensor.npy_2_compat import old_np_unique from pytensor.raise_op import Assert from pytensor.tensor import alloc from pytensor.tensor.elemwise import DimShuffle @@ -899,14 +900,14 @@ def setup_method(self): ) def test_basic_vector(self, x, inp, axis): list_outs_expected = [ - np.unique(inp, axis=axis), - np.unique(inp, True, axis=axis), - np.unique(inp, False, True, axis=axis), - np.unique(inp, True, True, axis=axis), - np.unique(inp, False, False, True, axis=axis), - np.unique(inp, True, False, True, axis=axis), - np.unique(inp, False, True, True, axis=axis), - np.unique(inp, True, True, True, axis=axis), + old_np_unique(inp, axis=axis), + old_np_unique(inp, True, axis=axis), + old_np_unique(inp, False, True, axis=axis), + old_np_unique(inp, True, True, axis=axis), + old_np_unique(inp, False, False, True, axis=axis), + old_np_unique(inp, True, False, True, axis=axis), + old_np_unique(inp, False, True, True, axis=axis), + old_np_unique(inp, True, True, True, axis=axis), ] for params, outs_expected in zip( self.op_params, list_outs_expected, strict=True From cd75f954fb11b502da41a73bd0046b3ef2019b4d Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Thu, 30 Jan 2025 14:55:29 +0000 Subject: [PATCH 16/19] Fix test for neg on unsigned Due to changes in numpy conversion rules (NEP 50), overflows are not ignored; in particular, negating a unsigned int causes an overflow error. The test for `neg` has been changed to check that this error is raised. --- tests/tensor/test_math.py | 12 +++++++++++- tests/tensor/utils.py | 21 +++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 374a22ab5d..f2331be62e 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -23,6 +23,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import vectorize_node from pytensor.link.c.basic import DualLinker +from pytensor.npy_2_compat import using_numpy_2 from pytensor.printing import pprint from pytensor.raise_op import Assert from pytensor.tensor import blas, blas_c @@ -391,11 +392,20 @@ def test_maximum_minimum_grad(): grad=_grad_broadcast_unary_normal, ) + +# in numpy >= 2.0, negating a uint raises an error +neg_good = _good_broadcast_unary_normal.copy() +if using_numpy_2: + neg_bad = {"uint8": neg_good.pop("uint8"), "uint16": neg_good.pop("uint16")} +else: + neg_bad = None + TestNegBroadcast = makeBroadcastTester( op=neg, expected=lambda x: -x, - good=_good_broadcast_unary_normal, + good=neg_good, grad=_grad_broadcast_unary_normal, + bad_compile=neg_bad, ) TestSgnBroadcast = makeBroadcastTester( diff --git a/tests/tensor/utils.py b/tests/tensor/utils.py index b94750ffe2..1a8b2455ec 100644 --- a/tests/tensor/utils.py +++ b/tests/tensor/utils.py @@ -339,6 +339,7 @@ def makeTester( good=None, bad_build=None, bad_runtime=None, + bad_compile=None, grad=None, mode=None, grad_rtol=None, @@ -373,6 +374,7 @@ def makeTester( _test_memmap = test_memmap _check_name = check_name _grad_eps = grad_eps + _bad_compile = bad_compile or {} class Checker: op = staticmethod(_op) @@ -382,6 +384,7 @@ class Checker: good = _good bad_build = _bad_build bad_runtime = _bad_runtime + bad_compile = _bad_compile grad = _grad mode = _mode skip = skip_ @@ -539,6 +542,24 @@ def test_bad_build(self): # instantiated on the following bad inputs: %s" # % (self.op, testname, node, inputs)) + @config.change_flags(compute_test_value="off") + @pytest.mark.skipif(skip, reason="Skipped") + def test_bad_compile(self): + for testname, inputs in self.bad_compile.items(): + inputrs = [shared(input) for input in inputs] + try: + node = safe_make_node(self.op, *inputrs) + except Exception as exc: + err_msg = ( + f"Test {self.op}::{testname}: Error occurred while trying" + f" to make a node with inputs {inputs}" + ) + exc.args += (err_msg,) + raise + + with pytest.raises(Exception): + inplace_func([], node.outputs, mode=mode, name="test_bad_runtime") + @config.change_flags(compute_test_value="off") @pytest.mark.skipif(skip, reason="Skipped") def test_bad_runtime(self): From 93dd7c8d60a1908192fd00d45699e08e12ecb4d5 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Tue, 4 Feb 2025 13:56:51 +0000 Subject: [PATCH 17/19] Split up TestMinMax::test_uint I split this test up to test uint64 separately, since this is the case discussed in Issue #770. I also added a test for the exact example used in that issue. The uint dtypes with lower precision should pass. The uint64 case started passing for me locally on Mac OSX, but still fails on CI. I'm not sure why this is, but at least the test will be more specific now if it fails in the future. --- tests/tensor/test_math.py | 41 ++++++++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index f2331be62e..9ab4fd104d 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -1403,18 +1403,37 @@ def _grad_list(self): # check_grad_max(data, eval_outputs(grad(max_and_argmax(n, # axis=1)[0], n)),axis=1) + @pytest.mark.parametrize( + "dtype", + ( + "uint8", + "uint16", + "uint32", + pytest.param("uint64", marks=pytest.mark.xfail(reason="Fails due to #770")), + ), + ) + def test_uint(self, dtype): + itype = np.iinfo(dtype) + data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) + n = as_tensor_variable(data) + + assert min(n).dtype == dtype + i_min = eval_outputs(min(n)) + assert i_min == itype.min + + assert max(n).dtype == dtype + i_max = eval_outputs(max(n)) + assert i_max == itype.max + @pytest.mark.xfail(reason="Fails due to #770") - def test_uint(self): - for dtype in ("uint8", "uint16", "uint32", "uint64"): - itype = np.iinfo(dtype) - data = np.array([itype.min + 3, itype.min, itype.max - 5, itype.max], dtype) - n = as_tensor_variable(data) - assert min(n).dtype == dtype - i = eval_outputs(min(n)) - assert i == itype.min - assert max(n).dtype == dtype - i = eval_outputs(max(n)) - assert i == itype.max + def test_uint64_special_value(self): + """Example from issue #770""" + dtype = "uint64" + data = np.array([0, 9223372036854775], dtype=dtype) + n = as_tensor_variable(data) + + i_max = eval_outputs(max(n)) + assert i_max == data.max() def test_bool(self): data = np.array([True, False], "bool") From 720568cfbd936ca4d0436281c9c477fe30ab2dd8 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Tue, 27 Aug 2024 11:23:43 +0100 Subject: [PATCH 18/19] Unpinned numpy Also added ruff numpy2 transition rule. --- environment-osx-arm64.yml | 2 +- environment.yml | 2 +- pyproject.toml | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/environment-osx-arm64.yml b/environment-osx-arm64.yml index 13a68faaaa..c9dc703dcc 100644 --- a/environment-osx-arm64.yml +++ b/environment-osx-arm64.yml @@ -9,7 +9,7 @@ channels: dependencies: - python=>3.10 - compilers - - numpy>=1.17.0,<2 + - numpy>=1.17.0 - scipy>=1,<2 - filelock>=3.15 - etuples diff --git a/environment.yml b/environment.yml index 1571ae0d11..9bdddfb6f6 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ channels: dependencies: - python>=3.10 - compilers - - numpy>=1.17.0,<2 + - numpy>=1.17.0 - scipy>=1,<2 - filelock>=3.15 - etuples diff --git a/pyproject.toml b/pyproject.toml index e82c42753a..e796e35a10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ keywords = [ dependencies = [ "setuptools>=59.0.0", "scipy>=1,<2", - "numpy>=1.17.0,<2", + "numpy>=1.17.0", "filelock>=3.15", "etuples", "logical-unification", @@ -129,7 +129,7 @@ exclude = ["doc/", "pytensor/_version.py"] docstring-code-format = true [tool.ruff.lint] -select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20"] +select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"] ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"] unfixable = [ # zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead From b633bcacd6c3b68d56b482bccd91cbcb65df84d5 Mon Sep 17 00:00:00 2001 From: Brendan Murphy Date: Tue, 4 Feb 2025 15:29:05 +0000 Subject: [PATCH 19/19] Added numpy 1.26.* to CI Remaining tests now run on latest numpy, except for Numba jobs, which need numpy 2.1.0 --- .github/workflows/test.yml | 31 ++++++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 53f1e16606..5bb416f893 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -65,7 +65,7 @@ jobs: - uses: pre-commit/action@v3.0.1 test: - name: "${{ matrix.os }} test py${{ matrix.python-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}" + name: "${{ matrix.os }} test py${{ matrix.python-version }} numpy${{ matrix.numpy-version }} : fast-compile ${{ matrix.fast-compile }} : float32 ${{ matrix.float32 }} : ${{ matrix.part }}" needs: - changes - style @@ -76,6 +76,7 @@ jobs: matrix: os: ["ubuntu-latest"] python-version: ["3.10", "3.12"] + numpy-version: ["~=1.26.0", ">=2.0"] fast-compile: [0, 1] float32: [0, 1] install-numba: [0] @@ -105,45 +106,68 @@ jobs: float32: 1 - part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" fast-compile: 1 + - numpy-version: "~=1.26.0" + fast-compile: 1 + - numpy-version: "~=1.26.0" + float32: 1 + - numpy-version: "~=1.26.0" + python-version: "3.12" + - numpy-version: "~=1.26.0" + part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link" include: - install-numba: 1 os: "ubuntu-latest" python-version: "3.10" + numpy-version: "~=2.1.0" fast-compile: 0 float32: 0 part: "tests/link/numba" - install-numba: 1 os: "ubuntu-latest" python-version: "3.12" + numpy-version: "~=2.1.0" fast-compile: 0 float32: 0 part: "tests/link/numba" - install-jax: 1 os: "ubuntu-latest" python-version: "3.10" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 part: "tests/link/jax" - install-jax: 1 os: "ubuntu-latest" python-version: "3.12" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 part: "tests/link/jax" - install-torch: 1 os: "ubuntu-latest" python-version: "3.10" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 part: "tests/link/pytorch" - os: macos-15 python-version: "3.12" + numpy-version: ">=2.0" fast-compile: 0 float32: 0 install-numba: 0 install-jax: 0 install-torch: 0 part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py" + - os: "ubuntu-latest" + python-version: "3.10" + numpy-version: "~=1.26.0" + fast-compile: 0 + float32: 0 + install-numba: 0 + install-jax: 0 + install-torch: 0 + part: "tests/tensor/test_math.py" steps: - uses: actions/checkout@v4 @@ -174,9 +198,9 @@ jobs: run: | if [[ $OS == "macos-15" ]]; then - micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" numpy scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate; + micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate; else - micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; + micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi @@ -193,6 +217,7 @@ jobs: fi env: PYTHON_VERSION: ${{ matrix.python-version }} + NUMPY_VERSION: ${{ matrix.numpy-version }} INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_TORCH: ${{ matrix.install-torch}}