diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 96f4daefba..3ab2960562 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -22,7 +22,7 @@ from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod +from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod from pytensor.tensor.nlinalg import ( SVD, KroneckerProduct, @@ -818,3 +818,72 @@ def rewrite_slogdet_blockdiag(fgraph, node): ) return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] + + +@register_canonicalize +@register_stabilize +@node_rewriter([ExtractDiag]) +def rewrite_diag_kronecker(fgraph, node): + """ + This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector. + + diag(kron(a,b)) -> outer(diag(a), diag(b)) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner kron operation + potential_kron = node.inputs[0].owner + if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): + return None + + # Find the matrices + a, b = potential_kron.inputs + diag_a, diag_b = diag(a), diag(b) + outer_prod_as_vector = outer(diag_a, diag_b).flatten() + + return [outer_prod_as_vector] + + +@register_canonicalize +@register_stabilize +@node_rewriter([slogdet]) +def rewrite_slogdet_kronecker(fgraph, node): + """ + This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner kron operation + potential_kron = node.inputs[0].owner + if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): + return None + + # Find the matrices + a, b = potential_kron.inputs + signs, logdets = zip(*[slogdet(a), slogdet(b)]) + sizes = [a.shape[-1], b.shape[-1]] + prod_sizes = prod(sizes, no_zeros_in_input=True) + signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] + logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] + + return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 133e8d6a31..211facb484 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -751,3 +751,55 @@ def test_slogdet_blockdiag_rewrite(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +def test_diag_kronecker_rewrite(): + a, b = pt.dmatrices("a", "b") + kron_prod = pt.linalg.kron(a, b) + diag_kron_prod = pt.diag(kron_prod) + f_rewritten = function([a, b], diag_kron_prod, mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, KroneckerProduct) for node in nodes) + + # Value Test + a_test, b_test = np.random.rand(2, 20, 20) + kron_prod_test = np.kron(a_test, b_test) + diag_kron_prod_test = np.diag(kron_prod_test) + rewritten_val = f_rewritten(a_test, b_test) + assert_allclose( + diag_kron_prod_test, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_slogdet_kronecker_rewrite(): + a, b = pt.dmatrices("a", "b") + kron_prod = pt.linalg.kron(a, b) + sign_output, logdet_output = pt.linalg.slogdet(kron_prod) + f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, KroneckerProduct) for node in nodes) + + # Value Test + a_test, b_test = np.random.rand(2, 20, 20) + kron_prod_test = np.kron(a_test, b_test) + sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test) + rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test) + assert_allclose( + sign_output_test, + rewritten_sign_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + logdet_output_test, + rewritten_logdet_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + )