From 09bd030d1ecdde053bf057694f4fcd2fc116a232 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 12 Jun 2025 15:47:13 -0700 Subject: [PATCH 1/4] Extend decomp+solve rewrite machinery to `assume_a="pos"` --- pytensor/tensor/_linalg/solve/rewriting.py | 58 ++++--- tests/tensor/linalg/test_rewriting.py | 183 +++++++++++++++++++-- 2 files changed, 205 insertions(+), 36 deletions(-) diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py index 8f3cda3e0f..2bc1116079 100644 --- a/pytensor/tensor/_linalg/solve/rewriting.py +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -15,7 +15,7 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.linalg import is_matrix_transpose -from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve +from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve from pytensor.tensor.variable import TensorVariable @@ -25,14 +25,17 @@ def decompose_A(A, assume_a, check_finite): elif assume_a == "tridiagonal": # We didn't implement check_finite for tridiagonal LU factorization return tridiagonal_lu_factor(A) + elif assume_a == "pos": + return cholesky(A, lower=True, check_finite=check_finite) else: raise NotImplementedError -def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve): +def solve_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve): b_ndim = core_solve_op.b_ndim check_finite = core_solve_op.check_finite assume_a = core_solve_op.assume_a + if assume_a == "gen": return lu_solve( A_decomp, @@ -49,11 +52,18 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: b_ndim=b_ndim, transposed=transposed, ) + elif assume_a == "pos": + return cho_solve( + (A_decomp, True), + b, + b_ndim=b_ndim, + check_finite=check_finite, + ) else: raise NotImplementedError -def _split_lu_solve_steps( +def _split_decomp_and_solve_steps( fgraph, node, *, eager: bool, allowed_assume_a: Container[str] ): if not isinstance(node.op.core_op, Solve): @@ -138,7 +148,7 @@ def find_solve_clients(var, assume_a): replacements = {} for client, transposed in A_solve_clients_and_transpose: _, b = client.inputs - new_x = solve_lu_decomposed_system( + new_x = solve_decomposed_system( A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op ) [old_x] = client.outputs @@ -149,7 +159,7 @@ def find_solve_clients(var, assume_a): return replacements -def _scan_split_non_sequence_lu_decomposition_solve( +def _scan_split_non_sequence_decomposition_and_solve( fgraph, node, *, allowed_assume_a: Container[str] ): """If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step. @@ -179,7 +189,7 @@ def _scan_split_non_sequence_lu_decomposition_solve( non_sequences = {equiv[non_seq] for non_seq in non_sequences} inner_node = equiv[inner_node] # type: ignore - replace_dict = _split_lu_solve_steps( + replace_dict = _split_decomp_and_solve_steps( new_scan_fgraph, inner_node, eager=True, @@ -207,22 +217,22 @@ def _scan_split_non_sequence_lu_decomposition_solve( @register_specialize @node_rewriter([Blockwise]) -def reuse_lu_decomposition_multiple_solves(fgraph, node): - return _split_lu_solve_steps( - fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"} +def reuse_decomposition_multiple_solves(fgraph, node): + return _split_decomp_and_solve_steps( + fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"} ) @node_rewriter([Scan]) -def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): - return _scan_split_non_sequence_lu_decomposition_solve( - fgraph, node, allowed_assume_a={"gen", "tridiagonal"} +def scan_split_non_sequence_decomposition_and_solve(fgraph, node): + return _scan_split_non_sequence_decomposition_and_solve( + fgraph, node, allowed_assume_a={"gen", "tridiagonal", "pos"} ) scan_seqopt1.register( - "scan_split_non_sequence_lu_decomposition_solve", - in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True), + scan_split_non_sequence_decomposition_and_solve.__name__, + in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True), "fast_run", "scan", "scan_pushout", @@ -231,28 +241,30 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): @node_rewriter([Blockwise]) -def reuse_lu_decomposition_multiple_solves_jax(fgraph, node): - return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"}) +def reuse_decomposition_multiple_solves_jax(fgraph, node): + return _split_decomp_and_solve_steps( + fgraph, node, eager=False, allowed_assume_a={"gen", "pos"} + ) optdb["specialize"].register( - reuse_lu_decomposition_multiple_solves_jax.__name__, - in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True), + reuse_decomposition_multiple_solves_jax.__name__, + in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True), "jax", use_db_name_as_tag=False, ) @node_rewriter([Scan]) -def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node): - return _scan_split_non_sequence_lu_decomposition_solve( - fgraph, node, allowed_assume_a={"gen"} +def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node): + return _scan_split_non_sequence_decomposition_and_solve( + fgraph, node, allowed_assume_a={"gen", "pos"} ) scan_seqopt1.register( - scan_split_non_sequence_lu_decomposition_solve_jax.__name__, - in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True), + scan_split_non_sequence_decomposition_and_solve_jax.__name__, + in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True), "jax", use_db_name_as_tag=False, position=2, diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py index 1bb5dd41a4..e9f0f0a5ae 100644 --- a/tests/tensor/linalg/test_rewriting.py +++ b/tests/tensor/linalg/test_rewriting.py @@ -6,8 +6,8 @@ from pytensor.gradient import grad from pytensor.scan.op import Scan from pytensor.tensor._linalg.solve.rewriting import ( - reuse_lu_decomposition_multiple_solves, - scan_split_non_sequence_lu_decomposition_solve, + reuse_decomposition_multiple_solves, + scan_split_non_sequence_decomposition_and_solve, ) from pytensor.tensor._linalg.solve.tridiagonal import ( LUFactorTridiagonal, @@ -15,7 +15,13 @@ ) from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.linalg import solve -from pytensor.tensor.slinalg import LUFactor, Solve, SolveTriangular +from pytensor.tensor.slinalg import ( + Cholesky, + CholeskySolve, + LUFactor, + Solve, + SolveTriangular, +) from pytensor.tensor.type import tensor @@ -42,6 +48,18 @@ def count_lu_decom_nodes(nodes) -> int: ) +def count_cholesky_decom_nodes(nodes) -> int: + return sum( + ( + isinstance(node.op, Cholesky) + or ( + isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Cholesky) + ) + ) + for node in nodes + ) + + def count_lu_solve_nodes(nodes) -> int: count = sum( ( @@ -67,10 +85,23 @@ def count_lu_solve_nodes(nodes) -> int: return int(count) +def count_cholesky_solve_nodes(nodes) -> int: + return sum( + ( + isinstance(node.op, CholeskySolve) + or ( + isinstance(node.op, Blockwise) + and isinstance(node.op.core_op, CholeskySolve) + ) + ) + for node in nodes + ) + + @pytest.mark.parametrize("transposed", (False, True)) @pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): - rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ + rewrite_name = reuse_decomposition_multiple_solves.__name__ mode = get_default_mode() A = tensor("A", shape=(3, 3)) @@ -101,10 +132,46 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): np.testing.assert_allclose(resg0, resg1, rtol=rtol) +def test_cholesky_reused_forward_and_gradient(): + rewrite_name = reuse_decomposition_multiple_solves.__name__ + mode = get_default_mode() + + A = tensor("A", shape=(3, 3)) + b = tensor("b", shape=(3, 4)) + + x = solve(A, b, assume_a="pos") + grad_x_wrt_A = grad(x.sum(), A) + fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name)) + no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes + + assert count_vanilla_solve_nodes(no_opt_nodes) == 2 + assert count_cholesky_decom_nodes(no_opt_nodes) == 0 + assert count_cholesky_solve_nodes(no_opt_nodes) == 0 + + fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name)) + opt_nodes = fn_opt.maker.fgraph.apply_nodes + assert count_vanilla_solve_nodes(opt_nodes) == 0 + assert count_cholesky_decom_nodes(opt_nodes) == 1 + assert count_cholesky_solve_nodes(opt_nodes) == 2 + + # Make sure results are correct + # A has to actually be positive definite, or else fn_opt and fn_no_opt won't agree + rng = np.random.default_rng(31) + L = rng.random(A.type.shape, dtype=A.type.dtype) + A_test = L @ L.T + b_test = rng.random(b.type.shape, dtype=b.type.dtype) + + resx0, resg0 = fn_no_opt(A_test, b_test) + resx1, resg1 = fn_opt(A_test, b_test) + rtol = 1e-7 if config.floatX == "float64" else 1e-4 + np.testing.assert_allclose(resx0, resx1, rtol=rtol) + np.testing.assert_allclose(resg0, resg1, rtol=rtol) + + @pytest.mark.parametrize("transposed", (False, True)) @pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) def test_lu_decomposition_reused_blockwise(assume_a, transposed): - rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ + rewrite_name = reuse_decomposition_multiple_solves.__name__ mode = get_default_mode() A = tensor("A", shape=(3, 3)) @@ -129,14 +196,46 @@ def test_lu_decomposition_reused_blockwise(assume_a, transposed): b_test = rng.random(b.type.shape, dtype=b.type.dtype) resx0 = fn_no_opt(A_test, b_test) resx1 = fn_opt(A_test, b_test) - rtol = rtol = 1e-7 if config.floatX == "float64" else 1e-4 + rtol = 1e-7 if config.floatX == "float64" else 1e-4 + np.testing.assert_allclose(resx0, resx1, rtol=rtol) + + +def test_cholesky_decomposition_reused_blockwise(): + rewrite_name = reuse_decomposition_multiple_solves.__name__ + mode = get_default_mode() + + A = tensor("A", shape=(3, 3)) + b = tensor("b", shape=(2, 3, 4)) + + x = solve(A, b, assume_a="pos") + fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name)) + no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes + assert count_vanilla_solve_nodes(no_opt_nodes) == 1 + assert count_cholesky_decom_nodes(no_opt_nodes) == 0 + assert count_cholesky_solve_nodes(no_opt_nodes) == 0 + + fn_opt = function([A, b], [x], mode=mode.including(rewrite_name)) + opt_nodes = fn_opt.maker.fgraph.apply_nodes + assert count_vanilla_solve_nodes(opt_nodes) == 0 + assert count_cholesky_decom_nodes(opt_nodes) == 1 + assert count_cholesky_solve_nodes(opt_nodes) == 1 + + # Make sure results are correct + rng = np.random.default_rng(31) + L = rng.random(A.type.shape, dtype=A.type.dtype) + A_test = L @ L.T + + b_test = rng.random(b.type.shape, dtype=b.type.dtype) + resx0 = fn_no_opt(A_test, b_test) + resx1 = fn_opt(A_test, b_test) + rtol = 1e-7 if config.floatX == "float64" else 1e-4 np.testing.assert_allclose(resx0, resx1, rtol=rtol) @pytest.mark.parametrize("transposed", (False, True)) @pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) def test_lu_decomposition_reused_scan(assume_a, transposed): - rewrite_name = scan_split_non_sequence_lu_decomposition_solve.__name__ + rewrite_name = scan_split_non_sequence_decomposition_and_solve.__name__ mode = get_default_mode() A = tensor("A", shape=(3, 3)) @@ -182,23 +281,81 @@ def test_lu_decomposition_reused_scan(assume_a, transposed): np.testing.assert_allclose(resx0, resx1, rtol=rtol) -def test_lu_decomposition_reused_preserves_check_finite(): +def test_cholesky_decomposition_reused_scan(): + rewrite_name = scan_split_non_sequence_decomposition_and_solve.__name__ + mode = get_default_mode() + + A = tensor("A", shape=(3, 3)) + x0 = tensor("b", shape=(3, 4)) + + xs, _ = scan( + lambda xtm1, A: solve(A, xtm1, assume_a="pos"), + outputs_info=[x0], + non_sequences=[A], + n_steps=10, + ) + + fn_no_opt = function( + [A, x0], + [xs], + mode=mode.excluding(rewrite_name), + ) + [no_opt_scan_node] = [ + node for node in fn_no_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + no_opt_nodes = no_opt_scan_node.op.fgraph.apply_nodes + assert count_vanilla_solve_nodes(no_opt_nodes) == 1 + assert count_cholesky_decom_nodes(no_opt_nodes) == 0 + assert count_cholesky_solve_nodes(no_opt_nodes) == 0 + + fn_opt = function([A, x0], [xs], mode=mode.including("scan", rewrite_name)) + [opt_scan_node] = [ + node for node in fn_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) + ] + opt_nodes = opt_scan_node.op.fgraph.apply_nodes + assert count_vanilla_solve_nodes(opt_nodes) == 0 + # The cholesky decomposition is outside of the scan! + assert count_cholesky_decom_nodes(opt_nodes) == 0 + assert count_cholesky_solve_nodes(opt_nodes) == 1 + + # Make sure results are correct + rng = np.random.default_rng(170) + L = rng.random(A.type.shape, dtype=A.type.dtype) + A_test = L @ L.T + + x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype) + resx0 = fn_no_opt(A_test, x0_test) + resx1 = fn_opt(A_test, x0_test) + rtol = 1e-7 if config.floatX == "float64" else 1e-4 + np.testing.assert_allclose(resx0, resx1, rtol=rtol) + + +@pytest.mark.parametrize( + "assume_a, count_decomp_fn, count_solve_fn", + ( + ("gen", count_lu_decom_nodes, count_lu_solve_nodes), + ("pos", count_cholesky_decom_nodes, count_cholesky_solve_nodes), + ), +) +def test_decomposition_reused_preserves_check_finite( + assume_a, count_decomp_fn, count_solve_fn +): # Check that the LU decomposition rewrite preserves the check_finite flag - rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ + rewrite_name = reuse_decomposition_multiple_solves.__name__ A = tensor("A", shape=(2, 2)) b1 = tensor("b1", shape=(2,)) b2 = tensor("b2", shape=(2,)) - x1 = solve(A, b1, assume_a="gen", check_finite=True) - x2 = solve(A, b2, assume_a="gen", check_finite=False) + x1 = solve(A, b1, assume_a=assume_a, check_finite=True) + x2 = solve(A, b2, assume_a=assume_a, check_finite=False) fn_opt = function( [A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name) ) opt_nodes = fn_opt.maker.fgraph.apply_nodes assert count_vanilla_solve_nodes(opt_nodes) == 0 - assert count_lu_decom_nodes(opt_nodes) == 1 - assert count_lu_solve_nodes(opt_nodes) == 2 + assert count_decomp_fn(opt_nodes) == 1 + assert count_solve_fn(opt_nodes) == 2 # We should get an error if A or b1 is non finite A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype) From dd6e0c38b7c35b1eaf2552703ef9e0fa21f04417 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Thu, 12 Jun 2025 16:28:28 -0700 Subject: [PATCH 2/4] Update rewrite name in test --- tests/tensor/test_blockwise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index cbaf27da29..30af86c038 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -581,7 +581,7 @@ def test_solve(self, solve_fn, batched_A, batched_b): mode = get_default_mode().excluding( "batched_vector_b_solve_to_matrix_b_solve", - "reuse_lu_decomposition_multiple_solves", + "reuse_decomposition_multiple_solves", ) fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode) From aae2e09ba46d5535e0ae95e42499deaa7c839411 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 13 Jun 2025 11:17:47 -0500 Subject: [PATCH 3/4] Refactor tests to be nicer --- tests/tensor/linalg/test_rewriting.py | 297 ++++++++------------------ 1 file changed, 89 insertions(+), 208 deletions(-) diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py index e9f0f0a5ae..f1ea2e1af3 100644 --- a/tests/tensor/linalg/test_rewriting.py +++ b/tests/tensor/linalg/test_rewriting.py @@ -25,82 +25,59 @@ from pytensor.tensor.type import tensor -def count_vanilla_solve_nodes(nodes) -> int: - return sum( - ( - isinstance(node.op, Solve) - or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)) +class DecompSolveOpCounter: + def __init__(self, solve_op, decomp_op, solve_op_value: float = 1.0): + self.solve_op = solve_op + self.decomp_op = decomp_op + self.solve_op_value = solve_op_value + + def check_node_op_or_core_op(self, node, op): + return isinstance(node.op, op) or ( + isinstance(node.op, Blockwise) and isinstance(node.op.core_op, op) ) - for node in nodes - ) + def count_vanilla_solve_nodes(self, nodes) -> int: + return sum(self.check_node_op_or_core_op(node, Solve) for node in nodes) -def count_lu_decom_nodes(nodes) -> int: - return sum( - ( - isinstance(node.op, LUFactor | LUFactorTridiagonal) - or ( - isinstance(node.op, Blockwise) - and isinstance(node.op.core_op, LUFactor | LUFactorTridiagonal) - ) + def count_decomp_nodes(self, nodes) -> int: + return sum( + self.check_node_op_or_core_op(node, self.decomp_op) for node in nodes ) - for node in nodes - ) - -def count_cholesky_decom_nodes(nodes) -> int: - return sum( - ( - isinstance(node.op, Cholesky) - or ( - isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Cholesky) - ) + def count_solve_nodes(self, nodes) -> int: + count = sum( + self.solve_op_value * self.check_node_op_or_core_op(node, self.solve_op) + for node in nodes ) - for node in nodes - ) + return int(count) -def count_lu_solve_nodes(nodes) -> int: - count = sum( - ( - # LUFactor uses 2 SolveTriangular nodes, so we count each as 0.5 - 0.5 - * ( - isinstance(node.op, SolveTriangular) - or ( - isinstance(node.op, Blockwise) - and isinstance(node.op.core_op, SolveTriangular) - ) - ) - or ( - isinstance(node.op, SolveLUFactorTridiagonal) - or ( - isinstance(node.op, Blockwise) - and isinstance(node.op.core_op, SolveLUFactorTridiagonal) - ) - ) - ) - for node in nodes - ) - return int(count) +LUOpCounter = DecompSolveOpCounter( + solve_op=SolveTriangular, + decomp_op=LUFactor, + # Each rewrite introduces two triangular solves, so count them as 1/2 each + solve_op_value=0.5, +) +TriDiagLUOpCounter = DecompSolveOpCounter( + solve_op=SolveLUFactorTridiagonal, decomp_op=LUFactorTridiagonal, solve_op_value=1.0 +) -def count_cholesky_solve_nodes(nodes) -> int: - return sum( - ( - isinstance(node.op, CholeskySolve) - or ( - isinstance(node.op, Blockwise) - and isinstance(node.op.core_op, CholeskySolve) - ) - ) - for node in nodes - ) +CholeskyOpCounter = DecompSolveOpCounter( + solve_op=CholeskySolve, decomp_op=Cholesky, solve_op_value=1.0 +) @pytest.mark.parametrize("transposed", (False, True)) -@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) -def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): +@pytest.mark.parametrize( + "assume_a, counter", + ( + ("gen", LUOpCounter), + ("tridiagonal", TriDiagLUOpCounter), + ("pos", CholeskyOpCounter), + ), +) +def test_lu_decomposition_reused_forward_and_gradient(assume_a, counter, transposed): rewrite_name = reuse_decomposition_multiple_solves.__name__ mode = get_default_mode() @@ -111,56 +88,23 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): grad_x_wrt_A = grad(x.sum(), A) fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name)) no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(no_opt_nodes) == 2 - assert count_lu_decom_nodes(no_opt_nodes) == 0 - assert count_lu_solve_nodes(no_opt_nodes) == 0 + assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 2 + assert counter.count_decomp_nodes(no_opt_nodes) == 0 + assert counter.count_solve_nodes(no_opt_nodes) == 0 fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name)) opt_nodes = fn_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 - assert count_lu_decom_nodes(opt_nodes) == 1 - assert count_lu_solve_nodes(opt_nodes) == 2 + assert counter.count_vanilla_solve_nodes(opt_nodes) == 0 + assert counter.count_decomp_nodes(opt_nodes) == 1 + assert counter.count_solve_nodes(opt_nodes) == 2 # Make sure results are correct rng = np.random.default_rng(31) A_test = rng.random(A.type.shape, dtype=A.type.dtype) - b_test = rng.random(b.type.shape, dtype=b.type.dtype) - resx0, resg0 = fn_no_opt(A_test, b_test) - resx1, resg1 = fn_opt(A_test, b_test) - rtol = 1e-7 if config.floatX == "float64" else 1e-4 - np.testing.assert_allclose(resx0, resx1, rtol=rtol) - np.testing.assert_allclose(resg0, resg1, rtol=rtol) + if assume_a == "pos": + A_test = A_test @ A_test.T # Ensure positive definite for Cholesky - -def test_cholesky_reused_forward_and_gradient(): - rewrite_name = reuse_decomposition_multiple_solves.__name__ - mode = get_default_mode() - - A = tensor("A", shape=(3, 3)) - b = tensor("b", shape=(3, 4)) - - x = solve(A, b, assume_a="pos") - grad_x_wrt_A = grad(x.sum(), A) - fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name)) - no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes - - assert count_vanilla_solve_nodes(no_opt_nodes) == 2 - assert count_cholesky_decom_nodes(no_opt_nodes) == 0 - assert count_cholesky_solve_nodes(no_opt_nodes) == 0 - - fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name)) - opt_nodes = fn_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 - assert count_cholesky_decom_nodes(opt_nodes) == 1 - assert count_cholesky_solve_nodes(opt_nodes) == 2 - - # Make sure results are correct - # A has to actually be positive definite, or else fn_opt and fn_no_opt won't agree - rng = np.random.default_rng(31) - L = rng.random(A.type.shape, dtype=A.type.dtype) - A_test = L @ L.T b_test = rng.random(b.type.shape, dtype=b.type.dtype) - resx0, resg0 = fn_no_opt(A_test, b_test) resx1, resg1 = fn_opt(A_test, b_test) rtol = 1e-7 if config.floatX == "float64" else 1e-4 @@ -169,8 +113,15 @@ def test_cholesky_reused_forward_and_gradient(): @pytest.mark.parametrize("transposed", (False, True)) -@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) -def test_lu_decomposition_reused_blockwise(assume_a, transposed): +@pytest.mark.parametrize( + "assume_a, counter", + ( + ("gen", LUOpCounter), + ("tridiagonal", TriDiagLUOpCounter), + ("pos", CholeskyOpCounter), + ), +) +def test_lu_decomposition_reused_blockwise(assume_a, counter, transposed): rewrite_name = reuse_decomposition_multiple_solves.__name__ mode = get_default_mode() @@ -180,50 +131,21 @@ def test_lu_decomposition_reused_blockwise(assume_a, transposed): x = solve(A, b, assume_a=assume_a, transposed=transposed) fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name)) no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(no_opt_nodes) == 1 - assert count_lu_decom_nodes(no_opt_nodes) == 0 - assert count_lu_solve_nodes(no_opt_nodes) == 0 + assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 1 + assert counter.count_decomp_nodes(no_opt_nodes) == 0 + assert counter.count_solve_nodes(no_opt_nodes) == 0 fn_opt = function([A, b], [x], mode=mode.including(rewrite_name)) opt_nodes = fn_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 - assert count_lu_decom_nodes(opt_nodes) == 1 - assert count_lu_solve_nodes(opt_nodes) == 1 + assert counter.count_vanilla_solve_nodes(opt_nodes) == 0 + assert counter.count_decomp_nodes(opt_nodes) == 1 + assert counter.count_solve_nodes(opt_nodes) == 1 # Make sure results are correct rng = np.random.default_rng(31) A_test = rng.random(A.type.shape, dtype=A.type.dtype) - b_test = rng.random(b.type.shape, dtype=b.type.dtype) - resx0 = fn_no_opt(A_test, b_test) - resx1 = fn_opt(A_test, b_test) - rtol = 1e-7 if config.floatX == "float64" else 1e-4 - np.testing.assert_allclose(resx0, resx1, rtol=rtol) - - -def test_cholesky_decomposition_reused_blockwise(): - rewrite_name = reuse_decomposition_multiple_solves.__name__ - mode = get_default_mode() - - A = tensor("A", shape=(3, 3)) - b = tensor("b", shape=(2, 3, 4)) - - x = solve(A, b, assume_a="pos") - fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name)) - no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(no_opt_nodes) == 1 - assert count_cholesky_decom_nodes(no_opt_nodes) == 0 - assert count_cholesky_solve_nodes(no_opt_nodes) == 0 - - fn_opt = function([A, b], [x], mode=mode.including(rewrite_name)) - opt_nodes = fn_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 - assert count_cholesky_decom_nodes(opt_nodes) == 1 - assert count_cholesky_solve_nodes(opt_nodes) == 1 - - # Make sure results are correct - rng = np.random.default_rng(31) - L = rng.random(A.type.shape, dtype=A.type.dtype) - A_test = L @ L.T + if assume_a == "pos": + A_test = A_test @ A_test.T # Ensure positive definite for Cholesky b_test = rng.random(b.type.shape, dtype=b.type.dtype) resx0 = fn_no_opt(A_test, b_test) @@ -233,8 +155,15 @@ def test_cholesky_decomposition_reused_blockwise(): @pytest.mark.parametrize("transposed", (False, True)) -@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) -def test_lu_decomposition_reused_scan(assume_a, transposed): +@pytest.mark.parametrize( + "assume_a, counter", + ( + ("gen", LUOpCounter), + ("tridiagonal", TriDiagLUOpCounter), + ("pos", CholeskyOpCounter), + ), +) +def test_lu_decomposition_reused_scan(assume_a, counter, transposed): rewrite_name = scan_split_non_sequence_decomposition_and_solve.__name__ mode = get_default_mode() @@ -257,71 +186,25 @@ def test_lu_decomposition_reused_scan(assume_a, transposed): node for node in fn_no_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) ] no_opt_nodes = no_opt_scan_node.op.fgraph.apply_nodes - assert count_vanilla_solve_nodes(no_opt_nodes) == 1 - assert count_lu_decom_nodes(no_opt_nodes) == 0 - assert count_lu_solve_nodes(no_opt_nodes) == 0 + assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 1 + assert counter.count_decomp_nodes(no_opt_nodes) == 0 + assert counter.count_solve_nodes(no_opt_nodes) == 0 fn_opt = function([A, x0], [xs], mode=mode.including("scan", rewrite_name)) [opt_scan_node] = [ node for node in fn_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) ] opt_nodes = opt_scan_node.op.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 + assert counter.count_vanilla_solve_nodes(opt_nodes) == 0 # The LU decomp is outside of the scan! - assert count_lu_decom_nodes(opt_nodes) == 0 - assert count_lu_solve_nodes(opt_nodes) == 1 + assert counter.count_decomp_nodes(opt_nodes) == 0 + assert counter.count_solve_nodes(opt_nodes) == 1 # Make sure results are correct rng = np.random.default_rng(170) A_test = rng.random(A.type.shape, dtype=A.type.dtype) - x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype) - resx0 = fn_no_opt(A_test, x0_test) - resx1 = fn_opt(A_test, x0_test) - rtol = 1e-7 if config.floatX == "float64" else 1e-4 - np.testing.assert_allclose(resx0, resx1, rtol=rtol) - - -def test_cholesky_decomposition_reused_scan(): - rewrite_name = scan_split_non_sequence_decomposition_and_solve.__name__ - mode = get_default_mode() - - A = tensor("A", shape=(3, 3)) - x0 = tensor("b", shape=(3, 4)) - - xs, _ = scan( - lambda xtm1, A: solve(A, xtm1, assume_a="pos"), - outputs_info=[x0], - non_sequences=[A], - n_steps=10, - ) - - fn_no_opt = function( - [A, x0], - [xs], - mode=mode.excluding(rewrite_name), - ) - [no_opt_scan_node] = [ - node for node in fn_no_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) - ] - no_opt_nodes = no_opt_scan_node.op.fgraph.apply_nodes - assert count_vanilla_solve_nodes(no_opt_nodes) == 1 - assert count_cholesky_decom_nodes(no_opt_nodes) == 0 - assert count_cholesky_solve_nodes(no_opt_nodes) == 0 - - fn_opt = function([A, x0], [xs], mode=mode.including("scan", rewrite_name)) - [opt_scan_node] = [ - node for node in fn_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) - ] - opt_nodes = opt_scan_node.op.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 - # The cholesky decomposition is outside of the scan! - assert count_cholesky_decom_nodes(opt_nodes) == 0 - assert count_cholesky_solve_nodes(opt_nodes) == 1 - - # Make sure results are correct - rng = np.random.default_rng(170) - L = rng.random(A.type.shape, dtype=A.type.dtype) - A_test = L @ L.T + if assume_a == "pos": + A_test = A_test @ A_test.T # Ensure positive definite for Cholesky x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype) resx0 = fn_no_opt(A_test, x0_test) @@ -331,15 +214,13 @@ def test_cholesky_decomposition_reused_scan(): @pytest.mark.parametrize( - "assume_a, count_decomp_fn, count_solve_fn", + "assume_a, counter", ( - ("gen", count_lu_decom_nodes, count_lu_solve_nodes), - ("pos", count_cholesky_decom_nodes, count_cholesky_solve_nodes), + ("gen", LUOpCounter), + ("pos", CholeskyOpCounter), ), ) -def test_decomposition_reused_preserves_check_finite( - assume_a, count_decomp_fn, count_solve_fn -): +def test_decomposition_reused_preserves_check_finite(assume_a, counter): # Check that the LU decomposition rewrite preserves the check_finite flag rewrite_name = reuse_decomposition_multiple_solves.__name__ @@ -353,9 +234,9 @@ def test_decomposition_reused_preserves_check_finite( [A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name) ) opt_nodes = fn_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 - assert count_decomp_fn(opt_nodes) == 1 - assert count_solve_fn(opt_nodes) == 2 + assert counter.count_vanilla_solve_nodes(opt_nodes) == 0 + assert counter.count_decomp_nodes(opt_nodes) == 1 + assert counter.count_solve_nodes(opt_nodes) == 2 # We should get an error if A or b1 is non finite A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype) From c11a9b5ca4d711013c480ef171ac25a15c1763b8 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 13 Jun 2025 11:34:51 -0500 Subject: [PATCH 4/4] Respect core op `lower` flag when rewriting to ChoSolve --- pytensor/tensor/_linalg/solve/rewriting.py | 23 ++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py index 2bc1116079..c0a1c5cce8 100644 --- a/pytensor/tensor/_linalg/solve/rewriting.py +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -19,19 +19,21 @@ from pytensor.tensor.variable import TensorVariable -def decompose_A(A, assume_a, check_finite): +def decompose_A(A, assume_a, check_finite, lower): if assume_a == "gen": return lu_factor(A, check_finite=check_finite) elif assume_a == "tridiagonal": # We didn't implement check_finite for tridiagonal LU factorization return tridiagonal_lu_factor(A) elif assume_a == "pos": - return cholesky(A, lower=True, check_finite=check_finite) + return cholesky(A, lower=lower, check_finite=check_finite) else: raise NotImplementedError -def solve_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve): +def solve_decomposed_system( + A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve +): b_ndim = core_solve_op.b_ndim check_finite = core_solve_op.check_finite assume_a = core_solve_op.assume_a @@ -53,8 +55,9 @@ def solve_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Sol transposed=transposed, ) elif assume_a == "pos": + # We can ignore the transposed argument here because A is symmetric by assumption return cho_solve( - (A_decomp, True), + (A_decomp, lower), b, b_ndim=b_ndim, check_finite=check_finite, @@ -143,13 +146,21 @@ def find_solve_clients(var, assume_a): if client.op.core_op.check_finite: check_finite_decomp = True break - A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp) + + lower = node.op.core_op.lower + A_decomp = decompose_A( + A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower + ) replacements = {} for client, transposed in A_solve_clients_and_transpose: _, b = client.inputs new_x = solve_decomposed_system( - A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op + A_decomp, + b, + transposed=transposed, + lower=lower, + core_solve_op=client.op.core_op, ) [old_x] = client.outputs new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)