Skip to content

Commit

Permalink
fixed merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
tanish1729 committed Jul 21, 2024
1 parent 981688c commit 736782b
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 0 deletions.
61 changes: 61 additions & 0 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,3 +611,64 @@ def rewrite_inv_inv(fgraph, node):
):
return None
return [potential_inner_inv.inputs[0]]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_cholesky_eye_to_eye(fgraph, node):
"""
This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself
The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Find whether cholesky op is being applied
if not isinstance(node.op.core_op, Cholesky):
return None

# Check whether input to Cholesky is Eye and the 1's are on main diagonal
eye_check = node.inputs[0]
if not (
eye_check.owner
and isinstance(eye_check.owner.op, Eye)
and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0
):
return None
return [eye_check]


@register_canonicalize
@register_stabilize
@node_rewriter([Blockwise])
def rewrite_cholesky_diag_from_eye_mul(fgraph, node):
# Find whether cholesky op is being applied
if not isinstance(node.op.core_op, Cholesky):
return None

# Check whether input is diagonal from multiplcation of identity matrix with a tensor
inputs = node.inputs[0]
inputs_or_none = _find_diag_from_eye_mul(inputs)
if inputs_or_none is None:
return None

eye_input, non_eye_inputs = inputs_or_none

# Dealing with only one other input
if len(non_eye_inputs) != 1:
return None

eye_input, non_eye_input = eye_input[0], non_eye_inputs[0]

return [eye_input * (non_eye_input**0.5)]
63 changes: 63 additions & 0 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,66 @@ def get_pt_function(x, op_name):
op2 = get_pt_function(op1, inv_op_2)
rewritten_out = rewrite_graph(op2)
assert rewritten_out == x


def test_cholesky_eye_rewrite():
x = pt.eye(10)
x_mat = pt.matrix("x")
L = pt.linalg.cholesky(x)
L_mat = pt.linalg.cholesky(x_mat)
f_rewritten = function([], L, mode="FAST_RUN")
f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes

# Rewrite Test
assert not any(isinstance(node.op, Cholesky) for node in nodes)
assert any(isinstance(node.op, Cholesky) for node in nodes_mat)

# Value Test
x_test = np.eye(10)
L = np.linalg.cholesky(x_test)
rewritten_val = f_rewritten()

assert_allclose(
L,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


@pytest.mark.parametrize(
"shape",
[(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)],
ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"],
)
def test_cholesky_diag_from_eye_mul(shape):
# Initializing x based on scalar/vector/matrix
x = pt.tensor("x", shape=shape)
y = pt.eye(7) * x
# Performing cholesky decomposition using pt.linalg.cholesky
z_cholesky = pt.linalg.cholesky(y)

# REWRITE TEST
f_rewritten = function([x], z_cholesky, mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Cholesky) for node in nodes)

# NUMERIC VALUE TEST
if len(shape) == 0:
x_test = np.array(np.random.rand()).astype(config.floatX)
elif len(shape) == 1:
x_test = np.random.rand(*shape).astype(config.floatX)
else:
x_test = np.random.rand(*shape).astype(config.floatX)
x_test_matrix = np.eye(7) * x_test
cholesky_val = np.linalg.cholesky(x_test_matrix)
rewritten_val = f_rewritten(x_test)

assert_allclose(
cholesky_val,
rewritten_val,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)

0 comments on commit 736782b

Please sign in to comment.