Skip to content

Commit

Permalink
added rewrites for inv(diag(x)) and inv(eye) (#898)
Browse files Browse the repository at this point in the history
* updated tests

* updated rewrites

* paramterized tests and added batch case

* minor changes
  • Loading branch information
tanish1729 authored Aug 30, 2024
1 parent 7eca252 commit 1a1c62b
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 6 deletions.
96 changes: 93 additions & 3 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import cast

from pytensor import Variable
from pytensor import tensor as pt
from pytensor.graph import Apply, FunctionGraph
from pytensor.graph.rewriting.basic import (
copy_stack_trace,
Expand Down Expand Up @@ -48,6 +49,7 @@


logger = logging.getLogger(__name__)
ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv)


def is_matrix_transpose(x: TensorVariable) -> bool:
Expand Down Expand Up @@ -592,11 +594,10 @@ def rewrite_inv_inv(fgraph, node):
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
valid_inverses = (MatrixInverse, MatrixPinv)
# Check if its a valid inverse operation (either inv/pinv)
# In case the outer operation is an inverse, it directly goes to the next step of finding inner operation
# If the outer operation is not a valid inverse, we do not apply this rewrite
if not isinstance(node.op.core_op, valid_inverses):
if not isinstance(node.op.core_op, ALL_INVERSE_OPS):
return None

potential_inner_inv = node.inputs[0].owner
Expand All @@ -607,7 +608,96 @@ def rewrite_inv_inv(fgraph, node):
if not (
potential_inner_inv
and isinstance(potential_inner_inv.op, Blockwise)
and isinstance(potential_inner_inv.op.core_op, valid_inverses)
and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS)
):
return None
return [potential_inner_inv.inputs[0]]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_eye_to_eye(fgraph, node):
"""
This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
core_op = node.op.core_op
if not (isinstance(core_op, ALL_INVERSE_OPS)):
return None

# Check whether input to inverse is Eye and the 1's are on main diagonal
potential_eye = node.inputs[0]
if not (
potential_eye.owner
and isinstance(potential_eye.owner.op, Eye)
and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0
):
return None
return [potential_eye]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_inv_diag_to_diag_reciprocal(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements.
This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
core_op = node.op.core_op
if not (isinstance(core_op, ALL_INVERSE_OPS)):
return None

inputs = node.inputs[0]
# Check for use of pt.diag first
if (
inputs.owner
and isinstance(inputs.owner.op, AllocDiag)
and AllocDiag.is_offset_zero(inputs.owner)
):
inv_input = inputs.owner.inputs[0]
inv_val = pt.diag(1 / inv_input)
return [inv_val]

# Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
inputs_or_none = _find_diag_from_eye_mul(inputs)
if inputs_or_none is None:
return None

eye_input, non_eye_inputs = inputs_or_none

# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None

non_eye_input = non_eye_inputs[0]

# For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
if non_eye_input.type.broadcastable[-2:] == (False, False):
non_eye_diag = non_eye_input.diagonal(axis1=-1, axis2=-2)
non_eye_input = pt.shape_padaxis(non_eye_diag, -2)

return [eye_input / non_eye_input]
100 changes: 97 additions & 3 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
from tests.test_rop import break_op


ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8


def test_rop_lop():
mx = matrix("mx")
mv = matrix("mv")
Expand Down Expand Up @@ -557,14 +560,105 @@ def test_svd_uv_merge():
assert svd_counter == 1


def get_pt_function(x, op_name):
return getattr(pt.linalg, op_name)(x)


@pytest.mark.parametrize("inv_op_1", ["inv", "pinv"])
@pytest.mark.parametrize("inv_op_2", ["inv", "pinv"])
def test_inv_inv_rewrite(inv_op_1, inv_op_2):
def get_pt_function(x, op_name):
return getattr(pt.linalg, op_name)(x)

x = pt.matrix("x")
op1 = get_pt_function(x, inv_op_1)
op2 = get_pt_function(op1, inv_op_2)
rewritten_out = rewrite_graph(op2)
assert rewritten_out == x


@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
def test_inv_eye_to_eye(inv_op):
x = pt.eye(10)
x_inv = get_pt_function(x, inv_op)
f_rewritten = function([], x_inv, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

# Rewrite Test
valid_inverses = (MatrixInverse, MatrixPinv)
assert not any(isinstance(node.op, valid_inverses) for node in nodes)

# Value Test
x_test = np.eye(10)
x_inv_val = np.linalg.inv(x_test)
rewritten_val = f_rewritten()

assert_allclose(
x_inv_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


@pytest.mark.parametrize(
"shape",
[(), (7,), (7, 7), (5, 7, 7)],
ids=["scalar", "vector", "matrix", "batched"],
)
@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
def test_inv_diag_from_eye_mul(shape, inv_op):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
x_diag = pt.eye(7) * x
# Calculating inverse using pt.linalg.inv
x_inv = get_pt_function(x_diag, inv_op)

# REWRITE TEST
f_rewritten = function([x], x_inv, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

valid_inverses = (MatrixInverse, MatrixPinv)
assert not any(isinstance(node.op, valid_inverses) for node in nodes)

# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
inverse_matrix = np.linalg.inv(x_test_matrix)
rewritten_inverse = f_rewritten(x_test)

assert_allclose(
inverse_matrix,
rewritten_inverse,
atol=ATOL,
rtol=RTOL,
)


@pytest.mark.parametrize("inv_op", ["inv", "pinv"])
def test_inv_diag_from_diag(inv_op):
x = pt.dvector("x")
x_diag = pt.diag(x)
x_inv = get_pt_function(x_diag, inv_op)

# REWRITE TEST
f_rewritten = function([x], x_inv, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes

valid_inverses = (MatrixInverse, MatrixPinv)
assert not any(isinstance(node.op, valid_inverses) for node in nodes)

# NUMERIC VALUE TEST
x_test = np.random.rand(10)
x_test_matrix = np.eye(10) * x_test
inverse_matrix = np.linalg.inv(x_test_matrix)
rewritten_inverse = f_rewritten(x_test)

assert_allclose(
inverse_matrix,
rewritten_inverse,
atol=ATOL,
rtol=RTOL,
)

0 comments on commit 1a1c62b

Please sign in to comment.