diff --git a/pytensor/graph/destroyhandler.py b/pytensor/graph/destroyhandler.py index bc29732a1f..1fe59f2c6d 100644 --- a/pytensor/graph/destroyhandler.py +++ b/pytensor/graph/destroyhandler.py @@ -7,7 +7,6 @@ import itertools from collections import deque -import pytensor from pytensor.configdefaults import config from pytensor.graph.basic import Constant from pytensor.graph.features import AlreadyThere, Bookkeeper @@ -223,7 +222,7 @@ def _build_droot_impact(destroy_handler): return droot, impact, root_destroyer -def fast_inplace_check(fgraph, inputs): +def inplace_candidates(fgraph, inputs, protected_inputs=None): """ Return the variables in inputs that are possible candidate for as inputs of inplace operation. @@ -234,22 +233,49 @@ def fast_inplace_check(fgraph, inputs): Inputs Variable that you want to use as inplace destination. """ - Supervisor = pytensor.compile.function.types.Supervisor - protected_inputs = list( - itertools.chain.from_iterable( - f.protected for f in fgraph._features if isinstance(f, Supervisor) + if protected_inputs is None: + from pytensor.compile.function.types import Supervisor + + protected_inputs = set( + itertools.chain.from_iterable( + f.protected for f in fgraph._features if isinstance(f, Supervisor) + ) ) - ) - protected_inputs.extend(fgraph.outputs) - - inputs = [ - i - for i in inputs - if not isinstance(i, Constant) - and not fgraph.has_destroyers([i]) - and i not in protected_inputs - ] - return inputs + protected_inputs.update(fgraph.outputs) + + has_destroyers = fgraph.has_destroyers + view_i = fgraph.destroy_handler.view_i + candidate_roots = {} + candidate_inputs = [] + for inp in inputs: + if isinstance(inp, Constant): + # Can't inplace on constants. + continue + + # Find the root of the view chain, and while traversing check if it passes on any protected inputs. + view_of_protected = False + root = inp + try: + while True: + if root in protected_inputs: + view_of_protected = True + root = view_i[root] + except KeyError: + pass + + if root in candidate_roots: + # Another input views on the same root, we can't destroy either + if (invalid_candidate := candidate_roots[root]) is not None: + # Invalidate the previous candidate + candidate_inputs.remove(invalid_candidate) + candidate_roots[root] = None + elif not view_of_protected and not has_destroyers([inp]): + candidate_inputs.append(inp) + candidate_roots[root] = inp + else: + candidate_roots[root] = None + + return candidate_inputs class DestroyHandler(Bookkeeper): diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 4d2a3715c3..88ad4c1522 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -1,10 +1,8 @@ -import itertools - -from pytensor.compile import Supervisor from pytensor.compile.mode import optdb from pytensor.graph import Constant, node_rewriter +from pytensor.graph.destroyhandler import inplace_candidates from pytensor.graph.replace import vectorize_node -from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, out2in +from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.math import Dot @@ -13,6 +11,7 @@ register_specialize, register_stabilize, ) +from pytensor.tensor.rewriting.elemwise import InplaceGraphOptimizer from pytensor.tensor.shape import Reshape from pytensor.tensor.subtensor import ( AdvancedIncSubtensor, @@ -262,74 +261,77 @@ def local_blockwise_of_subtensor(fgraph, node): return [x[(*none_slices, *core_idxs)]] -@node_rewriter(tracks=[Blockwise], inplace=True) -def blockwise_inplace(fgraph, node): - blockwise_op = node.op - - if blockwise_op.destroy_map: - # Op already has inplace - return - - # Find out valid inputs for inplacing - batch_ndim = blockwise_op.batch_ndim(node) - out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim] - - protected_inputs = [ - f.protected for f in fgraph._features if isinstance(f, Supervisor) - ] - protected_inputs = list(itertools.chain.from_iterable(protected_inputs)) - protected_inputs.extend(fgraph.outputs) - allowed_inplace_inputs = [ - idx - for idx, inp in enumerate(node.inputs) - if - ( - # Constants would need to be recreated every time if inplaced - not isinstance(inp, Constant) - # We can only inplace on inputs that are not being broadcasted - # As those are reused across iterations of Blockwise - and node.inputs[idx].type.broadcastable[:batch_ndim] == out_batch_bcast - # Inputs that are marked as protected or destroyed can't be inplaced - and not fgraph.has_destroyers([inp]) - and inp not in protected_inputs +class InplaceBlockwiseOptimizer(InplaceGraphOptimizer): + op = Blockwise + + def filter_candidate_pairs(self, fgraph, node, protected_inputs): + blockwise_op = node.op + batch_ndim = blockwise_op.batch_ndim(node) + out_batch_bcast = node.outputs[0].type.broadcastable[:batch_ndim] + inputs = node.inputs + + candidate_inputs = set( + inplace_candidates( + fgraph, + [ + inp + for inp in inputs + if inp.type.broadcastable[:batch_ndim] == out_batch_bcast + ], + protected_inputs=protected_inputs, + ) ) - ] - if not allowed_inplace_inputs: - return None + allowed_inplace_inputs = [ + i for i, inp in enumerate(inputs) if inp in candidate_inputs + ] + destroy_map = blockwise_op.core_op.inplace_on_inputs( + allowed_inplace_inputs=allowed_inplace_inputs + ).destroy_map + + if not destroy_map: + return [] + + outputs = node.outputs + return [ + ((out_idx, outputs[out_idx]), (inp_idx, inputs[inp_idx])) + for out_idx, inp_idxs in destroy_map.items() + for inp_idx in inp_idxs + ] - inplace_core_op = blockwise_op.core_op.inplace_on_inputs( - allowed_inplace_inputs=allowed_inplace_inputs - ) + def create_inplace_node(self, node, inplace_pattern): + blockwise_op = node.op + allowed_inplace_inputs = tuple(v[0] for v in inplace_pattern.values()) + inplace_core_op = blockwise_op.core_op.inplace_on_inputs( + allowed_inplace_inputs=allowed_inplace_inputs + ) - if not inplace_core_op.destroy_map: - return None + if not inplace_core_op.destroy_map: + return node - # Check Op is not trying to inplace on non-candidate inputs - for destroyed_inputs in inplace_core_op.destroy_map.values(): - for destroyed_input in destroyed_inputs: - if destroyed_input not in allowed_inplace_inputs: - raise ValueError( - f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}" - ) + # Check Op is not trying to inplace on non-candidate inputs + for destroyed_inputs in inplace_core_op.destroy_map.values(): + for destroyed_input in destroyed_inputs: + if destroyed_input not in allowed_inplace_inputs: + raise ValueError( + f"Op {blockwise_op.core_op} destroy_map does not respect allowed_inplace_inputs {allowed_inplace_inputs}" + ) - # Recreate core_op with inplace - inplace_blockwise_op = Blockwise( - core_op=inplace_core_op, - signature=blockwise_op.signature, - name=blockwise_op.name, - gufunc_spec=blockwise_op.gufunc_spec, - destroy_map=inplace_core_op.destroy_map, - ) + # Recreate core_op with inplace + inplace_blockwise_op = type(blockwise_op)( + core_op=inplace_core_op, + signature=blockwise_op.signature, + name=blockwise_op.name, + gufunc_spec=blockwise_op.gufunc_spec, + destroy_map=inplace_core_op.destroy_map, + ) - out = inplace_blockwise_op.make_node(*node.inputs).outputs - copy_stack_trace(node.outputs, out) - return out + return inplace_blockwise_op.make_node(*node.inputs) optdb.register( "blockwise_inplace", - in2out(blockwise_inplace), + InplaceBlockwiseOptimizer(), "fast_run", "inplace", position=50.1, diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 98fc4e074c..afe69a198b 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -1,19 +1,21 @@ +import abc import itertools import operator import sys -from collections import Counter, defaultdict, deque -from collections.abc import Generator +from collections import defaultdict, deque +from collections.abc import Generator, Sequence from functools import cache, reduce from typing import TypeVar from warnings import warn -import pytensor import pytensor.scalar.basic as ps from pytensor import clone_replace, compile +from pytensor.compile.function.types import Supervisor from pytensor.compile.mode import get_target_language from pytensor.configdefaults import config -from pytensor.graph import FunctionGraph -from pytensor.graph.basic import Apply, Constant, Variable, ancestors, io_toposort +from pytensor.graph import FunctionGraph, Op +from pytensor.graph.basic import Apply, Variable, ancestors +from pytensor.graph.destroyhandler import DestroyHandler, inplace_candidates from pytensor.graph.features import ReplaceValidate from pytensor.graph.fg import Output from pytensor.graph.rewriting.basic import ( @@ -43,47 +45,34 @@ register_specialize, ) from pytensor.tensor.shape import shape_padleft -from pytensor.tensor.variable import TensorConstant +from pytensor.tensor.variable import TensorConstant, TensorVariable -class InplaceElemwiseOptimizer(GraphRewriter): - r""" - This is parameterized so that it works for `Elemwise` `Op`\s. - """ - - def __init__(self, OP): - self.op = OP +class InplaceGraphOptimizer(GraphRewriter): + op: type[Op] def add_requirements(self, fgraph): - from pytensor.graph.destroyhandler import DestroyHandler - fgraph.attach_feature(DestroyHandler()) - @classmethod - def print_profile(cls, stream, prof, level=0): - blanc = " " * level - print(blanc, cls.__name__, prof["opt"].op, file=stream) - for k in [ - "node_before", - "nb_call_replace", - "nb_call_validate", - "nb_inconsistent", - ]: - print(blanc, k, prof[k], file=stream) - ndim = prof["ndim"] - if ndim: - print(blanc, "ndim", "nb", file=stream) - for n in sorted(ndim): - print(blanc, n, ndim[n], file=stream) + @abc.abstractmethod + def filter_candidate_pairs( + self, fgraph: FunctionGraph, node: Apply, protected_inputs: Sequence[Variable] + ) -> Sequence[tuple[tuple[int, Variable], tuple[int, Variable]]]: + pass + + @abc.abstractmethod + def create_inplace_node( + self, node: Apply, inplace_pattern: dict[int, Sequence[int]] + ) -> Apply: + pass def apply(self, fgraph): r""" - Attempts to replace all `Elemwise`\s by versions of them that operate - inplace. It operates greedily: for each `Elemwise` that is encountered, - for each output, it tries each input to see if it can operate inplace - on that input. If so, it makes the change and goes to the next output - or `Elemwise`. + Attempts to replace all `Op`\s by versions of them that operate + inplace. It operates greedily: for each `Op` that is encountered, + it tries to inplace all the valid inputs at once (if the Op supports it), + if that fails, it tries to inplace one input at a time. Examples -------- @@ -92,8 +81,7 @@ def apply(self, fgraph): (x + y) * (x * y) -> (x += y) *= (x * y) or (x + y) *= (x *= y) """ - # We should not validate too often as this takes too much time to - # execute! + # We should not validate too often as this takes too much time to execute! # It is the _dfs_toposort() fct in pytensor/graph/destroyhandler.py # that takes so much time. # Should we try to use another lib that does toposort? @@ -111,244 +99,207 @@ def apply(self, fgraph): # Then I think it is the [io_?]toposort (need to validate) so check if # the solution is also applicable there. - # We execute `validate` after this number of change. + # 2025: The above comment is not specific to Elemwise, if we have concerns about this approach, we should + # tackle them in a more general way. The whole try/except approach is probably suboptimal. + # We can consider restricting inputs with static shapes that are large enough. + + if config.tensor__insert_inplace_optimizer_validate_nb != -1: + warn( + "tensor__insert_inplace_optimizer_validate_nb config is deprecated. Setting it will fail in a future release.", + FutureWarning, + ) + + reason = f"{self.op}_inplace_optimizer" prof = { "opt": self, "node_before": len(fgraph.apply_nodes), - "nb_call_replace": 0, - "nb_call_validate": 0, + "nb_eager_inconsistent": 0, "nb_inconsistent": 0, - "ndim": Counter(), + "nb_replaced": 0, } + large_graph = len(fgraph.apply_nodes) > 500 - check_each_change = config.tensor__insert_inplace_optimizer_validate_nb - if check_each_change == -1: - if len(fgraph.apply_nodes) > 500: - check_each_change = 10 - else: - check_each_change = 1 - - nb_change_no_validate = 0 - chk = fgraph.checkpoint() - - if fgraph.update_mapping: - update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping] - else: - update_outs = [] - - Supervisor = pytensor.compile.function.types.Supervisor - protected_inputs = list( + protected_inputs = set( itertools.chain.from_iterable( f.protected for f in fgraph._features if isinstance(f, Supervisor) ) ) - protected_inputs.extend(fgraph.outputs) - for node in list(io_toposort(fgraph.inputs, fgraph.outputs)): - op = node.op - if not isinstance(op, self.op): - continue - # If big graph and the outputs are scalar, do not make it - # inplace. + protected_inputs.update(fgraph.outputs) + root_destroyer = fgraph.destroy_handler.root_destroyer + + self_op = self.op + update_mapping = fgraph.update_mapping or {} + op_updates: dict[TensorVariable, TensorVariable] = { + out: fgraph.inputs[update_mapping[out_idx]] + for out_idx, out in enumerate(fgraph.outputs) if ( - check_each_change != 1 - and - # If multiple outputs, they must all have the same size, - # so only check the first. - getattr(node.outputs[0].type, "ndim", -1) == 0 - ): + out_idx in update_mapping + and out.owner + and isinstance(out.owner.op, self_op) + ) + } + set_op_updates = set(op_updates.keys()) + + for node in fgraph.toposort(): + if not isinstance(node.op, self_op) or node.op.destroy_map: continue - if op.inplace_pattern: - # Maybe this isn't needed anymore, but I don't want to - # rish regression now. This case only happen if the - # original node add already some inplace patter and we - # still try to add more pattern. + # If big graph and the outputs are scalar, do not make it inplace. + if large_graph and all(node.outputs[0].type.broadcastable): + continue - baseline = op.inplace_pattern - candidate_outputs = [ - i for i in range(len(node.outputs)) if i not in baseline - ] - # node inputs that are Constant, already destroyed, - # or fgraph protected inputs and fgraph outputs can't be used as - # inplace target. - # Remove here as faster. - candidate_inputs = [ - i - for i in range(len(node.inputs)) - if i not in baseline.values() - and not isinstance(node.inputs[i], Constant) - # the next line should not be costly most of the time. - and not fgraph.has_destroyers([node.inputs[i]]) - and node.inputs[i] not in protected_inputs - ] - else: - baseline = [] - candidate_outputs = range(len(node.outputs)) - # node inputs that are Constant, already destroyed, - # fgraph protected inputs and fgraph outputs can't be used as inplace - # target. - # Remove here as faster. - candidate_inputs = [ - i - for i in range(len(node.inputs)) - if not isinstance(node.inputs[i], Constant) - and not fgraph.has_destroyers([node.inputs[i]]) - and node.inputs[i] not in protected_inputs - ] + candidate_pairs = self.filter_candidate_pairs( + fgraph, node, protected_inputs + ) - verbose = False - - raised_warning = not verbose - - for candidate_output in candidate_outputs: - # If the output of the node can be established as an update - # output of the fgraph, visit the candidate_inputs in an order - # that will improve the chances of making the node operate - # inplace on the input it's meant to update - candidate_out_var = node.outputs[candidate_output] - sorted_candidate_inputs = candidate_inputs - - if candidate_out_var in update_outs: - # The candidate output is an update. Sort the - # variables in candidate_inputs in the following order: - # - Vars corresponding to the actual updated input - # (best case scenario is for the node that procudes - # an update to operate inplace on the variable to - # update) - # - Vars computed inplace on the updates input (second - # best scenario if for the node to work inplace on - # a variable obtained by a chain of inplace on the - # variable to update. In some cases, this will be - # equivalent to operating inplace on the variable to - # update) - # - Remaining variables - updated_inputs = [] - for i, f_out in enumerate(fgraph.outputs): - if f_out is candidate_out_var and i in fgraph.update_mapping: - updated_inp_idx = fgraph.update_mapping[i] - updated_inputs.append(fgraph.inputs[updated_inp_idx]) - - updated_vars = [] - vars_from_inplace = [] - other_vars = [] - for inp_idx in candidate_inputs: - inp = node.inputs[inp_idx] - if inp in updated_inputs: - # the candidate input is the actual updated input - updated_vars.append(inp_idx) - elif ( - hasattr(fgraph, "destroy_handler") - and inp.owner - and any( - fgraph.destroy_handler.root_destroyer.get(up_inp, None) - is inp.owner - for up_inp in updated_inputs - ) + if not candidate_pairs: + continue + + sorted_candidate_pairs = candidate_pairs + if op_updates and (node_updates := set(node.outputs) & set_op_updates): + # If the fgraph has updates, we try to prioritize in-placing on the pairs that correspond to the update + direct_update_pairs = [] + indirect_update_pairs = [] + other_update_pairs = [] + for pair in candidate_pairs: + ((o, out), (i, inp)) = pair + if out in node_updates: + direct_update_inp = op_updates[out] + if direct_update_inp is inp: + # This pair is the whole graph update + direct_update_pairs.append(pair) + continue + elif (inp_node := inp.owner) is not None and any( + root_destroyer.get(up_inp, None) is inp_node + for up_inp in op_updates.values() ): - # the candidate input is a variable computed - # inplace on the updated input via a sequence of - # one or more inplace operations - vars_from_inplace.append(inp_idx) - else: - other_vars.append(inp_idx) + # This pair connects to an updated input + indirect_update_pairs.append(pair) + continue + other_update_pairs.append(pair) - sorted_candidate_inputs = ( - updated_vars + vars_from_inplace + other_vars - ) + sorted_candidate_pairs = ( + direct_update_pairs + indirect_update_pairs + other_update_pairs + ) - for candidate_input in sorted_candidate_inputs: - # remove inputs that don't have the same dtype as the output - if ( - node.inputs[candidate_input].type - != node.outputs[candidate_output].type - ): - continue + # Try in-placing all outputs at once + tried_inputs = set() + inplace_pattern = {} + for (o, _), (i, _) in sorted_candidate_pairs: + if o not in inplace_pattern and i not in tried_inputs: + inplace_pattern[o] = [i] + tried_inputs.add(i) + + inplace_node = self.create_inplace_node(node, inplace_pattern) + if inplace_node.op.destroy_map == inplace_pattern: + replacements = tuple(zip(node.outputs, inplace_node.outputs)) + try: + fgraph.replace_all_validate(replacements, reason=reason) + except InconsistencyError: + prof["nb_eager_inconsistent"] += 1 + else: + prof["nb_replaced"] += 1 + copy_stack_trace(node.outputs, inplace_node.outputs) + continue - inplace_pattern = dict(baseline) - inplace_pattern[candidate_output] = candidate_input - try: - if hasattr(op.scalar_op, "make_new_inplace"): - new_scal = op.scalar_op.make_new_inplace( - ps.transfer_type( - *[ - inplace_pattern.get(i, o.dtype) - for i, o in enumerate(node.outputs) - ] - ) - ) - else: - new_scal = op.scalar_op.__class__( - ps.transfer_type( - *[ - inplace_pattern.get(i, None) - for i in range(len(node.outputs)) - ] - ) - ) - new_outputs = self.op(new_scal, inplace_pattern)( - *node.inputs, return_list=True - ) - new_node = new_outputs[0].owner + # If it fails or doesn't match the desired inplace pattern, try one output/input at a time + tried_inputs = set() + inplace_pattern = {} + replaced = False + original_node = node + for (o, _), (i, _) in sorted_candidate_pairs: + if o not in inplace_pattern and i not in tried_inputs: + inplace_pattern[o] = [i] + tried_inputs.add(i) + + inplace_node = self.create_inplace_node(node, inplace_pattern) + if inplace_node.op.destroy_map != inplace_pattern: + # This Op can't respect this partial inplace pattern, + # We assume it can't support any other cases + break + else: + replacements = tuple(zip(node.outputs, inplace_node.outputs)) + try: + fgraph.replace_all_validate(replacements, reason=reason) + node = inplace_node + replaced = True + except InconsistencyError: + prof["nb_inconsistent"] += 1 + # The input, not the output caused inconsistencies + inplace_pattern.pop(o) + if replaced: + copy_stack_trace(original_node.outputs, node.outputs) + prof["nb_replaced"] += replaced - for r, new_r in zip(node.outputs, new_outputs, strict=True): - prof["nb_call_replace"] += 1 - fgraph.replace( - r, new_r, reason="inplace_elemwise_optimizer" - ) - nb_change_no_validate += 1 - prof["ndim"][candidate_out_var.ndim] += 1 - if nb_change_no_validate >= check_each_change: - prof["nb_call_validate"] += 1 - fgraph.validate() - chk = fgraph.checkpoint() - nb_change_no_validate = 0 - except (ValueError, InconsistencyError) as e: - prof["nb_inconsistent"] += 1 - if check_each_change != 1 and not raised_warning: - print( # noqa: T201 - ( - "Some inplace rewriting was not " - "performed due to an unexpected error:" - ), - file=sys.stderr, - ) - print(e, file=sys.stderr) # noqa: T201 - raised_warning = True - fgraph.revert(chk) - continue - candidate_inputs.remove(candidate_input) - node = new_node - baseline = inplace_pattern - break - - if nb_change_no_validate > 0: - try: - fgraph.validate() - except Exception: - if not raised_warning: - print( # noqa: T201 - ( - "Some inplace rewriting was not " - "performed due to an unexpected error" - ), - file=sys.stderr, - ) - fgraph.revert(chk) return prof + @classmethod + def print_profile(cls, stream, prof, level=0): + blanc = " " * level + print(blanc, cls.__name__, file=stream) + for k in [ + "node_before", + "nb_eager_inconsistent", + "nb_inconsistent", + "nb_replaced", + ]: + print(blanc, k, prof[k], file=stream) + def print_summary(self, stream=sys.stdout, level=0, depth=-1): print( - f"{' ' * level}{self.__class__.__name__} ({self.op})", + f"{' ' * level}{self.__class__.__name__}", file=stream, ) - return inplace_elemwise_optimizer -inplace_elemwise_optimizer = InplaceElemwiseOptimizer(Elemwise) +class InplaceElemwiseOptimizer(InplaceGraphOptimizer): + op = Elemwise + + def filter_candidate_pairs(self, fgraph, node, protected_inputs): + candidate_inputs = [ + (node.inputs.index(inp), inp) + for inp in inplace_candidates( + fgraph, + node.inputs, + protected_inputs=protected_inputs, + ) + ] + if not candidate_inputs: + return [] + + return [ + ((o, out), (i, inp)) + for o, out in enumerate(node.outputs) + for i, inp in candidate_inputs + if inp.type == out.type + ] + + def create_inplace_node(self, node, inplace_pattern): + op = node.op + scalar_op = op.scalar_op + inplace_pattern = {i: o for i, [o] in inplace_pattern.items()} + if hasattr(scalar_op, "make_new_inplace"): + new_scalar_op = scalar_op.make_new_inplace( + ps.transfer_type( + *[ + inplace_pattern.get(i, o.dtype) + for i, o in enumerate(node.outputs) + ] + ) + ) + else: + new_scalar_op = type(scalar_op)( + ps.transfer_type( + *[inplace_pattern.get(i, None) for i in range(len(node.outputs))] + ) + ) + return type(op)(new_scalar_op, inplace_pattern).make_node(*node.inputs) + + compile.optdb.register( - "inplace_elemwise_opt", - inplace_elemwise_optimizer, - "inplace_opt", # for historic reason + "inplace_elemwise", + InplaceElemwiseOptimizer(), + "inplace_elemwise_opt", # for historic reason "inplace_elemwise_optimizer", "fast_run", "inplace", diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 7833cecf91..e259b7d1a6 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -8,6 +8,7 @@ from pytensor import scalar as ps from pytensor import tensor as pt from pytensor.compile.function import function +from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config from pytensor.gradient import grad @@ -1529,3 +1530,31 @@ def test_constant_fold_branches_add_mul(op): new_out = rewrite_graph(out, include=("add_mul_fusion",)) assert len(new_out.owner.inputs) == 3 assert equal_computations([new_out], [op(py_op(a, b), c, x)]) + + +def test_InplaceElemwiseOptimizer_bug(): + # Regression test for https://github.com/pymc-devs/pytensor/issues/1420 + + # This graph fails if InplaceElemwiseOptimizer were to try to skip `fgraph.validate` + # in between two invalid inplace rewrites. + z = pt.matrix("z") + + z1 = ps.float64("z1") + z2 = ps.float64("z2") + out1, out2 = Elemwise(ps.Composite([z1, z2], [z1 + z2, z2 - z1]))(z[1:], z[:-1]) + out = pt.exp(z[1:-1]).sum() + out1.sum() + out2.sum() + + # Add 500 unrelated nodes to trigger the old special behavior + irrelevant_outs = [pt.specify_shape(z, (4, 4)) for _ in range(500)] + + fgraph = FunctionGraph(inputs=[z], outputs=[out, *irrelevant_outs], clone=False) + add_supervisor_to_fgraph(fgraph, [In(z)]) + # with config.change_flags(tensor__insert_inplace_optimizer_validate_nb=10): + rewrite_graph(fgraph, include=("inplace",)) + + pytensor.config.tensor__insert_inplace_optimizer_validate_nb = 1 + with pytest.warns( + FutureWarning, + match="tensor__insert_inplace_optimizer_validate_nb config is deprecated", + ): + rewrite_graph(fgraph, include=("inplace",)) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 9d48f310fe..c2df7e9699 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -8,11 +8,21 @@ import pytensor from pytensor import In, config, function, scan from pytensor.compile import get_default_mode, get_mode +from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.gradient import grad -from pytensor.graph import Apply, Op +from pytensor.graph import Apply, FunctionGraph, Op, rewrite_graph from pytensor.graph.replace import vectorize_graph, vectorize_node from pytensor.raise_op import assert_op -from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector +from pytensor.tensor import ( + diagonal, + dmatrix, + log, + matrices, + ones_like, + scalar, + tensor, + vector, +) from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback from pytensor.tensor.nlinalg import MatrixInverse from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot @@ -698,3 +708,57 @@ def test_scan_gradient_core_type(): grad_sit_sot0.eval({vec_seq: np.ones((4, n_steps, 1))}), np.ones((4, n_steps, 1)), ) + + +def test_partial_inplace(): + class CoreOp(Op): + __props__ = ("inplace",) + + def __init__(self, inplace): + self.inplace = tuple(inplace) + self.destroy_map = {i: [i] for i in inplace} + + def inplace_on_inputs(self, allowed_inplace_inputs): + return type(self)(inplace=allowed_inplace_inputs) + + def make_node(self, x, y, z): + return Apply(self, [x, y, z], [x.type(), y.type(), z.type()]) + + def perform(self, node, inputs, outputs): + [x, y, z] = inputs + if 0 not in self.inplace: + x = x.copy() + if 1 not in self.inplace: + y = y.copy() + if 2 not in self.inplace: + z = z.copy() + outputs[0][0] = x + outputs[1][0] = y + outputs[2][0] = z + + core_op = CoreOp(inplace=()) + blockwise_op = Blockwise(core_op, signature="(),(),()->(),(),()") + x, y, z = matrices("xyz") + + # All can be inplaced + out = blockwise_op(x.T, y.T, z.T) + fgraph = FunctionGraph([x, y, z], out) + add_supervisor_to_fgraph(fgraph, [In(inp, mutable=True) for inp in fgraph.inputs]) + rewrite_graph(fgraph, include=("inplace",)) + assert fgraph.outputs[0].owner.op.destroy_map == {0: [0], 1: [1], 2: [2]} + + # Only x, z can be inplaced, y is protected + out = blockwise_op(x.T, y.T, z.T) + fgraph = FunctionGraph([x, y, z], out) + add_supervisor_to_fgraph( + fgraph, [In(inp, mutable=(i % 2) == 0) for i, inp in enumerate(fgraph.inputs)] + ) + rewrite_graph(fgraph, include=("inplace",)) + assert fgraph.outputs[0].owner.op.destroy_map == {0: [0], 2: [2]} + + # Only y can be inplaced, x is reused for first and third outputs + out = blockwise_op(x.T, y.T, x.T) + fgraph = FunctionGraph([x, y, z], out) + add_supervisor_to_fgraph(fgraph, [In(inp, mutable=True) for inp in fgraph.inputs]) + rewrite_graph(fgraph, include=("inplace",)) + assert fgraph.outputs[0].owner.op.destroy_map == {1: [1]}