diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index 6170a02a98..ba57ea4d30 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -79,7 +79,6 @@ import logging import os import shlex -import time from pathlib import Path import numpy as np @@ -103,10 +102,8 @@ from pytensor.tensor import basic as ptb from pytensor.tensor.basic import expand_dims from pytensor.tensor.blas_headers import blas_header_text, blas_header_version -from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import add, mul, neg, sub, variadic_add from pytensor.tensor.shape import shape_padright, specify_broadcastable -from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor +from pytensor.tensor.type import DenseTensorType, tensor _logger = logging.getLogger("pytensor.tensor.blas") @@ -1148,322 +1145,6 @@ def c_code_cache_version(self): pprint.assign(gemm_no_inplace, FunctionPrinter(["gemm_no_inplace"])) -def res_is_a(fgraph, var, op, maxclients=None): - if maxclients is not None and var in fgraph.clients: - retval = len(fgraph.get_clients(var)) <= maxclients - else: - retval = True - - return var.owner and var.owner.op == op and retval - - -def _as_scalar(res, dtype=None): - """Return ``None`` or a `TensorVariable` of float type""" - if dtype is None: - dtype = config.floatX - if all(s == 1 for s in res.type.shape): - while res.owner and isinstance(res.owner.op, DimShuffle): - res = res.owner.inputs[0] - # may still have some number of True's - if res.type.ndim > 0: - rval = res.dimshuffle() - else: - rval = res - if rval.type.dtype in integer_dtypes: - # We check that the upcast of res and dtype won't change dtype. - # If dtype is float64, we will cast int64 to float64. - # This is valid when res is a scalar used as input to a dot22 - # as the cast of the scalar can be done before or after the dot22 - # and this will give the same result. - if pytensor.scalar.upcast(res.dtype, dtype) == dtype: - return ptb.cast(rval, dtype) - else: - return None - - return rval - - -def _is_real_matrix(res): - return ( - res.type.dtype in ("float16", "float32", "float64") - and res.type.ndim == 2 - and res.type.shape[0] != 1 - and res.type.shape[1] != 1 - ) # cope with tuple vs. list - - -def _is_real_vector(res): - return ( - res.type.dtype in ("float16", "float32", "float64") - and res.type.ndim == 1 - and res.type.shape[0] != 1 - ) - - -def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True): - # print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip - # EXPRESSION: (beta * L) + (alpha * M) - - # we've already checked the client counts, now just make the type check. - # if res_is_a(M, _dot22, 1): - if M.owner and M.owner.op == _dot22: - Ml, Mr = M.owner.inputs - rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)] - return rval, M - - # it also might be the case that there is a dimshuffle between the + - # and the dot22. local_dot_to_dot22 in particular will put in such things. - if ( - M.owner - and isinstance(M.owner.op, DimShuffle) - and M.owner.inputs[0].owner - and isinstance(M.owner.inputs[0].owner.op, Dot22) - ): - MM = M.owner.inputs[0] - if M.owner.op.new_order == (0,): - # it is making a column MM into a vector - MMl, MMr = MM.owner.inputs - g = gemm_no_inplace(L.dimshuffle(0, "x"), alpha, MMl, MMr, beta) - rval = [g.dimshuffle(0)] - return rval, MM - if M.owner.op.new_order == (1,): - # it is making a row MM into a vector - MMl, MMr = MM.owner.inputs - g = gemm_no_inplace(L.dimshuffle("x", 0), alpha, MMl, MMr, beta) - rval = [g.dimshuffle(1)] - return rval, MM - if len(M.owner.op.new_order) == 0: - # it is making a row MM into a vector - MMl, MMr = MM.owner.inputs - g = gemm_no_inplace(L.dimshuffle("x", "x"), alpha, MMl, MMr, beta) - rval = [g.dimshuffle()] - return rval, MM - - if recurse_flip: - return _beta_L_plus_alpha_M(fgraph, alpha, M, beta, L, recurse_flip=False) - else: - return False, False - - -def _gemm_canonicalize(fgraph, r, scale, rval, maxclients): - # Tries to interpret node as a sum of scalars * (vectors or matrices) - def scaled(thing): - if scale == 1: - return thing - if scale == -1 and thing.type.dtype != "bool": - return -thing - else: - return scale * thing - - if not isinstance(r.type, TensorType): - return None - - if (r.type.ndim not in (1, 2)) or r.type.dtype not in ( - "float16", - "float32", - "float64", - "complex64", - "complex128", - ): - rval.append(scaled(r)) - return rval - - if maxclients and len(fgraph.clients[r]) > maxclients: - rval.append((scale, r)) - return rval - - if r.owner and r.owner.op == sub: - _gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1) - _gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1) - - elif r.owner and r.owner.op == add: - for i in r.owner.inputs: - _gemm_canonicalize(fgraph, i, scale, rval, 1) - - elif r.owner and r.owner.op == neg: - _gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1) - - elif r.owner and r.owner.op == mul: - scalars = [] - vectors = [] - matrices = [] - for i in r.owner.inputs: - if all(s == 1 for s in i.type.shape): - while i.owner and isinstance(i.owner.op, DimShuffle): - i = i.owner.inputs[0] - if i.type.ndim > 0: - scalars.append(i.dimshuffle()) - else: - scalars.append(i) - elif _is_real_vector(i): - vectors.append(i) - elif _is_real_matrix(i): - matrices.append(i) - else: - # just put the original arguments as in the base case - rval.append((scale, r)) - return rval - if len(matrices) == 1: - assert len(vectors) == 0 - m = matrices[0] - if len(scalars) == 0: - _gemm_canonicalize(fgraph, m, scale, rval, 1) - elif len(scalars) == 1: - _gemm_canonicalize(fgraph, m, scaled(scalars[0]), rval, 1) - else: - _gemm_canonicalize( - fgraph, m, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 - ) - elif len(vectors) == 1: - assert len(matrices) == 0 - v = vectors[0] - if len(scalars) == 0: - _gemm_canonicalize(fgraph, v, scale, rval, 1) - elif len(scalars) == 1: - _gemm_canonicalize(fgraph, v, scaled(scalars[0]), rval, 1) - else: - _gemm_canonicalize( - fgraph, v, mul(scaled(scalars[0]), *scalars[1:]), rval, 1 - ) - else: # lets not open this up - rval.append((scale, r)) - else: - rval.append((scale, r)) - return rval - - -def _factor_canonicalized(lst): - # remove duplicates from canonicalized list - - # we only delete out of the right end of the list, - # once i has touched a list element, it is permantent - lst = list(lst) - # print 'FACTOR', lst - # for t in lst: - # if not isinstance(t, (list, tuple)): - # t = (t,) - # for e in t: - # try: - # pytensor.printing.debugprint(e) - # except TypeError: - # print e, type(e) - i = 0 - while i < len(lst) - 1: - try: - s_i, M_i = lst[i] - except Exception: - i += 1 - continue - - j = i + 1 - while j < len(lst): - try: - s_j, M_j = lst[j] - except Exception: - j += 1 - continue - - if M_i is M_j: - s_i = s_i + s_j - lst[i] = (s_i, M_i) - del lst[j] - else: - j += 1 - i += 1 - return lst - - -def _gemm_from_factored_list(fgraph, lst): - """ - Returns None, or a list to replace node.outputs. - - """ - lst2 = [] - # Remove the tuple that can't be cast correctly. - # This can happen when we try to cast a complex to a real - for sM in lst: - # Make every pair in list have matching dtypes - # sM can be a tuple of 2 elements or an PyTensor variable. - if isinstance(sM, tuple): - sm0, sm1 = sM - sm0 = ptb.as_tensor_variable(sm0) - if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype: - lst2.append((ptb.cast(sm0, sm1.dtype), sM[1])) - - lst = lst2 - - def item_to_var(t): - try: - s, M = t - except Exception: - return t - if s == 1: - return M - if s == -1: - return -M - return s * M - - # Try every pair in the sM_list, trying to turn it into a gemm operation - for i in range(len(lst) - 1): - s_i, M_i = lst[i] - - for j in range(i + 1, len(lst)): - s_j, M_j = lst[j] - - if not M_j.type.in_same_class(M_i.type): - continue - - # print 'TRYING', (s_i, M_i, s_j, M_j) - - gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M( - fgraph, s_i, M_i, s_j, M_j - ) - # print 'GOT IT', gemm_of_sM_list - if gemm_of_sM_list: - assert len(gemm_of_sM_list) == 1 - add_inputs = [ - item_to_var(input) for k, input in enumerate(lst) if k not in (i, j) - ] - add_inputs.extend(gemm_of_sM_list) - rval = [variadic_add(*add_inputs)] - return rval, old_dot22 - - -def _gemm_from_node2(fgraph, node): - """ - - TODO: In many expressions, there are many ways to turn it into a - gemm. For example dot(a,b) + c + d. This function should return all - of them, so that if one version of gemm causes a cycle in the graph, then - another application of gemm can be tried. - - """ - lst = [] - t0 = time.perf_counter() - _gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0) - t1 = time.perf_counter() - - if len(lst) > 1: - lst = _factor_canonicalized(lst) - t2 = time.perf_counter() - rval = _gemm_from_factored_list(fgraph, lst) - t3 = time.perf_counter() - - # It can happen that _factor_canonicalized and - # _gemm_from_factored_list return a node with an incorrect - # type. This happens in particular when one of the scalar - # factors forces the upcast of the whole expression. In that - # case, we simply skip that candidate for Gemm. This was - # discussed in - # http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5, - # but never made it into a trac ticket. - - if rval and rval[0][0].type.in_same_class(node.outputs[0].type): - return rval, t1 - t0, t2 - t1, t3 - t2 - - return None, t1 - t0, 0, 0 - - class Dot22(GemmRelated): """Compute a matrix-matrix product. diff --git a/tests/tensor/rewriting/test_blas.py b/tests/tensor/rewriting/test_blas.py index efd18c3831..d939ceedce 100644 --- a/tests/tensor/rewriting/test_blas.py +++ b/tests/tensor/rewriting/test_blas.py @@ -2,11 +2,39 @@ import pytest from pytensor import function +from pytensor import tensor as pt from pytensor.compile import get_default_mode -from pytensor.tensor import matmul, tensor, vectorize +from pytensor.graph import FunctionGraph +from pytensor.tensor import ( + col, + dscalar, + dvector, + matmul, + matrix, + mul, + neg, + row, + scalar, + sqrt, + tensor, + vector, + vectorize, +) from pytensor.tensor.blas import BatchedDot from pytensor.tensor.blockwise import Blockwise -from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot +from pytensor.tensor.elemwise import DimShuffle +from pytensor.tensor.rewriting.blas import ( + _as_scalar, + _factor_canonicalized, + _gemm_canonicalize, + _is_real_matrix, + res_is_a, + specialize_matmul_to_batched_dot, +) + + +def XYZab(): + return matrix(), matrix(), matrix(), scalar(), scalar() @pytest.mark.parametrize("valid_case", (True, False)) @@ -46,3 +74,136 @@ def core_np(x, y): vectorize_pt(x_test, y_test), vectorize_np(x_test, y_test), ) + + +def test_gemm_factor(): + X, Y = matrix("X"), matrix("Y") + + assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)]) + assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)]) + + +def test_gemm_canonicalize(): + X, Y, Z, a, b = ( + matrix("X"), + matrix("Y"), + matrix("Z"), + scalar("a"), + scalar("b"), + ) + c, d = scalar("c"), scalar("d") + u = row("u") + v = vector("v") + w = col("w") + + can = [] + fg = FunctionGraph([X, Y, Z], [X + Y + Z], clone=False) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + assert can == [(1.0, X), (1.0, Y), (1.0, Z)] + + can = [] + fg = FunctionGraph([X, Y, u], [X + Y + u], clone=False) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + assert can == [(1.0, X), (1.0, Y), (1.0, u)], can + + can = [] + fg = FunctionGraph([X, Y, v], [X + Y + v], clone=False) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + # [(1.0, X), (1.0, Y), (1.0, InplaceDimShuffle{x,0}(v))] + assert can[:2] == [(1.0, X), (1.0, Y)] + assert isinstance(can[2], tuple) + assert len(can[2]) == 2 + assert can[2][0] == 1.0 + assert can[2][1].owner + assert isinstance(can[2][1].owner.op, DimShuffle) + assert can[2][1].owner.inputs == [v] + + can = [] + fg = FunctionGraph([X, Y, w], [X + Y + w], clone=False) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + assert can == [(1.0, X), (1.0, Y), (1.0, w)], can + + can = [] + fg = FunctionGraph([a, X, Y, b, Z, c], [a * X + Y - b * Z * c], clone=False) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + assert can[0] == (a, X) + assert can[1] == (1.0, Y) + assert can[2][0].owner.op == mul + assert can[2][0].owner.inputs[0].owner.op == neg + assert can[2][0].owner.inputs[0].owner.inputs[0] == c + assert can[2][0].owner.inputs[1] == b + + can = [] + fg = FunctionGraph( + [a, X, Y, b, Z, c, d], [(-d) * X - (a * X + Y - b * Z * c)], clone=False + ) + _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) + assert can[0][0].owner.op == neg + assert can[0][0].owner.inputs[0] == d + assert can[0][1] == X + assert can[1][0].owner.op == neg + assert can[1][0].owner.inputs[0] == a + assert can[2] == (-1.0, Y) + assert can[3][0].owner.op == mul + assert can[3][0].owner.inputs == [c, b] + + +def test_res_is_a(): + X, Y, Z, a, b = XYZab() + + assert not res_is_a(None, a, sqrt) + assert not res_is_a(None, a + a, sqrt) + assert res_is_a(None, sqrt(a + a), sqrt) + + sqrt_term = sqrt(a + a) + fg = FunctionGraph([a], [2 * sqrt_term], clone=False) + assert res_is_a(fg, sqrt_term, sqrt, 2) + assert not res_is_a(fg, sqrt_term, sqrt, 0) + + +class TestAsScalar: + def test_basic(self): + # Test that it works on scalar constants + a = pt.constant(2.5) + b = pt.constant(np.asarray([[[0.5]]])) + b2 = b.dimshuffle() + assert b2.ndim == 0 + d_a = DimShuffle(input_ndim=0, new_order=[])(a) + d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b) + d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a) + + assert _as_scalar(a) == a + assert _as_scalar(b) != b + assert _as_scalar(d_a) != d_a + assert _as_scalar(d_b) != d_b + assert _as_scalar(d_a2) != d_a2 + + def test_basic_1(self): + # Test that it fails on nonscalar constants + a = pt.constant(np.ones(5)) + assert _as_scalar(a) is None + assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None + + def test_basic_2(self): + # Test that it works on scalar variables + a = dscalar() + d_a = DimShuffle(input_ndim=0, new_order=[])(a) + d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a) + + assert _as_scalar(a) is a + assert _as_scalar(d_a) is a + assert _as_scalar(d_a2) is a + + def test_basic_3(self): + # Test that it fails on nonscalar variables + a = matrix() + assert _as_scalar(a) is None + assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None + + +class TestRealMatrix: + def test_basic(self): + assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix())) + assert not _is_real_matrix( + DimShuffle(input_ndim=1, new_order=["x", 0])(dvector()) + ) diff --git a/tests/tensor/test_blas.py b/tests/tensor/test_blas.py index 1c0d707c11..1e4afb8928 100644 --- a/tests/tensor/test_blas.py +++ b/tests/tensor/test_blas.py @@ -16,7 +16,6 @@ from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.gradient import grad -from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import in2out from pytensor.graph.utils import InconsistencyError from pytensor.tensor import inplace @@ -28,12 +27,8 @@ Gemm, Gemv, Ger, - _as_scalar, _dot22, _dot22scalar, - _factor_canonicalized, - _gemm_canonicalize, - _is_real_matrix, batched_dot, batched_tensordot, gemm, @@ -44,19 +39,15 @@ gemv_no_inplace, ger, ger_destructive, - res_is_a, ) -from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import Dot, dot, mean, mul, neg, outer, sigmoid, sqrt +from pytensor.tensor.math import Dot, dot, mean, mul, outer, sigmoid from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger from pytensor.tensor.type import ( cmatrix, - col, cscalar, dmatrix, drow, dscalar, - dvector, fmatrix, fscalar, imatrix, @@ -65,7 +56,6 @@ ivector, matrices, matrix, - row, scalar, scalars, tensor, @@ -572,67 +562,6 @@ def test_gemm(self): self.run_gemm(dtype, alpha, beta, tA, tB, tC, sA, sB, sC, rng) -def test_res_is_a(): - X, Y, Z, a, b = XYZab() - - assert not res_is_a(None, a, sqrt) - assert not res_is_a(None, a + a, sqrt) - assert res_is_a(None, sqrt(a + a), sqrt) - - sqrt_term = sqrt(a + a) - fg = FunctionGraph([a], [2 * sqrt_term], clone=False) - assert res_is_a(fg, sqrt_term, sqrt, 2) - assert not res_is_a(fg, sqrt_term, sqrt, 0) - - -class TestAsScalar: - def test_basic(self): - # Test that it works on scalar constants - a = pt.constant(2.5) - b = pt.constant(np.asarray([[[0.5]]])) - b2 = b.dimshuffle() - assert b2.ndim == 0 - d_a = DimShuffle(input_ndim=0, new_order=[])(a) - d_b = DimShuffle(input_ndim=3, new_order=[0, 2, 1])(b) - d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x", "x"])(a) - - assert _as_scalar(a) == a - assert _as_scalar(b) != b - assert _as_scalar(d_a) != d_a - assert _as_scalar(d_b) != d_b - assert _as_scalar(d_a2) != d_a2 - - def test_basic_1(self): - # Test that it fails on nonscalar constants - a = pt.constant(np.ones(5)) - assert _as_scalar(a) is None - assert _as_scalar(DimShuffle(input_ndim=1, new_order=[0, "x"])(a)) is None - - def test_basic_2(self): - # Test that it works on scalar variables - a = dscalar() - d_a = DimShuffle(input_ndim=0, new_order=[])(a) - d_a2 = DimShuffle(input_ndim=0, new_order=["x", "x"])(a) - - assert _as_scalar(a) is a - assert _as_scalar(d_a) is a - assert _as_scalar(d_a2) is a - - def test_basic_3(self): - # Test that it fails on nonscalar variables - a = matrix() - assert _as_scalar(a) is None - assert _as_scalar(DimShuffle(input_ndim=2, new_order=[0, "x", 1])(a)) is None - - -class TestRealMatrix: - def test_basic(self): - assert _is_real_matrix(DimShuffle(input_ndim=2, new_order=[1, 0])(matrix())) - assert not _is_real_matrix( - DimShuffle(input_ndim=1, new_order=["x", 0])(dvector()) - ) - - """ This test suite ensures that Gemm is inserted where it belongs, and that the resulting functions compute the same things as the originals. @@ -774,78 +703,6 @@ def test_gemm_opt_double_gemm(): assert max_abs_err <= eps, "GEMM is computing the wrong output. max_rel_err =" -def test_gemm_canonicalize(): - X, Y, Z, a, b = ( - matrix("X"), - matrix("Y"), - matrix("Z"), - scalar("a"), - scalar("b"), - ) - c, d = scalar("c"), scalar("d") - u = row("u") - v = vector("v") - w = col("w") - - can = [] - fg = FunctionGraph([X, Y, Z], [X + Y + Z], clone=False) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - assert can == [(1.0, X), (1.0, Y), (1.0, Z)] - - can = [] - fg = FunctionGraph([X, Y, u], [X + Y + u], clone=False) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - assert can == [(1.0, X), (1.0, Y), (1.0, u)], can - - can = [] - fg = FunctionGraph([X, Y, v], [X + Y + v], clone=False) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - # [(1.0, X), (1.0, Y), (1.0, InplaceDimShuffle{x,0}(v))] - assert can[:2] == [(1.0, X), (1.0, Y)] - assert isinstance(can[2], tuple) - assert len(can[2]) == 2 - assert can[2][0] == 1.0 - assert can[2][1].owner - assert isinstance(can[2][1].owner.op, DimShuffle) - assert can[2][1].owner.inputs == [v] - - can = [] - fg = FunctionGraph([X, Y, w], [X + Y + w], clone=False) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - assert can == [(1.0, X), (1.0, Y), (1.0, w)], can - - can = [] - fg = FunctionGraph([a, X, Y, b, Z, c], [a * X + Y - b * Z * c], clone=False) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - assert can[0] == (a, X) - assert can[1] == (1.0, Y) - assert can[2][0].owner.op == mul - assert can[2][0].owner.inputs[0].owner.op == neg - assert can[2][0].owner.inputs[0].owner.inputs[0] == c - assert can[2][0].owner.inputs[1] == b - - can = [] - fg = FunctionGraph( - [a, X, Y, b, Z, c, d], [(-d) * X - (a * X + Y - b * Z * c)], clone=False - ) - _gemm_canonicalize(fg, fg.outputs[0], 1.0, can, 0) - assert can[0][0].owner.op == neg - assert can[0][0].owner.inputs[0] == d - assert can[0][1] == X - assert can[1][0].owner.op == neg - assert can[1][0].owner.inputs[0] == a - assert can[2] == (-1.0, Y) - assert can[3][0].owner.op == mul - assert can[3][0].owner.inputs == [c, b] - - -def test_gemm_factor(): - X, Y = matrix("X"), matrix("Y") - - assert [(1.0, X), (1.0, Y)] == _factor_canonicalized([(1.0, X), (1.0, Y)]) - assert [(2.0, X)] == _factor_canonicalized([(1.0, X), (1.0, X)]) - - def test_upcasting_scalar_nogemm(): # Test that the optimization does not crash when the scale has an incorrect # dtype, and forces upcasting of the result