diff --git a/pytensor/tensor/nlinalg.py b/pytensor/tensor/nlinalg.py index e7093a82bd..47c6699cca 100644 --- a/pytensor/tensor/nlinalg.py +++ b/pytensor/tensor/nlinalg.py @@ -11,6 +11,7 @@ from pytensor.gradient import DisconnectedType from pytensor.graph.basic import Apply from pytensor.graph.op import Op +from pytensor.tensor import TensorLike from pytensor.tensor import basic as ptb from pytensor.tensor import math as ptm from pytensor.tensor.basic import as_tensor_variable, diagonal @@ -266,7 +267,33 @@ def __str__(self): return "SLogDet" -slogdet = Blockwise(SLogDet()) +def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]: + """ + Compute the sign and (natural) logarithm of the determinant of an array. + + Returns a naive graph which is optimized later using rewrites with the det operation. + + Parameters + ---------- + x : (..., M, M) tensor or tensor_like + Input tensor, has to be square. + + Returns + ------- + A tuple with the following attributes: + + sign : (...) tensor_like + A number representing the sign of the determinant. For a real matrix, + this is 1, 0, or -1. + logabsdet : (...) tensor_like + The natural log of the absolute value of the determinant. + + If the determinant is zero, then `sign` will be 0 and `logabsdet` + will be -inf. In all cases, the determinant is equal to + ``sign * exp(logabsdet)``. + """ + det_val = det(x) + return ptm.sign(det_val), ptm.log(ptm.abs(det_val)) class Eig(Op): diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index a2418147cf..cd202fe3ed 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -2,6 +2,8 @@ from collections.abc import Callable from typing import cast +import numpy as np + from pytensor import Variable from pytensor import tensor as pt from pytensor.compile import optdb @@ -11,7 +13,7 @@ in2out, node_rewriter, ) -from pytensor.scalar.basic import Mul +from pytensor.scalar.basic import Abs, Log, Mul, Sign from pytensor.tensor.basic import ( AllocDiag, ExtractDiag, @@ -30,11 +32,11 @@ KroneckerProduct, MatrixInverse, MatrixPinv, + SLogDet, det, inv, kron, pinv, - slogdet, svd, ) from pytensor.tensor.rewriting.basic import ( @@ -785,45 +787,6 @@ def rewrite_det_blockdiag(fgraph, node): return [prod(det_sub_matrices)] -@register_canonicalize -@register_stabilize -@node_rewriter([slogdet]) -def rewrite_slogdet_blockdiag(fgraph, node): - """ - This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those - - slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....) - - Parameters - ---------- - fgraph: FunctionGraph - Function graph being optimized - node: Apply - Node of the function graph to be optimized - - Returns - ------- - list of Variable, optional - List of optimized variables, or None if no optimization was performed - """ - # Check for inner block_diag operation - potential_block_diag = node.inputs[0].owner - if not ( - potential_block_diag - and isinstance(potential_block_diag.op, Blockwise) - and isinstance(potential_block_diag.op.core_op, BlockDiagonal) - ): - return None - - # Find the composing sub_matrices - sub_matrices = potential_block_diag.inputs - sign_sub_matrices, logdet_sub_matrices = zip( - *[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))] - ) - - return [prod(sign_sub_matrices), sum(logdet_sub_matrices)] - - @register_canonicalize @register_stabilize @node_rewriter([ExtractDiag]) @@ -860,10 +823,10 @@ def rewrite_diag_kronecker(fgraph, node): @register_canonicalize @register_stabilize -@node_rewriter([slogdet]) -def rewrite_slogdet_kronecker(fgraph, node): +@node_rewriter([det]) +def rewrite_det_kronecker(fgraph, node): """ - This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those + This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those Parameters ---------- @@ -884,13 +847,12 @@ def rewrite_slogdet_kronecker(fgraph, node): # Find the matrices a, b = potential_kron.inputs - signs, logdets = zip(*[slogdet(a), slogdet(b)]) + dets = [det(a), det(b)] sizes = [a.shape[-1], b.shape[-1]] prod_sizes = prod(sizes, no_zeros_in_input=True) - signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)] - logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)] + det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)]) - return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)] + return [det_final] @register_canonicalize @@ -989,3 +951,65 @@ def jax_bilinaer_lyapunov_to_direct(fgraph: FunctionGraph, node: Apply): "jax", position=0.9, # Run before canonicalization ) + + +@register_specialize +@node_rewriter([det]) +def slogdet_specialization(fgraph, node): + """ + This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + dictionary of Variables, optional + Dictionary of nodes and what they should be replaced with, or None if no optimization was performed + """ + dummy_replacements = {} + for client, _ in fgraph.clients[node.outputs[0]]: + # Check for sign(det) + if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign): + dummy_replacements[client.outputs[0]] = "sign" + + # Check for log(abs(det)) + elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs): + potential_log = None + for client_2, _ in fgraph.clients[client.outputs[0]]: + if isinstance(client_2.op, Elemwise) and isinstance( + client_2.op.scalar_op, Log + ): + potential_log = client_2 + if potential_log: + dummy_replacements[potential_log.outputs[0]] = "log_abs_det" + else: + return None + + # Check for log(det) + elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log): + dummy_replacements[client.outputs[0]] = "log_det" + + # Det is used directly for something else, don't rewrite to avoid computing two dets + else: + return None + + if not dummy_replacements: + return None + else: + [x] = node.inputs + sign_det_x, log_abs_det_x = SLogDet()(x) + log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x) + slogdet_specialization_map = { + "sign": sign_det_x, + "log_abs_det": log_abs_det_x, + "log_det": log_det_x, + } + replacements = { + k: slogdet_specialization_map[v] for k, v in dummy_replacements.items() + } + return replacements diff --git a/tests/link/pytorch/test_nlinalg.py b/tests/link/pytorch/test_nlinalg.py index 7d69ac0500..55e7c447e3 100644 --- a/tests/link/pytorch/test_nlinalg.py +++ b/tests/link/pytorch/test_nlinalg.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + import numpy as np import pytest @@ -22,13 +24,13 @@ def matrix_test(): @pytest.mark.parametrize( "func", - (pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det), + (pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det), ) def test_lin_alg_no_params(func, matrix_test): x, test_value = matrix_test out = func(x) - out_fg = FunctionGraph([x], out if isinstance(out, list) else [out]) + out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out]) def assert_fn(x, y): np.testing.assert_allclose(x, y, rtol=1e-3) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 9dd2a247a8..c9b9afff19 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -21,6 +21,7 @@ KroneckerProduct, MatrixInverse, MatrixPinv, + SLogDet, matrix_inverse, svd, ) @@ -719,7 +720,7 @@ def test_det_blockdiag_rewrite(): def test_slogdet_blockdiag_rewrite(): - n_matrices = 100 + n_matrices = 10 matrix_size = (5, 5) sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size)) bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)]) @@ -776,11 +777,34 @@ def test_diag_kronecker_rewrite(): ) +def test_det_kronecker_rewrite(): + a, b = pt.dmatrices("a", "b") + kron_prod = pt.linalg.kron(a, b) + det_output = pt.linalg.det(kron_prod) + f_rewritten = function([a, b], [det_output], mode="FAST_RUN") + + # Rewrite Test + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, KroneckerProduct) for node in nodes) + + # Value Test + a_test, b_test = np.random.rand(2, 20, 20) + kron_prod_test = np.kron(a_test, b_test) + det_output_test = np.linalg.det(kron_prod_test) + rewritten_det_val = f_rewritten(a_test, b_test) + assert_allclose( + det_output_test, + rewritten_det_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + def test_slogdet_kronecker_rewrite(): a, b = pt.dmatrices("a", "b") kron_prod = pt.linalg.kron(a, b) sign_output, logdet_output = pt.linalg.slogdet(kron_prod) - f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN") + f_rewritten = function([a, b], [sign_output, logdet_output], mode="FAST_RUN") # Rewrite Test nodes = f_rewritten.maker.fgraph.apply_nodes @@ -790,7 +814,7 @@ def test_slogdet_kronecker_rewrite(): a_test, b_test = np.random.rand(2, 20, 20) kron_prod_test = np.kron(a_test, b_test) sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test) - rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test) + rewritten_sign_val, rewritten_logdet_val = f_rewritten(a_test, b_test) assert_allclose( sign_output_test, rewritten_sign_val, @@ -906,3 +930,69 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied(): f_rewritten = function([x], z_cholesky, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes assert any(isinstance(node.op, Cholesky) for node in nodes) + + +def test_slogdet_specialization(): + x, a = pt.dmatrix("x"), np.random.rand(20, 20) + det_x, det_a = pt.linalg.det(x), np.linalg.det(a) + log_abs_det_x, log_abs_det_a = pt.log(pt.abs(det_x)), np.log(np.abs(det_a)) + log_det_x, log_det_a = pt.log(det_x), np.log(det_a) + sign_det_x, sign_det_a = pt.sign(det_x), np.sign(det_a) + exp_det_x = pt.exp(det_x) + + # REWRITE TESTS + # sign(det(x)) + f = function([x], [sign_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + rw_sign_det_a = f(a) + assert_allclose( + sign_det_a, + rw_sign_det_a, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + # log(abs(det(x))) + f = function([x], [log_abs_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + rw_log_abs_det_a = f(a) + assert_allclose( + log_abs_det_a, + rw_log_abs_det_a, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + # log(det(x)) + f = function([x], [log_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + rw_log_det_a = f(a) + assert_allclose( + log_det_a, + rw_log_det_a, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + # More than 1 valid function + f = function([x], [sign_det_x, log_abs_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1 + assert not any(isinstance(node.op, Det) for node in nodes) + + # Other functions (rewrite shouldnt be applied to these) + # Only invalid functions + f = function([x], [exp_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, SLogDet) for node in nodes) + + # Invalid + Valid function + f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN") + nodes = f.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, SLogDet) for node in nodes)