diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 1de6dbb373..d76cc04ee9 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -611,3 +611,64 @@ def rewrite_inv_inv(fgraph, node): ): return None return [potential_inner_inv.inputs[0]] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_cholesky_eye_to_eye(fgraph, node): + """ + This rewrite takes advantage of the fact that the cholesky decomposition of an identity matrix is the matrix itself + + The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside Cholesky. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Find whether cholesky op is being applied + if not isinstance(node.op.core_op, Cholesky): + return None + + # Check whether input to Cholesky is Eye and the 1's are on main diagonal + eye_check = node.inputs[0] + if not ( + eye_check.owner + and isinstance(eye_check.owner.op, Eye) + and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0 + ): + return None + return [eye_check] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_cholesky_diag_from_eye_mul(fgraph, node): + # Find whether cholesky op is being applied + if not isinstance(node.op.core_op, Cholesky): + return None + + # Check whether input is diagonal from multiplcation of identity matrix with a tensor + inputs = node.inputs[0] + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + eye_input, non_eye_input = eye_input[0], non_eye_inputs[0] + + return [eye_input * (non_eye_input**0.5)] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 7353a82be0..605a81a3c1 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -568,3 +568,66 @@ def get_pt_function(x, op_name): op2 = get_pt_function(op1, inv_op_2) rewritten_out = rewrite_graph(op2) assert rewritten_out == x + + +def test_cholesky_eye_rewrite(): + x = pt.eye(10) + x_mat = pt.matrix("x") + L = pt.linalg.cholesky(x) + L_mat = pt.linalg.cholesky(x_mat) + f_rewritten = function([], L, mode="FAST_RUN") + f_rewritten_mat = function([x_mat], L_mat, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + nodes_mat = f_rewritten_mat.maker.fgraph.apply_nodes + + # Rewrite Test + assert not any(isinstance(node.op, Cholesky) for node in nodes) + assert any(isinstance(node.op, Cholesky) for node in nodes_mat) + + # Value Test + x_test = np.eye(10) + L = np.linalg.cholesky(x_test) + rewritten_val = f_rewritten() + + assert_allclose( + L, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +@pytest.mark.parametrize( + "shape", + [(), (7,), (1, 7), (7, 1), (7, 7), (3, 7, 7)], + ids=["scalar", "vector", "row_vec", "col_vec", "matrix", "batched_input"], +) +def test_cholesky_diag_from_eye_mul(shape): + # Initializing x based on scalar/vector/matrix + x = pt.tensor("x", shape=shape) + y = pt.eye(7) * x + # Performing cholesky decomposition using pt.linalg.cholesky + z_cholesky = pt.linalg.cholesky(y) + + # REWRITE TEST + f_rewritten = function([x], z_cholesky, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Cholesky) for node in nodes) + + # NUMERIC VALUE TEST + if len(shape) == 0: + x_test = np.array(np.random.rand()).astype(config.floatX) + elif len(shape) == 1: + x_test = np.random.rand(*shape).astype(config.floatX) + else: + x_test = np.random.rand(*shape).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + cholesky_val = np.linalg.cholesky(x_test_matrix) + rewritten_val = f_rewritten(x_test) + + assert_allclose( + cholesky_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + )