Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slogdet returns naive expression and is optimized later #1041

Merged
merged 16 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion pytensor/tensor/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pytensor.gradient import DisconnectedType
from pytensor.graph.basic import Apply
from pytensor.graph.op import Op
from pytensor.tensor import TensorLike
from pytensor.tensor import basic as ptb
from pytensor.tensor import math as ptm
from pytensor.tensor.basic import as_tensor_variable, diagonal
Expand Down Expand Up @@ -266,7 +267,33 @@ def __str__(self):
return "SLogDet"


slogdet = Blockwise(SLogDet())
def slogdet(x: TensorLike) -> tuple[ptb.TensorVariable, ptb.TensorVariable]:
"""
Compute the sign and (natural) logarithm of the determinant of an array.

Returns a naive graph which is optimized later using rewrites with the det operation.

Parameters
----------
x : (..., M, M) tensor or tensor_like
Input tensor, has to be square.

Returns
-------
A tuple with the following attributes:

sign : (...) tensor_like
A number representing the sign of the determinant. For a real matrix,
this is 1, 0, or -1.
logabsdet : (...) tensor_like
The natural log of the absolute value of the determinant.

If the determinant is zero, then `sign` will be 0 and `logabsdet`
will be -inf. In all cases, the determinant is equal to
``sign * exp(logabsdet)``.
"""
det_val = det(x)
return ptm.sign(det_val), ptm.log(ptm.abs(det_val))


class Eig(Op):
Expand Down
120 changes: 72 additions & 48 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from collections.abc import Callable
from typing import cast

import numpy as np

from pytensor import Variable
from pytensor import tensor as pt
from pytensor.compile import optdb
Expand All @@ -11,7 +13,7 @@
in2out,
node_rewriter,
)
from pytensor.scalar.basic import Mul
from pytensor.scalar.basic import Abs, Log, Mul, Sign
from pytensor.tensor.basic import (
AllocDiag,
ExtractDiag,
Expand All @@ -30,11 +32,11 @@
KroneckerProduct,
MatrixInverse,
MatrixPinv,
SLogDet,
det,
inv,
kron,
pinv,
slogdet,
svd,
)
from pytensor.tensor.rewriting.basic import (
Expand Down Expand Up @@ -785,45 +787,6 @@
return [prod(det_sub_matrices)]


@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_blockdiag(fgraph, node):
"""
This rewrite simplifies the slogdet of a blockdiagonal matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those

slogdet(block_diag(a,b,c,....)) = prod(sign(a), sign(b), sign(c),...), sum(logdet(a), logdet(b), logdet(c),....)

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check for inner block_diag operation
potential_block_diag = node.inputs[0].owner
if not (
potential_block_diag
and isinstance(potential_block_diag.op, Blockwise)
and isinstance(potential_block_diag.op.core_op, BlockDiagonal)
):
return None

# Find the composing sub_matrices
sub_matrices = potential_block_diag.inputs
sign_sub_matrices, logdet_sub_matrices = zip(
*[slogdet(sub_matrices[i]) for i in range(len(sub_matrices))]
)

return [prod(sign_sub_matrices), sum(logdet_sub_matrices)]


@register_canonicalize
@register_stabilize
@node_rewriter([ExtractDiag])
Expand Down Expand Up @@ -860,10 +823,10 @@

@register_canonicalize
@register_stabilize
@node_rewriter([slogdet])
def rewrite_slogdet_kronecker(fgraph, node):
@node_rewriter([det])
def rewrite_det_kronecker(fgraph, node):
"""
This rewrite simplifies the slogdet of a kronecker-structured matrix by extracting the individual sub matrices and returning the sign and logdet values computed using those
This rewrite simplifies the determinant of a kronecker-structured matrix by extracting the individual sub matrices and returning the det values computed using those

Parameters
----------
Expand All @@ -884,13 +847,12 @@

# Find the matrices
a, b = potential_kron.inputs
signs, logdets = zip(*[slogdet(a), slogdet(b)])
dets = [det(a), det(b)]
sizes = [a.shape[-1], b.shape[-1]]
prod_sizes = prod(sizes, no_zeros_in_input=True)
signs_final = [signs[i] ** (prod_sizes / sizes[i]) for i in range(2)]
logdet_final = [logdets[i] * prod_sizes / sizes[i] for i in range(2)]
det_final = prod([dets[i] ** (prod_sizes / sizes[i]) for i in range(2)])

return [prod(signs_final, no_zeros_in_input=True), sum(logdet_final)]
return [det_final]


@register_canonicalize
Expand Down Expand Up @@ -989,3 +951,65 @@
"jax",
position=0.9, # Run before canonicalization
)


@register_specialize
@node_rewriter([det])
def slogdet_specialization(fgraph, node):
"""
This rewrite targets specific operations related to slogdet i.e sign(det), log(det) and log(abs(det)) and rewrites them using the SLogDet operation.

Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized

Returns
-------
dictionary of Variables, optional
Dictionary of nodes and what they should be replaced with, or None if no optimization was performed
"""
dummy_replacements = {}
for client, _ in fgraph.clients[node.outputs[0]]:
# Check for sign(det)
if isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Sign):
dummy_replacements[client.outputs[0]] = "sign"

# Check for log(abs(det))
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Abs):
potential_log = None
for client_2, _ in fgraph.clients[client.outputs[0]]:
if isinstance(client_2.op, Elemwise) and isinstance(
client_2.op.scalar_op, Log
):
potential_log = client_2
if potential_log:
dummy_replacements[potential_log.outputs[0]] = "log_abs_det"
else:
return None

