diff --git a/pytensor/tensor/_linalg/solve/rewriting.py b/pytensor/tensor/_linalg/solve/rewriting.py index 8f3cda3e0f..c0a1c5cce8 100644 --- a/pytensor/tensor/_linalg/solve/rewriting.py +++ b/pytensor/tensor/_linalg/solve/rewriting.py @@ -15,24 +15,29 @@ from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.linalg import is_matrix_transpose -from pytensor.tensor.slinalg import Solve, lu_factor, lu_solve +from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve from pytensor.tensor.variable import TensorVariable -def decompose_A(A, assume_a, check_finite): +def decompose_A(A, assume_a, check_finite, lower): if assume_a == "gen": return lu_factor(A, check_finite=check_finite) elif assume_a == "tridiagonal": # We didn't implement check_finite for tridiagonal LU factorization return tridiagonal_lu_factor(A) + elif assume_a == "pos": + return cholesky(A, lower=lower, check_finite=check_finite) else: raise NotImplementedError -def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve): +def solve_decomposed_system( + A_decomp, b, transposed=False, lower=False, *, core_solve_op: Solve +): b_ndim = core_solve_op.b_ndim check_finite = core_solve_op.check_finite assume_a = core_solve_op.assume_a + if assume_a == "gen": return lu_solve( A_decomp, @@ -49,11 +54,19 @@ def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: b_ndim=b_ndim, transposed=transposed, ) + elif assume_a == "pos": + # We can ignore the transposed argument here because A is symmetric by assumption + return cho_solve( + (A_decomp, lower), + b, + b_ndim=b_ndim, + check_finite=check_finite, + ) else: raise NotImplementedError -def _split_lu_solve_steps( +def _split_decomp_and_solve_steps( fgraph, node, *, eager: bool, allowed_assume_a: Container[str] ): if not isinstance(node.op.core_op, Solve): @@ -133,13 +146,21 @@ def find_solve_clients(var, assume_a): if client.op.core_op.check_finite: check_finite_decomp = True break - A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp) + + lower = node.op.core_op.lower + A_decomp = decompose_A( + A, assume_a=assume_a, check_finite=check_finite_decomp, lower=lower + ) replacements = {} for client, transposed in A_solve_clients_and_transpose: _, b = client.inputs - new_x = solve_lu_decomposed_system( - A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op + new_x = solve_decomposed_system( + A_decomp, + b, + transposed=transposed, + lower=lower, + core_solve_op=client.op.core_op, ) [old_x] = client.outputs new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype) @@ -149,7 +170,7 @@ def find_solve_clients(var, assume_a): return replacements -def _scan_split_non_sequence_lu_decomposition_solve( +def _scan_split_non_sequence_decomposition_and_solve( fgraph, node, *, allowed_assume_a: Container[str] ): """If the A of a Solve within a Scan is a function of non-sequences, split the LU decomposition step. @@ -179,7 +200,7 @@ def _scan_split_non_sequence_lu_decomposition_solve( non_sequences = {equiv[non_seq] for non_seq in non_sequences} inner_node = equiv[inner_node] # type: ignore - replace_dict = _split_lu_solve_steps( + replace_dict = _split_decomp_and_solve_steps( new_scan_fgraph, inner_node, eager=True, @@ -207,22 +228,22 @@ def _scan_split_non_sequence_lu_decomposition_solve( @register_specialize @node_rewriter([Blockwise]) -def reuse_lu_decomposition_multiple_solves(fgraph, node): - return _split_lu_solve_steps( - fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal"} +def reuse_decomposition_multiple_solves(fgraph, node): + return _split_decomp_and_solve_steps( + fgraph, node, eager=False, allowed_assume_a={"gen", "tridiagonal", "pos"} ) @node_rewriter([Scan]) -def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): - return _scan_split_non_sequence_lu_decomposition_solve( - fgraph, node, allowed_assume_a={"gen", "tridiagonal"} +def scan_split_non_sequence_decomposition_and_solve(fgraph, node): + return _scan_split_non_sequence_decomposition_and_solve( + fgraph, node, allowed_assume_a={"gen", "tridiagonal", "pos"} ) scan_seqopt1.register( - "scan_split_non_sequence_lu_decomposition_solve", - in2out(scan_split_non_sequence_lu_decomposition_solve, ignore_newtrees=True), + scan_split_non_sequence_decomposition_and_solve.__name__, + in2out(scan_split_non_sequence_decomposition_and_solve, ignore_newtrees=True), "fast_run", "scan", "scan_pushout", @@ -231,28 +252,30 @@ def scan_split_non_sequence_lu_decomposition_solve(fgraph, node): @node_rewriter([Blockwise]) -def reuse_lu_decomposition_multiple_solves_jax(fgraph, node): - return _split_lu_solve_steps(fgraph, node, eager=False, allowed_assume_a={"gen"}) +def reuse_decomposition_multiple_solves_jax(fgraph, node): + return _split_decomp_and_solve_steps( + fgraph, node, eager=False, allowed_assume_a={"gen", "pos"} + ) optdb["specialize"].register( - reuse_lu_decomposition_multiple_solves_jax.__name__, - in2out(reuse_lu_decomposition_multiple_solves_jax, ignore_newtrees=True), + reuse_decomposition_multiple_solves_jax.__name__, + in2out(reuse_decomposition_multiple_solves_jax, ignore_newtrees=True), "jax", use_db_name_as_tag=False, ) @node_rewriter([Scan]) -def scan_split_non_sequence_lu_decomposition_solve_jax(fgraph, node): - return _scan_split_non_sequence_lu_decomposition_solve( - fgraph, node, allowed_assume_a={"gen"} +def scan_split_non_sequence_decomposition_and_solve_jax(fgraph, node): + return _scan_split_non_sequence_decomposition_and_solve( + fgraph, node, allowed_assume_a={"gen", "pos"} ) scan_seqopt1.register( - scan_split_non_sequence_lu_decomposition_solve_jax.__name__, - in2out(scan_split_non_sequence_lu_decomposition_solve_jax, ignore_newtrees=True), + scan_split_non_sequence_decomposition_and_solve_jax.__name__, + in2out(scan_split_non_sequence_decomposition_and_solve_jax, ignore_newtrees=True), "jax", use_db_name_as_tag=False, position=2, diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py index 1bb5dd41a4..f1ea2e1af3 100644 --- a/tests/tensor/linalg/test_rewriting.py +++ b/tests/tensor/linalg/test_rewriting.py @@ -6,8 +6,8 @@ from pytensor.gradient import grad from pytensor.scan.op import Scan from pytensor.tensor._linalg.solve.rewriting import ( - reuse_lu_decomposition_multiple_solves, - scan_split_non_sequence_lu_decomposition_solve, + reuse_decomposition_multiple_solves, + scan_split_non_sequence_decomposition_and_solve, ) from pytensor.tensor._linalg.solve.tridiagonal import ( LUFactorTridiagonal, @@ -15,62 +15,70 @@ ) from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.linalg import solve -from pytensor.tensor.slinalg import LUFactor, Solve, SolveTriangular +from pytensor.tensor.slinalg import ( + Cholesky, + CholeskySolve, + LUFactor, + Solve, + SolveTriangular, +) from pytensor.tensor.type import tensor -def count_vanilla_solve_nodes(nodes) -> int: - return sum( - ( - isinstance(node.op, Solve) - or (isinstance(node.op, Blockwise) and isinstance(node.op.core_op, Solve)) +class DecompSolveOpCounter: + def __init__(self, solve_op, decomp_op, solve_op_value: float = 1.0): + self.solve_op = solve_op + self.decomp_op = decomp_op + self.solve_op_value = solve_op_value + + def check_node_op_or_core_op(self, node, op): + return isinstance(node.op, op) or ( + isinstance(node.op, Blockwise) and isinstance(node.op.core_op, op) ) - for node in nodes - ) + def count_vanilla_solve_nodes(self, nodes) -> int: + return sum(self.check_node_op_or_core_op(node, Solve) for node in nodes) -def count_lu_decom_nodes(nodes) -> int: - return sum( - ( - isinstance(node.op, LUFactor | LUFactorTridiagonal) - or ( - isinstance(node.op, Blockwise) - and isinstance(node.op.core_op, LUFactor | LUFactorTridiagonal) - ) + def count_decomp_nodes(self, nodes) -> int: + return sum( + self.check_node_op_or_core_op(node, self.decomp_op) for node in nodes ) - for node in nodes - ) - -def count_lu_solve_nodes(nodes) -> int: - count = sum( - ( - # LUFactor uses 2 SolveTriangular nodes, so we count each as 0.5 - 0.5 - * ( - isinstance(node.op, SolveTriangular) - or ( - isinstance(node.op, Blockwise) - and isinstance(node.op.core_op, SolveTriangular) - ) - ) - or ( - isinstance(node.op, SolveLUFactorTridiagonal) - or ( - isinstance(node.op, Blockwise) - and isinstance(node.op.core_op, SolveLUFactorTridiagonal) - ) - ) + def count_solve_nodes(self, nodes) -> int: + count = sum( + self.solve_op_value * self.check_node_op_or_core_op(node, self.solve_op) + for node in nodes ) - for node in nodes - ) - return int(count) + return int(count) + + +LUOpCounter = DecompSolveOpCounter( + solve_op=SolveTriangular, + decomp_op=LUFactor, + # Each rewrite introduces two triangular solves, so count them as 1/2 each + solve_op_value=0.5, +) + +TriDiagLUOpCounter = DecompSolveOpCounter( + solve_op=SolveLUFactorTridiagonal, decomp_op=LUFactorTridiagonal, solve_op_value=1.0 +) + +CholeskyOpCounter = DecompSolveOpCounter( + solve_op=CholeskySolve, decomp_op=Cholesky, solve_op_value=1.0 +) @pytest.mark.parametrize("transposed", (False, True)) -@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) -def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): - rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ +@pytest.mark.parametrize( + "assume_a, counter", + ( + ("gen", LUOpCounter), + ("tridiagonal", TriDiagLUOpCounter), + ("pos", CholeskyOpCounter), + ), +) +def test_lu_decomposition_reused_forward_and_gradient(assume_a, counter, transposed): + rewrite_name = reuse_decomposition_multiple_solves.__name__ mode = get_default_mode() A = tensor("A", shape=(3, 3)) @@ -80,19 +88,22 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): grad_x_wrt_A = grad(x.sum(), A) fn_no_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.excluding(rewrite_name)) no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(no_opt_nodes) == 2 - assert count_lu_decom_nodes(no_opt_nodes) == 0 - assert count_lu_solve_nodes(no_opt_nodes) == 0 + assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 2 + assert counter.count_decomp_nodes(no_opt_nodes) == 0 + assert counter.count_solve_nodes(no_opt_nodes) == 0 fn_opt = function([A, b], [x, grad_x_wrt_A], mode=mode.including(rewrite_name)) opt_nodes = fn_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 - assert count_lu_decom_nodes(opt_nodes) == 1 - assert count_lu_solve_nodes(opt_nodes) == 2 + assert counter.count_vanilla_solve_nodes(opt_nodes) == 0 + assert counter.count_decomp_nodes(opt_nodes) == 1 + assert counter.count_solve_nodes(opt_nodes) == 2 # Make sure results are correct rng = np.random.default_rng(31) A_test = rng.random(A.type.shape, dtype=A.type.dtype) + if assume_a == "pos": + A_test = A_test @ A_test.T # Ensure positive definite for Cholesky + b_test = rng.random(b.type.shape, dtype=b.type.dtype) resx0, resg0 = fn_no_opt(A_test, b_test) resx1, resg1 = fn_opt(A_test, b_test) @@ -102,9 +113,16 @@ def test_lu_decomposition_reused_forward_and_gradient(assume_a, transposed): @pytest.mark.parametrize("transposed", (False, True)) -@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) -def test_lu_decomposition_reused_blockwise(assume_a, transposed): - rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ +@pytest.mark.parametrize( + "assume_a, counter", + ( + ("gen", LUOpCounter), + ("tridiagonal", TriDiagLUOpCounter), + ("pos", CholeskyOpCounter), + ), +) +def test_lu_decomposition_reused_blockwise(assume_a, counter, transposed): + rewrite_name = reuse_decomposition_multiple_solves.__name__ mode = get_default_mode() A = tensor("A", shape=(3, 3)) @@ -113,30 +131,40 @@ def test_lu_decomposition_reused_blockwise(assume_a, transposed): x = solve(A, b, assume_a=assume_a, transposed=transposed) fn_no_opt = function([A, b], [x], mode=mode.excluding(rewrite_name)) no_opt_nodes = fn_no_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(no_opt_nodes) == 1 - assert count_lu_decom_nodes(no_opt_nodes) == 0 - assert count_lu_solve_nodes(no_opt_nodes) == 0 + assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 1 + assert counter.count_decomp_nodes(no_opt_nodes) == 0 + assert counter.count_solve_nodes(no_opt_nodes) == 0 fn_opt = function([A, b], [x], mode=mode.including(rewrite_name)) opt_nodes = fn_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 - assert count_lu_decom_nodes(opt_nodes) == 1 - assert count_lu_solve_nodes(opt_nodes) == 1 + assert counter.count_vanilla_solve_nodes(opt_nodes) == 0 + assert counter.count_decomp_nodes(opt_nodes) == 1 + assert counter.count_solve_nodes(opt_nodes) == 1 # Make sure results are correct rng = np.random.default_rng(31) A_test = rng.random(A.type.shape, dtype=A.type.dtype) + if assume_a == "pos": + A_test = A_test @ A_test.T # Ensure positive definite for Cholesky + b_test = rng.random(b.type.shape, dtype=b.type.dtype) resx0 = fn_no_opt(A_test, b_test) resx1 = fn_opt(A_test, b_test) - rtol = rtol = 1e-7 if config.floatX == "float64" else 1e-4 + rtol = 1e-7 if config.floatX == "float64" else 1e-4 np.testing.assert_allclose(resx0, resx1, rtol=rtol) @pytest.mark.parametrize("transposed", (False, True)) -@pytest.mark.parametrize("assume_a", ("gen", "tridiagonal")) -def test_lu_decomposition_reused_scan(assume_a, transposed): - rewrite_name = scan_split_non_sequence_lu_decomposition_solve.__name__ +@pytest.mark.parametrize( + "assume_a, counter", + ( + ("gen", LUOpCounter), + ("tridiagonal", TriDiagLUOpCounter), + ("pos", CholeskyOpCounter), + ), +) +def test_lu_decomposition_reused_scan(assume_a, counter, transposed): + rewrite_name = scan_split_non_sequence_decomposition_and_solve.__name__ mode = get_default_mode() A = tensor("A", shape=(3, 3)) @@ -158,23 +186,26 @@ def test_lu_decomposition_reused_scan(assume_a, transposed): node for node in fn_no_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) ] no_opt_nodes = no_opt_scan_node.op.fgraph.apply_nodes - assert count_vanilla_solve_nodes(no_opt_nodes) == 1 - assert count_lu_decom_nodes(no_opt_nodes) == 0 - assert count_lu_solve_nodes(no_opt_nodes) == 0 + assert counter.count_vanilla_solve_nodes(no_opt_nodes) == 1 + assert counter.count_decomp_nodes(no_opt_nodes) == 0 + assert counter.count_solve_nodes(no_opt_nodes) == 0 fn_opt = function([A, x0], [xs], mode=mode.including("scan", rewrite_name)) [opt_scan_node] = [ node for node in fn_opt.maker.fgraph.apply_nodes if isinstance(node.op, Scan) ] opt_nodes = opt_scan_node.op.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 + assert counter.count_vanilla_solve_nodes(opt_nodes) == 0 # The LU decomp is outside of the scan! - assert count_lu_decom_nodes(opt_nodes) == 0 - assert count_lu_solve_nodes(opt_nodes) == 1 + assert counter.count_decomp_nodes(opt_nodes) == 0 + assert counter.count_solve_nodes(opt_nodes) == 1 # Make sure results are correct rng = np.random.default_rng(170) A_test = rng.random(A.type.shape, dtype=A.type.dtype) + if assume_a == "pos": + A_test = A_test @ A_test.T # Ensure positive definite for Cholesky + x0_test = rng.random(x0.type.shape, dtype=x0.type.dtype) resx0 = fn_no_opt(A_test, x0_test) resx1 = fn_opt(A_test, x0_test) @@ -182,23 +213,30 @@ def test_lu_decomposition_reused_scan(assume_a, transposed): np.testing.assert_allclose(resx0, resx1, rtol=rtol) -def test_lu_decomposition_reused_preserves_check_finite(): +@pytest.mark.parametrize( + "assume_a, counter", + ( + ("gen", LUOpCounter), + ("pos", CholeskyOpCounter), + ), +) +def test_decomposition_reused_preserves_check_finite(assume_a, counter): # Check that the LU decomposition rewrite preserves the check_finite flag - rewrite_name = reuse_lu_decomposition_multiple_solves.__name__ + rewrite_name = reuse_decomposition_multiple_solves.__name__ A = tensor("A", shape=(2, 2)) b1 = tensor("b1", shape=(2,)) b2 = tensor("b2", shape=(2,)) - x1 = solve(A, b1, assume_a="gen", check_finite=True) - x2 = solve(A, b2, assume_a="gen", check_finite=False) + x1 = solve(A, b1, assume_a=assume_a, check_finite=True) + x2 = solve(A, b2, assume_a=assume_a, check_finite=False) fn_opt = function( [A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name) ) opt_nodes = fn_opt.maker.fgraph.apply_nodes - assert count_vanilla_solve_nodes(opt_nodes) == 0 - assert count_lu_decom_nodes(opt_nodes) == 1 - assert count_lu_solve_nodes(opt_nodes) == 2 + assert counter.count_vanilla_solve_nodes(opt_nodes) == 0 + assert counter.count_decomp_nodes(opt_nodes) == 1 + assert counter.count_solve_nodes(opt_nodes) == 2 # We should get an error if A or b1 is non finite A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index cbaf27da29..30af86c038 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -581,7 +581,7 @@ def test_solve(self, solve_fn, batched_A, batched_b): mode = get_default_mode().excluding( "batched_vector_b_solve_to_matrix_b_solve", - "reuse_lu_decomposition_multiple_solves", + "reuse_decomposition_multiple_solves", ) fn = function([In(A, mutable=True), In(b, mutable=True)], x, mode=mode)