From 4134881fea71798858b0ebc2bfa5519e86169ae7 Mon Sep 17 00:00:00 2001 From: HarshvirSandhu Date: Wed, 10 Jul 2024 00:15:27 +0530 Subject: [PATCH] Implement indexing operations in pytorch Co-authored-by: Ricardo Vieira <28983449+ricardov94@users.noreply.github.com> --- pytensor/compile/mode.py | 1 + pytensor/link/pytorch/dispatch/__init__.py | 3 +- pytensor/link/pytorch/dispatch/basic.py | 34 +++- pytensor/link/pytorch/dispatch/subtensor.py | 124 +++++++++++++ tests/link/pytorch/test_basic.py | 6 +- tests/link/pytorch/test_subtensor.py | 186 ++++++++++++++++++++ 6 files changed, 345 insertions(+), 9 deletions(-) create mode 100644 pytensor/link/pytorch/dispatch/subtensor.py create mode 100644 tests/link/pytorch/test_subtensor.py diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 16019d4187..152ad3554d 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -471,6 +471,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): "BlasOpt", "fusion", "inplace", + "local_uint_constant_indices", ], ), ) diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index 0295a12e8e..fddded525a 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -7,7 +7,8 @@ import pytensor.link.pytorch.dispatch.elemwise import pytensor.link.pytorch.dispatch.math import pytensor.link.pytorch.dispatch.extra_ops +import pytensor.link.pytorch.dispatch.nlinalg import pytensor.link.pytorch.dispatch.shape import pytensor.link.pytorch.dispatch.sort -import pytensor.link.pytorch.dispatch.nlinalg +import pytensor.link.pytorch.dispatch.subtensor # isort: on diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index c71e1606bf..2cbb3631a9 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -1,24 +1,40 @@ from functools import singledispatch from types import NoneType +import numpy as np import torch from pytensor.compile.ops import DeepCopyOp from pytensor.graph.fg import FunctionGraph from pytensor.link.utils import fgraph_to_python from pytensor.raise_op import CheckAndRaise -from pytensor.tensor.basic import Alloc, AllocEmpty, ARange, Eye, Join, MakeVector +from pytensor.tensor.basic import ( + Alloc, + AllocEmpty, + ARange, + Eye, + Join, + MakeVector, + TensorFromScalar, +) @singledispatch -def pytorch_typify(data, dtype=None, **kwargs): - r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" +def pytorch_typify(data, **kwargs): + raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}") + + +@pytorch_typify.register(np.ndarray) +@pytorch_typify.register(torch.Tensor) +def pytorch_typify_tensor(data, dtype=None, **kwargs): return torch.as_tensor(data, dtype=dtype) +@pytorch_typify.register(slice) @pytorch_typify.register(NoneType) -def pytorch_typify_None(data, **kwargs): - return None +@pytorch_typify.register(np.number) +def pytorch_typify_no_conversion_needed(data, **kwargs): + return data @singledispatch @@ -132,3 +148,11 @@ def makevector(*x): return torch.tensor(x, dtype=torch_dtype) return makevector + + +@pytorch_funcify.register(TensorFromScalar) +def pytorch_funcify_TensorFromScalar(op, **kwargs): + def tensorfromscalar(x): + return torch.as_tensor(x) + + return tensorfromscalar diff --git a/pytensor/link/pytorch/dispatch/subtensor.py b/pytensor/link/pytorch/dispatch/subtensor.py new file mode 100644 index 0000000000..4f53ec29f7 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/subtensor.py @@ -0,0 +1,124 @@ +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.tensor.subtensor import ( + AdvancedIncSubtensor, + AdvancedIncSubtensor1, + AdvancedSubtensor, + AdvancedSubtensor1, + IncSubtensor, + Subtensor, + indices_from_subtensor, +) +from pytensor.tensor.type_other import MakeSlice, SliceType + + +def check_negative_steps(indices): + for index in indices: + if isinstance(index, slice): + if index.step is not None and index.step < 0: + raise NotImplementedError( + "Negative step sizes are not supported in Pytorch" + ) + + +@pytorch_funcify.register(Subtensor) +def pytorch_funcify_Subtensor(op, node, **kwargs): + idx_list = op.idx_list + + def subtensor(x, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) + check_negative_steps(indices) + return x[indices] + + return subtensor + + +@pytorch_funcify.register(MakeSlice) +def pytorch_funcify_makeslice(op, **kwargs): + def makeslice(*x): + return slice(x) + + return makeslice + + +@pytorch_funcify.register(AdvancedSubtensor1) +@pytorch_funcify.register(AdvancedSubtensor) +def pytorch_funcify_AdvSubtensor(op, node, **kwargs): + def advsubtensor(x, *indices): + check_negative_steps(indices) + return x[indices] + + return advsubtensor + + +@pytorch_funcify.register(IncSubtensor) +def pytorch_funcify_IncSubtensor(op, node, **kwargs): + idx_list = op.idx_list + inplace = op.inplace + if op.set_instead_of_inc: + + def set_subtensor(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) + check_negative_steps(indices) + if not inplace: + x = x.clone() + x[indices] = y + return x + + return set_subtensor + + else: + + def inc_subtensor(x, y, *flattened_indices): + indices = indices_from_subtensor(flattened_indices, idx_list) + check_negative_steps(indices) + if not inplace: + x = x.clone() + x[indices] += y + return x + + return inc_subtensor + + +@pytorch_funcify.register(AdvancedIncSubtensor) +@pytorch_funcify.register(AdvancedIncSubtensor1) +def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): + inplace = op.inplace + ignore_duplicates = getattr(op, "ignore_duplicates", False) + + if op.set_instead_of_inc: + + def adv_set_subtensor(x, y, *indices): + check_negative_steps(indices) + if not inplace: + x = x.clone() + x[indices] = y.type_as(x) + return x + + return adv_set_subtensor + + elif ignore_duplicates: + + def adv_inc_subtensor_no_duplicates(x, y, *indices): + check_negative_steps(indices) + if not inplace: + x = x.clone() + x[indices] += y.type_as(x) + return x + + return adv_inc_subtensor_no_duplicates + + else: + if any(isinstance(idx.type, SliceType) for idx in node.inputs[2:]): + raise NotImplementedError( + "IncSubtensor with potential duplicates indexes and slice indexing not implemented in PyTorch" + ) + + def adv_inc_subtensor(x, y, *indices): + # Not needed because slices aren't supported + # check_negative_steps(indices) + if not inplace: + x = x.clone() + x.index_put_(indices, y.type_as(x), accumulate=True) + return x + + return adv_inc_subtensor diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 27c1b1bd6a..89e6d8553d 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -66,10 +66,10 @@ def compare_pytorch_and_py( py_res = pytensor_py_fn(*test_inputs) if len(fgraph.outputs) > 1: - for j, p in zip(pytorch_res, py_res): - assert_fn(j.cpu(), p) + for pytorch_res_i, py_res_i in zip(pytorch_res, py_res): + assert_fn(pytorch_res_i.detach().cpu().numpy(), py_res_i) else: - assert_fn([pytorch_res[0].cpu()], py_res) + assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0]) return pytensor_torch_fn, pytorch_res diff --git a/tests/link/pytorch/test_subtensor.py b/tests/link/pytorch/test_subtensor.py new file mode 100644 index 0000000000..fb2b3390d3 --- /dev/null +++ b/tests/link/pytorch/test_subtensor.py @@ -0,0 +1,186 @@ +import contextlib + +import numpy as np +import pytest + +import pytensor.scalar as ps +import pytensor.tensor as pt +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor import inc_subtensor, set_subtensor +from pytensor.tensor import subtensor as pt_subtensor +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_Subtensor(): + shape = (3, 4, 5) + x_pt = pt.tensor("x", shape=shape, dtype="int") + x_np = np.arange(np.prod(shape)).reshape(shape) + + out_pt = x_pt[1, 2, 0] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + out_pt = x_pt[1:, 1, :] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + out_pt = x_pt[:2, 1, :] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + out_pt = x_pt[1:2, 1, :] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + # symbolic index + a_pt = ps.int64("a") + a_np = 1 + out_pt = x_pt[a_pt, 2, a_pt:2] + assert isinstance(out_pt.owner.op, pt_subtensor.Subtensor) + out_fg = FunctionGraph([x_pt, a_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np, a_np]) + + with pytest.raises( + NotImplementedError, match="Negative step sizes are not supported in Pytorch" + ): + out_pt = x_pt[::-1] + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + +def test_pytorch_AdvSubtensor(): + shape = (3, 4, 5) + x_pt = pt.tensor("x", shape=shape, dtype="int") + x_np = np.arange(np.prod(shape)).reshape(shape) + + out_pt = pt_subtensor.advanced_subtensor1(x_pt, [1, 2]) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor1) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + out_pt = x_pt[[1, 2], [2, 3]] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + out_pt = x_pt[[1, 2], 1:] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + out_pt = x_pt[[1, 2], :, [3, 4]] + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + out_pt = x_pt[[1, 2], None] + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + a_pt = ps.int64("a") + a_np = 2 + out_pt = x_pt[[1, a_pt], a_pt] + out_fg = FunctionGraph([x_pt, a_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np, a_np]) + + # boolean indices + out_pt = x_pt[np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool)] + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np]) + + a_pt = pt.tensor3("a", dtype="bool") + a_np = np.random.binomial(1, 0.5, size=(3, 4, 5)).astype(bool) + out_pt = x_pt[a_pt] + out_fg = FunctionGraph([x_pt, a_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_np, a_np]) + + with pytest.raises( + NotImplementedError, match="Negative step sizes are not supported in Pytorch" + ): + out_pt = x_pt[[1, 2], ::-1] + out_fg = FunctionGraph([x_pt], [out_pt]) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor) + compare_pytorch_and_py(out_fg, [x_np]) + + +@pytest.mark.parametrize("subtensor_op", [set_subtensor, inc_subtensor]) +def test_pytorch_IncSubtensor(subtensor_op): + x_pt = pt.tensor3("x") + x_test = np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX) + + st_pt = pt.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) + out_pt = subtensor_op(x_pt[1, 2, 3], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_test]) + + # Test different type update + st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32")) + out_pt = subtensor_op(x_pt[:2, 0, 0], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_test]) + + out_pt = subtensor_op(x_pt[0, 1:3, 0], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.IncSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_test]) + + +def inc_subtensor_ignore_duplicates(x, y): + return inc_subtensor(x, y, ignore_duplicates=True) + + +@pytest.mark.parametrize( + "advsubtensor_op", [set_subtensor, inc_subtensor, inc_subtensor_ignore_duplicates] +) +def test_pytorch_AvdancedIncSubtensor(advsubtensor_op): + rng = np.random.default_rng(42) + + x_pt = pt.tensor3("x") + x_test = (np.arange(3 * 4 * 5) + 1).reshape((3, 4, 5)).astype(config.floatX) + + st_pt = pt.as_tensor_variable( + rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) + ) + out_pt = advsubtensor_op(x_pt[np.r_[0, 2]], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_test]) + + # Repeated indices + out_pt = advsubtensor_op(x_pt[np.r_[0, 0]], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_test]) + + # Mixing advanced and basic indexing + if advsubtensor_op is inc_subtensor: + # PyTorch does not support `np.add.at` equivalent with slices + expectation = pytest.raises(NotImplementedError) + else: + expectation = contextlib.nullcontext() + st_pt = pt.as_tensor_variable(x_test[[0, 2], 0, :3]) + out_pt = advsubtensor_op(x_pt[[0, 0], 0, :3], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + with expectation: + compare_pytorch_and_py(out_fg, [x_test]) + + # Test different dtype update + st_pt = pt.as_tensor_variable(np.r_[-1.0, 0.0].astype("float32")) + out_pt = advsubtensor_op(x_pt[[0, 2], 0, 0], st_pt) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_test]) + + # Boolean indices + out_pt = advsubtensor_op(x_pt[x_pt > 5], 1.0) + assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([x_pt], [out_pt]) + compare_pytorch_and_py(out_fg, [x_test])