Skip to content

Commit

Permalink
Backwards compat for complex types C code
Browse files Browse the repository at this point in the history
  • Loading branch information
brendan-m-murphy committed Feb 6, 2025
1 parent cb4dec7 commit 7090c1c
Showing 1 changed file with 52 additions and 8 deletions.
60 changes: 52 additions & 8 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,10 +522,43 @@ def c_support_code(self, **kwargs):
# In that case we add the 'int' type to the real types.
real_types.append("int")

# Macros for backwards compatibility with numpy < 2.0
#
# In numpy 2.0+, these are defined in npy_math.h, but
# for early versions, they must be vendored by users (e.g. PyTensor)
backwards_compat_macros = """
#ifndef NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
#define NUMPY_CORE_INCLUDE_NUMPY_NPY_2_COMPLEXCOMPAT_H_
#include <numpy/npy_math.h>
#ifndef NPY_CSETREALF
#define NPY_CSETREALF(c, r) (c)->real = (r)
#endif
#ifndef NPY_CSETIMAGF
#define NPY_CSETIMAGF(c, i) (c)->imag = (i)
#endif
#ifndef NPY_CSETREAL
#define NPY_CSETREAL(c, r) (c)->real = (r)
#endif
#ifndef NPY_CSETIMAG
#define NPY_CSETIMAG(c, i) (c)->imag = (i)
#endif
#ifndef NPY_CSETREALL
#define NPY_CSETREALL(c, r) (c)->real = (r)
#endif
#ifndef NPY_CSETIMAGL
#define NPY_CSETIMAGL(c, i) (c)->imag = (i)
#endif
#endif
"""

def _make_get_set_real_imag(scalar_type: str) -> str:
"""Make overloaded getter/setter functions for real/imag parts of numpy complex types.
The functions called by these getter/setter functions are defining in npy_math.h
The functions called by these getter/setter functions are defining in npy_math.h, or
in the `backward_compat_macros` defined above.
Args:
scalar_type: float, double, or longdouble
Expand All @@ -536,11 +569,11 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
"""
complex_type = "npy_c" + scalar_type
suffix = "" if scalar_type == "double" else scalar_type[0]
return_type = scalar_type

if scalar_type == "longdouble":
scalar_type += "_t"
return_type = "npy_" + return_type
scalar_type = "npy_" + scalar_type

return_type = scalar_type

template = f"""
static inline {return_type} get_real(const {complex_type} z)
Expand All @@ -550,7 +583,7 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
static inline void set_real({complex_type} *z, const {scalar_type} r)
{{
npy_csetreal{suffix}(z, r);
NPY_CSETREAL{suffix.upper()}(z, r);
}}
static inline {return_type} get_imag(const {complex_type} z)
Expand All @@ -560,17 +593,28 @@ def _make_get_set_real_imag(scalar_type: str) -> str:
static inline void set_imag({complex_type} *z, const {scalar_type} i)
{{
npy_csetimag{suffix}(z, i);
NPY_CSETIMAG{suffix.upper()}(z, i);
}}
"""
return template

# TODO: add guard code to prevent this from being defining twice, in case we need to add it somewhere else
get_set_aliases = "\n".join(
_make_get_set_real_imag(stype)
for stype in ["float", "double", "longdouble"]
)

get_set_aliases = backwards_compat_macros + "\n" + get_set_aliases

# Template for defining pytensor_complex64 and pytensor_complex128 structs/classes
#
# The npy_complex64, npy_complex128 types are aliases defined at run time based on
# the size of floats and doubles on the machine. This means that both types are
# not necessarily defined on every machine, but a machine with 32-bit floats and
# 64-bit doubles will have npy_complex64 as an alias of npy_cfloat and npy_complex128
# as an alias of npy_complex128.
#
# In any case, the get/set real/imag functions defined above will always work for
# npy_complex64 and npy_complex128.
template = """
struct pytensor_complex%(nbits)s : public npy_complex%(nbits)s {
typedef pytensor_complex%(nbits)s complex_type;
Expand Down Expand Up @@ -719,7 +763,7 @@ def c_init_code(self, **kwargs):
return ["import_array();"]

def c_code_cache_version(self):
return (15, np.version.git_revision)
return (18, np.version.git_revision)

def get_shape_info(self, obj):
return obj.itemsize
Expand Down

0 comments on commit 7090c1c

Please sign in to comment.