From 625e98cb185cf3ad86648719a4150aa29e773770 Mon Sep 17 00:00:00 2001 From: Tanish Date: Mon, 5 Aug 2024 17:58:28 +0530 Subject: [PATCH 1/4] updated tests --- tests/tensor/rewriting/test_linalg.py | 90 +++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index 7353a82be0..e2e2121c61 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -41,6 +41,9 @@ from tests.test_rop import break_op +ATOL = RTOL = 1e-3 if config.floatX == "float32" else 1e-8 + + def test_rop_lop(): mx = matrix("mx") mv = matrix("mv") @@ -568,3 +571,90 @@ def get_pt_function(x, op_name): op2 = get_pt_function(op1, inv_op_2) rewritten_out = rewrite_graph(op2) assert rewritten_out == x + + +def test_inv_eye_to_eye(): + x = pt.eye(10) + x_inv = pt.linalg.inv(x) + f_rewritten = function([], x_inv, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + # Rewrite Test + valid_inverses = (MatrixInverse, MatrixPinv) + assert not any(isinstance(node.op, valid_inverses) for node in nodes) + + # Value Test + x_test = np.eye(10) + x_inv_val = np.linalg.inv(x_test) + rewritten_val = f_rewritten() + + assert_allclose( + x_inv_val, + rewritten_val, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +@pytest.mark.parametrize( + "shape", + [(), (7,), (7, 7)], + ids=["scalar", "vector", "matrix"], +) +def test_inv_diag_from_eye_mul(shape): + # Initializing x based on scalar/vector/matrix + x = pt.tensor("x", shape=shape) + x_diag = pt.eye(7) * x + # Calculating inverse using pt.linalg.inv + x_inv = pt.linalg.inv(x_diag) + + # REWRITE TEST + f_rewritten = function([x], x_inv, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + valid_inverses = (MatrixInverse, MatrixPinv) + assert not any(isinstance(node.op, valid_inverses) for node in nodes) + + # NUMERIC VALUE TEST + if len(shape) == 0: + x_test = np.array(np.random.rand()).astype(config.floatX) + elif len(shape) == 1: + x_test = np.random.rand(*shape).astype(config.floatX) + else: + x_test = np.random.rand(*shape).astype(config.floatX) + x_test_matrix = np.eye(7) * x_test + inverse_matrix = np.linalg.inv(x_test_matrix) + rewritten_inverse = f_rewritten(x_test) + + assert_allclose( + inverse_matrix, + rewritten_inverse, + atol=ATOL, + rtol=RTOL, + ) + + +def test_inv_diag_from_diag(): + x = pt.dvector("x") + x_diag = pt.diag(x) + x_inv = pt.linalg.inv(x_diag) + + # REWRITE TEST + f_rewritten = function([x], x_inv, mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + + valid_inverses = (MatrixInverse, MatrixPinv) + assert not any(isinstance(node.op, valid_inverses) for node in nodes) + + # NUMERIC VALUE TEST + x_test = np.random.rand(10) + x_test_matrix = np.eye(10) * x_test + inverse_matrix = np.linalg.inv(x_test_matrix) + rewritten_inverse = f_rewritten(x_test) + + assert_allclose( + inverse_matrix, + rewritten_inverse, + atol=ATOL, + rtol=RTOL, + ) From 1db99999bb42663112e4c96626e804c294a37086 Mon Sep 17 00:00:00 2001 From: Tanish Date: Mon, 5 Aug 2024 18:03:16 +0530 Subject: [PATCH 2/4] updated rewrites --- pytensor/tensor/rewriting/linalg.py | 94 +++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 1de6dbb373..7278285e5f 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -3,6 +3,7 @@ from typing import cast from pytensor import Variable +from pytensor import tensor as pt from pytensor.graph import Apply, FunctionGraph from pytensor.graph.rewriting.basic import ( copy_stack_trace, @@ -611,3 +612,96 @@ def rewrite_inv_inv(fgraph, node): ): return None return [potential_inner_inv.inputs[0]] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_eye_to_eye(fgraph, node): + """ + This rewrite takes advantage of the fact that the inverse of an identity matrix is the matrix itself + The presence of an identity matrix is identified by checking whether we have k = 0 for an Eye Op inside an inverse op. + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + valid_inverses = (MatrixInverse, MatrixPinv) + core_op = node.op.core_op + if not (isinstance(core_op, valid_inverses)): + return None + + # Check whether input to inverse is Eye and the 1's are on main diagonal + eye_check = node.inputs[0] + if not ( + eye_check.owner + and isinstance(eye_check.owner.op, Eye) + and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0 + ): + return None + return [eye_check] + + +@register_canonicalize +@register_stabilize +@node_rewriter([Blockwise]) +def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): + """ + This rewrite takes advantage of the fact that for a diagonal matrix, the inverse is a diagonal matrix with the new diagonal entries as reciprocals of the original diagonal elements. + This function deals with diagonal matrix arising from the multiplicaton of eye with a scalar/vector/matrix + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + valid_inverses = (MatrixInverse, MatrixPinv) + core_op = node.op.core_op + if not (isinstance(core_op, valid_inverses)): + return None + + inputs = node.inputs[0] + # Check for use of pt.diag first + if ( + inputs.owner + and isinstance(inputs.owner.op, AllocDiag) + and AllocDiag.is_offset_zero(inputs.owner) + ): + inv_input = inputs.owner.inputs[0] + if inv_input.type.ndim == 1: + inv_val = pt.diag(1 / inv_input) + return [inv_val] + + # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix + inputs_or_none = _find_diag_from_eye_mul(inputs) + if inputs_or_none is None: + return None + + eye_input, non_eye_inputs = inputs_or_none + + # Dealing with only one other input + if len(non_eye_inputs) != 1: + return None + + non_eye_input = non_eye_inputs[0] + + # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those + if non_eye_input.type.broadcastable[-2:] == (False, False): + # For Matrix + return [eye_input / non_eye_input.diagonal(axis1=-1, axis2=-2)] + else: + # For Vector or Scalar + return [eye_input / non_eye_input] From 3359a2aa0cbb1d98fdddc33bf9babbfc4890086d Mon Sep 17 00:00:00 2001 From: Tanish Date: Tue, 6 Aug 2024 19:27:52 +0530 Subject: [PATCH 3/4] paramterized tests and added batch case --- pytensor/tensor/rewriting/linalg.py | 37 ++++++++++++--------------- tests/tensor/rewriting/test_linalg.py | 26 +++++++++++-------- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 7278285e5f..9ec5a458f0 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -49,6 +49,7 @@ logger = logging.getLogger(__name__) +ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv) def is_matrix_transpose(x: TensorVariable) -> bool: @@ -593,11 +594,11 @@ def rewrite_inv_inv(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - valid_inverses = (MatrixInverse, MatrixPinv) + ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv) # Check if its a valid inverse operation (either inv/pinv) # In case the outer operation is an inverse, it directly goes to the next step of finding inner operation # If the outer operation is not a valid inverse, we do not apply this rewrite - if not isinstance(node.op.core_op, valid_inverses): + if not isinstance(node.op.core_op, ALL_INVERSE_OPS): return None potential_inner_inv = node.inputs[0].owner @@ -608,7 +609,7 @@ def rewrite_inv_inv(fgraph, node): if not ( potential_inner_inv and isinstance(potential_inner_inv.op, Blockwise) - and isinstance(potential_inner_inv.op.core_op, valid_inverses) + and isinstance(potential_inner_inv.op.core_op, ALL_INVERSE_OPS) ): return None return [potential_inner_inv.inputs[0]] @@ -632,20 +633,19 @@ def rewrite_inv_eye_to_eye(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - valid_inverses = (MatrixInverse, MatrixPinv) core_op = node.op.core_op - if not (isinstance(core_op, valid_inverses)): + if not (isinstance(core_op, ALL_INVERSE_OPS)): return None # Check whether input to inverse is Eye and the 1's are on main diagonal - eye_check = node.inputs[0] + potential_eye = node.inputs[0] if not ( - eye_check.owner - and isinstance(eye_check.owner.op, Eye) - and getattr(eye_check.owner.inputs[-1], "data", -1).item() == 0 + potential_eye.owner + and isinstance(potential_eye.owner.op, Eye) + and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0 ): return None - return [eye_check] + return [potential_eye] @register_canonicalize @@ -668,9 +668,8 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - valid_inverses = (MatrixInverse, MatrixPinv) core_op = node.op.core_op - if not (isinstance(core_op, valid_inverses)): + if not (isinstance(core_op, ALL_INVERSE_OPS)): return None inputs = node.inputs[0] @@ -681,9 +680,8 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): and AllocDiag.is_offset_zero(inputs.owner) ): inv_input = inputs.owner.inputs[0] - if inv_input.type.ndim == 1: - inv_val = pt.diag(1 / inv_input) - return [inv_val] + inv_val = pt.diag(1 / inv_input) + return [inv_val] # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix inputs_or_none = _find_diag_from_eye_mul(inputs) @@ -700,8 +698,7 @@ def rewrite_inv_diag_to_diag_reciprocal(fgraph, node): # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those if non_eye_input.type.broadcastable[-2:] == (False, False): - # For Matrix - return [eye_input / non_eye_input.diagonal(axis1=-1, axis2=-2)] - else: - # For Vector or Scalar - return [eye_input / non_eye_input] + non_eye_diag = non_eye_input.diagonal(axis1=-1, axis2=-2) + non_eye_input = pt.shape_padaxis(non_eye_diag, -2) + + return [eye_input / non_eye_input] diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index e2e2121c61..0bee56eb30 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -560,12 +560,13 @@ def test_svd_uv_merge(): assert svd_counter == 1 +def get_pt_function(x, op_name): + return getattr(pt.linalg, op_name)(x) + + @pytest.mark.parametrize("inv_op_1", ["inv", "pinv"]) @pytest.mark.parametrize("inv_op_2", ["inv", "pinv"]) def test_inv_inv_rewrite(inv_op_1, inv_op_2): - def get_pt_function(x, op_name): - return getattr(pt.linalg, op_name)(x) - x = pt.matrix("x") op1 = get_pt_function(x, inv_op_1) op2 = get_pt_function(op1, inv_op_2) @@ -573,9 +574,10 @@ def get_pt_function(x, op_name): assert rewritten_out == x -def test_inv_eye_to_eye(): +@pytest.mark.parametrize("inv_op", ["inv", "pinv"]) +def test_inv_eye_to_eye(inv_op): x = pt.eye(10) - x_inv = pt.linalg.inv(x) + x_inv = get_pt_function(x, inv_op) f_rewritten = function([], x_inv, mode="FAST_RUN") nodes = f_rewritten.maker.fgraph.apply_nodes @@ -598,15 +600,16 @@ def test_inv_eye_to_eye(): @pytest.mark.parametrize( "shape", - [(), (7,), (7, 7)], - ids=["scalar", "vector", "matrix"], + [(), (7,), (7, 7), (5, 7, 7)], + ids=["scalar", "vector", "matrix", "batched"], ) -def test_inv_diag_from_eye_mul(shape): +@pytest.mark.parametrize("inv_op", ["inv", "pinv"]) +def test_inv_diag_from_eye_mul(shape, inv_op): # Initializing x based on scalar/vector/matrix x = pt.tensor("x", shape=shape) x_diag = pt.eye(7) * x # Calculating inverse using pt.linalg.inv - x_inv = pt.linalg.inv(x_diag) + x_inv = get_pt_function(x_diag, inv_op) # REWRITE TEST f_rewritten = function([x], x_inv, mode="FAST_RUN") @@ -634,10 +637,11 @@ def test_inv_diag_from_eye_mul(shape): ) -def test_inv_diag_from_diag(): +@pytest.mark.parametrize("inv_op", ["inv", "pinv"]) +def test_inv_diag_from_diag(inv_op): x = pt.dvector("x") x_diag = pt.diag(x) - x_inv = pt.linalg.inv(x_diag) + x_inv = get_pt_function(x_diag, inv_op) # REWRITE TEST f_rewritten = function([x], x_inv, mode="FAST_RUN") From 370f2d8f128814381d32fb26038b5028ea33ccbc Mon Sep 17 00:00:00 2001 From: Tanish Date: Sun, 18 Aug 2024 17:42:32 +0530 Subject: [PATCH 4/4] minor changes --- pytensor/tensor/rewriting/linalg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index 9ec5a458f0..47ca08cf21 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -594,7 +594,6 @@ def rewrite_inv_inv(fgraph, node): list of Variable, optional List of optimized variables, or None if no optimization was performed """ - ALL_INVERSE_OPS = (MatrixInverse, MatrixPinv) # Check if its a valid inverse operation (either inv/pinv) # In case the outer operation is an inverse, it directly goes to the next step of finding inner operation # If the outer operation is not a valid inverse, we do not apply this rewrite