Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make blockwise perform method node dependent #1048

Merged
merged 2 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 36 additions & 27 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Sequence
from copy import copy
from typing import Any, cast

import numpy as np
Expand Down Expand Up @@ -79,7 +78,6 @@ def __init__(
self.name = name
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None
if destroy_map is not None:
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
Expand All @@ -91,11 +89,6 @@ def __init__(

super().__init__(**kwargs)

def __getstate__(self):
d = copy(self.__dict__)
d["_gufunc"] = None
return d

def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
core_input_types = []
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
Expand Down Expand Up @@ -296,32 +289,46 @@ def L_op(self, inputs, outs, ograds):

return rval

def _create_gufunc(self, node):
def _create_node_gufunc(self, node) -> None:
"""Define (or retrieve) the node gufunc used in `perform`.

If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.

The gufunc is stored in the tag of the node.
"""
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)

if gufunc_spec is not None:
self._gufunc = import_func_from_string(gufunc_spec[0])
if self._gufunc:
return self._gufunc
else:
gufunc = import_func_from_string(gufunc_spec[0])
if gufunc is None:
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")

n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs)

def core_func(*inner_inputs):
inner_outputs = [[None] for _ in range(n_outs)]
else:
# Wrap core_op perform method in numpy vectorize
n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs)
inner_outputs_storage = [[None] for _ in range(n_outs)]

def core_func(
*inner_inputs,
core_node=core_node,
inner_outputs_storage=inner_outputs_storage,
):
self.core_op.perform(
core_node,
[np.asarray(inp) for inp in inner_inputs],
inner_outputs_storage,
)

inner_inputs = [np.asarray(inp) for inp in inner_inputs]
self.core_op.perform(core_node, inner_inputs, inner_outputs)
if n_outs == 1:
return inner_outputs_storage[0][0]
else:
return tuple(r[0] for r in inner_outputs_storage)

if len(inner_outputs) == 1:
return inner_outputs[0][0]
else:
return tuple(r[0] for r in inner_outputs)
gufunc = np.vectorize(core_func, signature=self.signature)

self._gufunc = np.vectorize(core_func, signature=self.signature)
return self._gufunc
node.tag.gufunc = gufunc

def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self.batch_ndim(node)
Expand All @@ -340,10 +347,12 @@ def _check_runtime_broadcast(self, node, inputs):
)

def perform(self, node, inputs, output_storage):
gufunc = self._gufunc
gufunc = getattr(node.tag, "gufunc", None)

if gufunc is None:
gufunc = self._create_gufunc(node)
# Cache it once per node
self._create_node_gufunc(node)
gufunc = node.tag.gufunc

self._check_runtime_broadcast(node, inputs)

Expand Down
35 changes: 35 additions & 0 deletions tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,41 @@
from pytensor.tensor.utils import _parse_gufunc_signature


def test_perform_method_per_node():
"""Confirm that Blockwise uses one perform method per node.

This is important if the perform method requires node information (such as dtypes)
"""

class NodeDependentPerformOp(Op):
def make_node(self, x):
return Apply(self, [x], [x.type()])

def perform(self, node, inputs, outputs):
[x] = inputs
if node.inputs[0].type.dtype.startswith("float"):
y = x + 1
else:
y = x - 1
outputs[0][0] = y

blockwise_op = Blockwise(core_op=NodeDependentPerformOp(), signature="()->()")
x = tensor("x", shape=(3,), dtype="float32")
y = tensor("y", shape=(3,), dtype="int32")

out_x = blockwise_op(x)
out_y = blockwise_op(y)
fn = pytensor.function([x, y], [out_x, out_y])
[op1, op2] = [node.op for node in fn.maker.fgraph.apply_nodes]
# Confirm both nodes have the same Op
assert op1 is blockwise_op
assert op1 is op2

res_out_x, res_out_y = fn(np.zeros(3, dtype="float32"), np.zeros(3, dtype="int32"))
np.testing.assert_array_equal(res_out_x, np.ones(3, dtype="float32"))
np.testing.assert_array_equal(res_out_y, -np.ones(3, dtype="int32"))


def test_vectorize_blockwise():
mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None))
Expand Down
Loading