# Check for log(det)
elif isinstance(client.op, Elemwise) and isinstance(client.op.scalar_op, Log):
dummy_replacements[client.outputs[0]] = "log_det"

# Det is used directly for something else, don't rewrite to avoid computing two dets
else:
return None

if not dummy_replacements:
return None

Check warning on line 1002 in pytensor/tensor/rewriting/linalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/rewriting/linalg.py#L1002

Added line #L1002 was not covered by tests
else:
[x] = node.inputs
sign_det_x, log_abs_det_x = SLogDet()(x)
log_det_x = pt.where(pt.eq(sign_det_x, -1), np.nan, log_abs_det_x)
slogdet_specialization_map = {
"sign": sign_det_x,
"log_abs_det": log_abs_det_x,
"log_det": log_det_x,
}
replacements = {
k: slogdet_specialization_map[v] for k, v in dummy_replacements.items()
}
return replacements
6 changes: 4 additions & 2 deletions tests/link/pytorch/test_nlinalg.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import Sequence

import numpy as np
import pytest

Expand All @@ -22,13 +24,13 @@ def matrix_test():

@pytest.mark.parametrize(
"func",
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det),
(pt_nla.eig, pt_nla.eigh, pt_nla.SLogDet(), pt_nla.inv, pt_nla.det),
)
def test_lin_alg_no_params(func, matrix_test):
x, test_value = matrix_test

out = func(x)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
out_fg = FunctionGraph([x], out if isinstance(out, Sequence) else [out])

def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)
Expand Down
96 changes: 93 additions & 3 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
KroneckerProduct,
MatrixInverse,
MatrixPinv,
SLogDet,
matrix_inverse,
svd,
)
Expand Down Expand Up @@ -719,7 +720,7 @@ def test_det_blockdiag_rewrite():


def test_slogdet_blockdiag_rewrite():
n_matrices = 100
n_matrices = 10
matrix_size = (5, 5)
sub_matrices = pt.tensor("sub_matrices", shape=(n_matrices, *matrix_size))
bd_output = pt.linalg.block_diag(*[sub_matrices[i] for i in range(n_matrices)])
Expand Down Expand Up @@ -776,11 +777,34 @@ def test_diag_kronecker_rewrite():
)


def test_det_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
det_output = pt.linalg.det(kron_prod)
f_rewritten = function([a, b], [det_output], mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, KroneckerProduct) for node in nodes)

# Value Test
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
det_output_test = np.linalg.det(kron_prod_test)
rewritten_det_val = f_rewritten(a_test, b_test)
assert_allclose(
det_output_test,
rewritten_det_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_slogdet_kronecker_rewrite():
a, b = pt.dmatrices("a", "b")
kron_prod = pt.linalg.kron(a, b)
sign_output, logdet_output = pt.linalg.slogdet(kron_prod)
f_rewritten = function([kron_prod], [sign_output, logdet_output], mode="FAST_RUN")
f_rewritten = function([a, b], [sign_output, logdet_output], mode="FAST_RUN")

# Rewrite Test
nodes = f_rewritten.maker.fgraph.apply_nodes
Expand All @@ -790,7 +814,7 @@ def test_slogdet_kronecker_rewrite():
a_test, b_test = np.random.rand(2, 20, 20)
kron_prod_test = np.kron(a_test, b_test)
sign_output_test, logdet_output_test = np.linalg.slogdet(kron_prod_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(kron_prod_test)
rewritten_sign_val, rewritten_logdet_val = f_rewritten(a_test, b_test)
assert_allclose(
sign_output_test,
rewritten_sign_val,
Expand Down Expand Up @@ -906,3 +930,69 @@ def test_rewrite_cholesky_diag_to_sqrt_diag_not_applied():
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert any(isinstance(node.op, Cholesky) for node in nodes)


def test_slogdet_specialization():
x, a = pt.dmatrix("x"), np.random.rand(20, 20)
det_x, det_a = pt.linalg.det(x), np.linalg.det(a)
log_abs_det_x, log_abs_det_a = pt.log(pt.abs(det_x)), np.log(np.abs(det_a))
log_det_x, log_det_a = pt.log(det_x), np.log(det_a)
sign_det_x, sign_det_a = pt.sign(det_x), np.sign(det_a)
exp_det_x = pt.exp(det_x)

# REWRITE TESTS
# sign(det(x))
f = function([x], [sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)
rw_sign_det_a = f(a)
assert_allclose(
sign_det_a,
rw_sign_det_a,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)

# log(abs(det(x)))
f = function([x], [log_abs_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)
rw_log_abs_det_a = f(a)
assert_allclose(
log_abs_det_a,
rw_log_abs_det_a,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)

# log(det(x))
f = function([x], [log_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)
rw_log_det_a = f(a)
assert_allclose(
log_det_a,
rw_log_det_a,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)

# More than 1 valid function
f = function([x], [sign_det_x, log_abs_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert len([node for node in nodes if isinstance(node.op, SLogDet)]) == 1
assert not any(isinstance(node.op, Det) for node in nodes)

# Other functions (rewrite shouldnt be applied to these)
# Only invalid functions
f = function([x], [exp_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)

# Invalid + Valid function
f = function([x], [exp_det_x, sign_det_x], mode="FAST_RUN")
nodes = f.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, SLogDet) for node in nodes)
Loading