From 032cbca5a7739a74eac9036ccdc05138de3a0bc7 Mon Sep 17 00:00:00 2001 From: Tanish Date: Sat, 17 Aug 2024 12:31:15 +0530 Subject: [PATCH 1/3] Added rewrite for diag of kronecker product --- pytensor/tensor/rewriting/linalg.py | 19 ++++++++++++++++++- tests/tensor/rewriting/test_linalg.py | 23 +++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 96f4daefba..fb3beaa9cd 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -22,7 +22,7 @@ from pytensor.tensor.blas import Dot22 from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, prod +from pytensor.tensor.math import Dot, Prod, _matrix_matrix_matmul, log, outer, prod from pytensor.tensor.nlinalg import ( SVD, KroneckerProduct, @@ -818,3 +818,20 @@ def rewrite_slogdet_blockdiag(fgraph, node): ) return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] + + +@register_canonicalize +@register_stabilize +@node_rewriter([ExtractDiag]) +def rewrite_diag_kronecker(fgraph, node): + # Check for inner kron operation + potential_kron = node.inputs[0].owner + if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): + return None + + # Find the matrices + a, b = potential_kron.inputs + diag_a, diag_b = diag(a), diag(b) + outer_prod_as_vector = outer(diag_a, diag_b).flatten() + + return [outer_prod_as_vector] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 133e8d6a31..42a31f65d8 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -751,3 +751,26 @@ def test_slogdet_blockdiag_rewrite(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +def test_diag_kronecker_rewrite(): + a, b = pt.dmatrices("a", "b") + kron_prod = pt.linalg.kron(a, b) + diag_kron_prod = pt.diag(kron_prod) + f_rewritten = function([a, b], diag_kron_prod, mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, KroneckerProduct) for node in nodes) + + # Value Test + a_test, b_test = np.random.rand(2, 20, 20) + kron_prod_test = np.kron(a_test, b_test) + diag_kron_prod_test = np.diag(kron_prod_test) + rewritten_val = f_rewritten(a_test, b_test) + assert_allclose( + diag_kron_prod_test, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) From d76c1b2b087c0c3dd4e357c533755c5f34aa8dc3 Mon Sep 17 00:00:00 2001 From: Tanish Date: Mon, 19 Aug 2024 14:25:07 +0530 Subject: [PATCH 2/3] Added rewrite for slogdet; added docstrings for rewrites --- pytensor/tensor/rewriting/linalg.py | 52 +++++++++++++++++++++++++++ tests/tensor/rewriting/test_linalg.py | 29 +++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index fb3beaa9cd..bf436a4fff 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -824,6 +824,23 @@ def rewrite_slogdet_blockdiag(fgraph, node): @register_stabilize @node_rewriter([ExtractDiag]) def rewrite_diag_kronecker(fgraph, node): + """ + This rewrite simplifies the diagonal of the kronecker product of 2 matrices by extracting the individual sub matrices and returning their outer product as a vector. + + diag(kron(a,b)) -> outer(diag(a), diag(b)) + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ # Check for inner kron operation potential_kron = node.inputs[0].owner if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): @@ -835,3 +852,38 @@ def rewrite_diag_kronecker(fgraph, node): outer_prod_as_vector = outer(diag_a, diag_b).flatten() return [outer_prod_as_vector] + + +@register_canonicalize +@register_stabilize +@node_rewriter([slogdet]) +def rewrite_slogdet_kronecker(fgraph, node): + """ + This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check for inner kron operation + potential_kron = node.inputs[0].owner + if not (potential_kron and isinstance(potential_kron.op, KroneckerProduct)): + return None + + # Find the matrices + a, b = potential_kron.inputs + signs, logdets = zip(*[slogdet(a), slogdet(b)]) + sizes = [a.shape[-1], b.shape[-1]] + prod_sizes = prod(sizes, no_zeros_in_input=True) + signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] + logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] + + return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 42a31f65d8..211facb484 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -774,3 +774,32 @@ def test_diag_kronecker_rewrite(): atol=1e-3 if config.floatX == "float32" else 1e-8, rtol=1e-3 if config.floatX == "float32" else 1e-8, ) + + +def test_slogdet_kronecker_rewrite(): + a, b = pt.dmatrices("a", "b") + kron_prod = pt.linalg.kron(a, b) + sign_output, logdet_output = pt.linalg.slogdet(kron_prod) + f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, KroneckerProduct) for node in nodes) + + # Value Test + a_test, b_test = np.random.rand(2, 20, 20) + kron_prod_test = np.kron(a_test, b_test) + sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test) + rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test) + assert_allclose( + sign_output_test, + rewritten_sign_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + logdet_output_test, + rewritten_logdet_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) From 329353dd55bad1ffc5bfe638ee64d28f5816aa2d Mon Sep 17 00:00:00 2001 From: Tanish Date: Sat, 31 Aug 2024 13:15:10 +0530 Subject: [PATCH 3/3] fixed typo --- pytensor/tensor/rewriting/linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index bf436a4fff..3ab2960562 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -859,7 +859,7 @@ def rewrite_diag_kronecker(fgraph, node): @node_rewriter([slogdet]) def rewrite_slogdet_kronecker(fgraph, node): """ - This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those + This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those Parameters ----------