From 58ef3ac05624faa71cbe5ce4eeaee45dc4b15c0d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 3 Feb 2025 17:39:11 +0100 Subject: [PATCH] Simplify implementation of tile Deprecate obscure ndim kwarg --- pytensor/tensor/basic.py | 117 ++++++++--------- tests/tensor/test_basic.py | 258 +++++++++++++------------------------ 2 files changed, 144 insertions(+), 231 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 401642ddb9..5418cde744 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -10,7 +10,7 @@ from collections.abc import Sequence from functools import partial from numbers import Number -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union from typing import cast as type_cast import numpy as np @@ -33,7 +33,7 @@ from pytensor.link.c.op import COp from pytensor.link.c.params_type import ParamsType from pytensor.printing import Printer, min_informative_str, pprint, set_precedence -from pytensor.raise_op import CheckAndRaise, assert_op +from pytensor.raise_op import CheckAndRaise from pytensor.scalar import int32 from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable from pytensor.tensor import ( @@ -3084,7 +3084,9 @@ def flatten(x, ndim=1): return x_reshaped -def tile(x, reps, ndim=None): +def tile( + A: "TensorLike", reps: Union[Sequence[int, "TensorLike"], "TensorLike"] +) -> TensorVariable: """ Tile input array `x` according to `reps`. @@ -3094,77 +3096,62 @@ def tile(x, reps, ndim=None): symbolic scalar (e.g. tensor.iscalar()), symbolic vector (e.g. tensor.ivector()) or a list of symbolic scalar (e.g. [tensor.iscalar(), tensor.iscalar()]). - ndim is the number of the dimensions of the output, if it is provided, ndim - should be equal or larger than x.ndim and len(reps), otherwise, we will use - max(x.ndim, len(reps)) as ndim. If reps is symbolic vector, the ndim has to - be provided. - """ - from pytensor.tensor.math import ge - _x = as_tensor_variable(x) - if ndim is not None and ndim < _x.ndim: - raise ValueError("ndim should be equal or larger than _x.ndim") + A = as_tensor_variable(A) - # If reps is a scalar, integer or vector, we convert it to a list. + # Convert symbolic reps to a tuple if not isinstance(reps, list | tuple): - reps_astensor = as_tensor_variable(reps) - ndim_check = reps_astensor.ndim - if reps_astensor.dtype not in discrete_dtypes: - raise ValueError("elements of reps must be integer dtype") - - # The scalar/integer case - if ndim_check == 0: - reps = [reps] - - # The vector case - elif ndim_check == 1: - if ndim is None: + reps = as_tensor_variable(reps) + if reps.type.ndim == 0: + reps = (reps,) + elif reps.type.ndim == 1: + try: + reps = tuple(reps) + except ValueError: raise ValueError( - "if reps is tensor.vector, you should specify the ndim" + "Length of repetitions tensor cannot be determined. Use specify_shape to set the length." ) - else: - offset = ndim - reps.shape[0] - - # assert that reps.shape[0] does not exceed ndim - offset = assert_op(offset, ge(offset, 0)) + else: + raise ValueError( + f"Repetitions tensor must be a scalar or a vector, got ndim={reps.type.ndim}" + ) - # if reps.ndim is less than _x.ndim, we pad the reps with - # "1" so that reps will have the same ndim as _x. - reps_ = [switch(i < offset, 1, reps[i - offset]) for i in range(ndim)] - reps = reps_ + reps = [as_tensor_variable(rep) for rep in reps] + if not all( + rep.type.ndim == 0 and rep.type.dtype in discrete_dtypes for rep in reps + ): + raise ValueError( + f"All reps entries shoud be scalar integers, got {reps} of type {[rep.type for rep in reps]}" + ) - # For others, raise an error - else: - raise ValueError("the dimension of reps should not exceed 1") - else: - if ndim is not None and len(reps) > ndim: - raise ValueError("len(reps) should be equal or less than ndim") - if not all( - isinstance(r, int) - or (isinstance(r, TensorVariable) and r.dtype in discrete_dtypes) - for r in reps - ): - raise ValueError("elements of reps must be scalars of integer dtype") + len_reps = len(reps) + out_ndim = builtins.max(len_reps, A.type.ndim) + + # Pad reps on the left (if needed) + if len_reps < out_ndim: + reps = (*((1,) * (out_ndim - len_reps)), *reps) + + # Pad A's shape on the left (if needed) + elif A.type.ndim < out_ndim: + A = shape_padleft(A, out_ndim - A.type.ndim) + + # Expand every other dim of A and expand n-reps via Alloc + # A_replicated = alloc(A[None, :, ..., None, :], reps[0], A.shape[0], ..., reps[-1], A.shape[-1]) + A_shape = A.shape + interleaved_reps_shape = [ + d for pair in zip(reps, A.shape, strict=True) for d in pair + ] + every_other_axis = tuple(range(0, out_ndim * 2, 2)) + A_replicated = alloc( + expand_dims(A, every_other_axis), + *interleaved_reps_shape, + ) - # If reps.ndim is less than _x.ndim, we pad the reps with - # "1" so that reps will have the same ndim as _x - reps = list(reps) - if ndim is None: - ndim = builtins.max(len(reps), _x.ndim) - if len(reps) < ndim: - reps = [1] * (ndim - len(reps)) + reps - - _shape = [1] * (ndim - _x.ndim) + [_x.shape[i] for i in range(_x.ndim)] - alloc_shape = reps + _shape - y = alloc(_x, *alloc_shape) - shuffle_ind = np.arange(ndim * 2).reshape(2, ndim) - shuffle_ind = shuffle_ind.transpose().flatten() - y = y.dimshuffle(*shuffle_ind) - new_shapes = [sh * reps[i] for i, sh in enumerate(_shape)] - y = y.reshape(new_shapes) - - return y + # Combine replicate and original dimensions via reshape + # A_tiled = A_replicated.reshape(reps[0] * A.shape[0], ..., reps[-1] * A.shape[-1]) + tiled_shape = tuple(rep * A_dim for rep, A_dim in zip(reps, A_shape, strict=True)) + return A_replicated.reshape(tiled_shape) class ARange(Op): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 754859fa6f..06627c3206 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -2386,194 +2386,120 @@ def test_is_flat(): assert not ptb.is_flat(X.reshape((iscalar(),) * 3)) -def test_tile(): - """ - TODO FIXME: Split this apart and parameterize. Also, find out why it's - unreasonably slow. - """ +class TestTile: + @pytest.mark.parametrize( + "A_shape, reps_test", + [ + ((), (2,)), + ((5,), (2,)), + ((2, 4), (2, 3)), + ((2, 4), (2, 3, 4)), + ( + (2, 4, 3), + ( + 2, + 3, + ), + ), + ((2, 4, 3), (2, 3, 4)), + ((2, 4, 3, 5), (2, 3, 4, 6)), + ], + ) + def test_tile_separate_reps_entries(self, A_shape, reps_test): + rng = np.random.default_rng(2400) - def run_tile(x, x_, reps, use_symbolic_reps): - if use_symbolic_reps: - rep_symbols = [iscalar() for _ in range(len(reps))] - f = function([x, *rep_symbols], tile(x, rep_symbols)) - return f(*([x_, *reps])) - else: - f = function([x], tile(x, reps)) - return f(x_) + A = tensor("A", shape=(None,) * len(A_shape)) + reps = [iscalar(f"r{i}") for i in range(len(reps_test))] + tile_out = tile(A, reps) - rng = np.random.default_rng(utt.fetch_seed()) + tile_fn = function([A, *reps], tile_out) - for use_symbolic_reps in [False, True]: - # Test the one-dimensional case. - x = vector() - x_ = rng.standard_normal(5).astype(config.floatX) - assert np.all(run_tile(x, x_, (2,), use_symbolic_reps) == np.tile(x_, (2,))) + A_test = rng.standard_normal(A_shape).astype(config.floatX) + np.testing.assert_allclose( + tile_fn(A_test, *reps_test), + np.tile(A_test, reps_test), + ) - # Test the two-dimensional case. - x = matrix() - x_ = rng.standard_normal((2, 4)).astype(config.floatX) - assert np.all(run_tile(x, x_, (2, 3), use_symbolic_reps) == np.tile(x_, (2, 3))) - - # Test the three-dimensional case. - x = tensor3() - x_ = rng.standard_normal((2, 4, 3)).astype(config.floatX) - assert np.all( - run_tile(x, x_, (2, 3, 4), use_symbolic_reps) == np.tile(x_, (2, 3, 4)) + @pytest.mark.parametrize("reps", (2, np.array([2, 3, 4]))) + def test_combined_reps_entries(self, reps): + rng = np.random.default_rng(2422) + A_test = rng.standard_normal((2, 4, 3)).astype(config.floatX) + expected_eval = np.tile(A_test, reps) + + A = tensor3("A") + np.testing.assert_allclose( + tile(A, reps).eval({A: A_test}), + expected_eval, ) - # Test the four-dimensional case. - x = tensor4() - x_ = rng.standard_normal((2, 4, 3, 5)).astype(config.floatX) - assert np.all( - run_tile(x, x_, (2, 3, 4, 6), use_symbolic_reps) - == np.tile(x_, (2, 3, 4, 6)) + sym_reps = as_tensor_variable(reps).type() + np.testing.assert_allclose( + tile(A, sym_reps).eval({A: A_test, sym_reps: reps}), + expected_eval, ) - # Test passing a float - x = scalar() - x_val = 1.0 - assert np.array_equal( - run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,)) + def test_mixed_reps_type(self): + A = np.arange(9).reshape(3, 3) + reps = [2, iscalar("3"), 4] + np.testing.assert_allclose( + tile(A, reps).eval({"3": 3}), + np.tile(A, [2, 3, 4]), ) + def test_tensorlike_A(self): # Test when x is a list - x = matrix() x_val = [[1.0, 2.0], [3.0, 4.0]] - assert np.array_equal( - run_tile(x, x_val, (2,), use_symbolic_reps), np.tile(x_val, (2,)) + assert equal_computations( + [tile(x_val, (2,))], + [tile(as_tensor_variable(x_val), (2,))], ) - # Test when reps is integer, scalar or vector. - # Test 1,2,3,4-dimensional cases. - # Test input x has the shape [2], [2, 4], [2, 4, 3], [2, 4, 3, 5]. - test_shape = [2, 4, 3, 5] - k = 0 - for xtype in [vector(), matrix(), tensor3(), tensor4()]: - x = xtype - k = k + 1 - x_ = rng.standard_normal(test_shape[0:k]).astype(config.floatX) - - # integer: - reps_ = 2 - f = function([x], tile(x, reps_)) - assert np.all(f(x_) == np.tile(x_, reps_)) - - # scalar: - reps = iscalar() - reps_ = 2 - f = function([x, reps], tile(x, reps)) - assert np.all(f(x_, reps_) == np.tile(x_, reps_)) - - # vector: - reps = ivector() - reps_ = [2] if k == 1 or k == 2 else [2, 3] - ndim_ = k - f = function([x, reps], tile(x, reps, ndim_)) - assert np.all(f(x_, reps_) == np.tile(x_, reps_)) - - # list of integers: - reps_ = [2, 3, 4] - f = function([x], tile(x, reps_)) - assert np.all(f(x_) == np.tile(x_, reps_)) - - # list of integers and scalars: - d = iscalar() - reps = [2, d, 4] - f = function([x, d], tile(x, reps)) - reps_ = [2, 3, 4] - assert np.all(f(x_, 3) == np.tile(x_, reps_)) - - # reps is list, len(reps) > x.ndim, 3 cases below: - r = [2, 3, 4, 5, 6] - reps_ = r[: k + 1] # len(reps_) = x.ndim+1 - # (1) ndim = None. - f = function([x], tile(x, reps_)) - assert np.all(f(x_) == np.tile(x_, reps_)) - # (2) ndim = len(reps). - ndim_ = len(reps_) - f = function([x], tile(x, reps_, ndim_)) - assert np.all(f(x_) == np.tile(x_, reps_)) - # (3) ndim > len(reps) - ndim_ = len(reps_) + 1 - f = function([x], tile(x, reps_, ndim_)) - assert np.all(f(x_) == np.tile(x_, [1, *reps_])) - - # reps is list, ndim > x.ndim > len(reps): - r = [2, 3, 4, 5] - if k > 1: - ndim_ = k + 1 - reps_ = r[: k - 1] - f = function([x], tile(x, reps_, ndim_)) - assert np.all(f(x_) == np.tile(x_, [1, 1, *reps_])) - + def test_error_unknown_reps_length(self): # error raising test: ndim not specified when reps is vector reps = ivector() - with pytest.raises(ValueError): - tile(x, reps) + with pytest.raises(ValueError, match="Use specify_shape to set the length"): + tile(arange(3), reps) + + # fine with specify_shape + out = tile(arange(3), specify_shape(reps, 2)) + np.testing.assert_allclose( + out.eval({reps: [2, 3]}), + np.tile(np.arange(3), [2, 3]), + ) - # error raising test: not a integer - for reps in [2.5, fscalar(), fvector()]: + def test_error_non_integer_reps(self): + for reps in ( + 2.5, + fscalar(), + vector(shape=(3,), dtype="float64"), + [2, fscalar()], + ): with pytest.raises(ValueError): - tile(x, reps) + tile(arange(3), reps) - # error raising test: the dimension of reps exceeds 1 - reps = imatrix() - with pytest.raises(ValueError): - tile(x, reps) - - # error raising test: ndim is not None, ndim < x.ndim - # 3 cases below (reps is list/scalar/vector): - for reps in [[2, 3, 4], iscalar(), ivector()]: - if k > 1: - ndim = k - 1 - with pytest.raises(ValueError): - tile(x, reps, ndim) - - # error raising test: reps is list, len(reps) > ndim - r = [2, 3, 4, 5, 6] - reps = r[: k + 1] - ndim = k - with pytest.raises(ValueError): - tile(x, reps, ndim) + def test_error_reps_ndim(self): + for reps in ( + matrix(shape=(3, 1), dtype=int), + [2, vector(shape=(2,), dtype=int)], + ): + with pytest.raises(ValueError): + tile(arange(3), reps) + + def test_tile_grad(self): + A = tensor3("A") + reps = vector("reps", shape=(3,), dtype=int) + A_tile = tile(A, reps) + grad_tile = grad(A_tile.sum(), A) - # error raising test: - # reps is vector and len(reps_value) > ndim, - # reps_value is the real value when executing the function. - reps = ivector() - r = [2, 3, 4, 5, 6, 7] - reps_ = r[: k + 2] - ndim_ = k + 1 - f = function([x, reps], tile(x, reps, ndim_)) - with pytest.raises(AssertionError): - f(x_, reps_) - - -def test_tile_grad(): - def grad_tile(x, reps, np_x): - y = tile(x, reps) - z = y.sum() - g = pytensor.function([x], grad(z, x)) - grad_res = g(np_x) # The gradient should be the product of the tiling dimensions # (since the gradients are additive through the tiling operation) - assert np.all(grad_res == np.prod(reps)) - - rng = np.random.default_rng(utt.fetch_seed()) - - # test vector - grad_tile(vector("x"), [3], rng.standard_normal(5).astype(config.floatX)) - # test matrix - grad_tile(matrix("x"), [3, 4], rng.standard_normal((2, 3)).astype(config.floatX)) - # test tensor3 - grad_tile( - tensor3("x"), [3, 4, 5], rng.standard_normal((2, 4, 3)).astype(config.floatX) - ) - # test tensor4 - grad_tile( - tensor4("x"), - [3, 4, 5, 6], - rng.standard_normal((2, 4, 3, 5)).astype(config.floatX), - ) + rng = np.random.default_rng(2489) + A_test = rng.normal(size=(2, 4, 3)).astype(config.floatX) + reps_test = [3, 4, 5] + np.testing.assert_allclose( + grad_tile.eval({A: A_test, reps: reps_test}), + np.full(A_test.shape, np.prod(reps_test)), + ) class TestARange: