diff --git a/pytensor/compile/function/pfunc.py b/pytensor/compile/function/pfunc.py index ed17715e04..4c45b07ebc 100644 --- a/pytensor/compile/function/pfunc.py +++ b/pytensor/compile/function/pfunc.py @@ -4,7 +4,7 @@ """ from copy import copy -from typing import Optional +from typing import Optional, Sequence, Union, overload from pytensor.compile.function.types import Function, UnusedInputError, orig_function from pytensor.compile.io import In, Out @@ -15,8 +15,9 @@ from pytensor.graph.fg import FunctionGraph +@overload def rebuild_collect_shared( - outputs, + outputs: Variable, inputs=None, replace=None, updates=None, @@ -24,7 +25,107 @@ def rebuild_collect_shared( copy_inputs_over=True, no_default_updates=False, clone_inner_graphs=False, -): +) -> tuple[ + list[Variable], + Variable, + tuple[ + dict[Variable, Variable], + dict[SharedVariable, Variable], + list[Variable], + list[SharedVariable], + ], +]: + ... + + +@overload +def rebuild_collect_shared( + outputs: Sequence[Variable], + inputs=None, + replace=None, + updates=None, + rebuild_strict=True, + copy_inputs_over=True, + no_default_updates=False, + clone_inner_graphs=False, +) -> tuple[ + list[Variable], + list[Variable], + tuple[ + dict[Variable, Variable], + dict[SharedVariable, Variable], + list[Variable], + list[SharedVariable], + ], +]: + ... + + +@overload +def rebuild_collect_shared( + outputs: Out, + inputs=None, + replace=None, + updates=None, + rebuild_strict=True, + copy_inputs_over=True, + no_default_updates=False, + clone_inner_graphs=False, +) -> tuple[ + list[Variable], + Out, + tuple[ + dict[Variable, Variable], + dict[SharedVariable, Variable], + list[Variable], + list[SharedVariable], + ], +]: + ... + + +@overload +def rebuild_collect_shared( + outputs: Sequence[Out], + inputs=None, + replace=None, + updates=None, + rebuild_strict=True, + copy_inputs_over=True, + no_default_updates=False, + clone_inner_graphs=False, +) -> tuple[ + list[Variable], + list[Out], + tuple[ + dict[Variable, Variable], + dict[SharedVariable, Variable], + list[Variable], + list[SharedVariable], + ], +]: + ... + + +def rebuild_collect_shared( + outputs: Union[Sequence[Variable], Variable, Out, Sequence[Out]], + inputs=None, + replace=None, + updates=None, + rebuild_strict=True, + copy_inputs_over=True, + no_default_updates=False, + clone_inner_graphs=False, +) -> tuple[ + list[Variable], + Union[list[Variable], Variable, Out, list[Out]], + tuple[ + dict[Variable, Variable], + dict[SharedVariable, Variable], + list[Variable], + list[SharedVariable], + ], +]: r"""Replace subgraphs of a computational graph. It returns a set of dictionaries and lists which collect (partial?) @@ -260,7 +361,7 @@ def clone_inputs(i): return ( input_variables, cloned_outputs, - [clone_d, update_d, update_expr, shared_inputs], + (clone_d, update_d, update_expr, shared_inputs), ) diff --git a/pytensor/graph/replace.py b/pytensor/graph/replace.py index 2213c70578..d16f4119ba 100644 --- a/pytensor/graph/replace.py +++ b/pytensor/graph/replace.py @@ -1,27 +1,56 @@ from functools import partial -from typing import ( - Collection, - Dict, - Iterable, - List, - Optional, - Sequence, - Tuple, - Union, - cast, -) - -from pytensor.graph.basic import Constant, Variable, truncated_graph_inputs +from typing import Iterable, Optional, Sequence, Union, cast, overload + +from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs from pytensor.graph.fg import FunctionGraph +ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]] + + +def _format_replace(replace: Optional[ReplaceTypes] = None) -> dict[Variable, Variable]: + items: dict[Variable, Variable] + if isinstance(replace, dict): + # PyLance has issues with type resolution + items = cast(dict[Variable, Variable], replace) + elif isinstance(replace, Iterable): + items = dict(replace) + elif replace is None: + items = {} + else: + raise ValueError( + "replace is neither a dictionary, list, " + f"tuple or None ! The value provided is {replace}," + f"of type {type(replace)}" + ) + return items + + +@overload +def clone_replace( + output: Sequence[Variable], + replace: Optional[ReplaceTypes] = None, + **rebuild_kwds, +) -> list[Variable]: + ... + + +@overload def clone_replace( - output: Collection[Variable], + output: Variable, replace: Optional[ - Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]] + Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]] ] = None, **rebuild_kwds, -) -> List[Variable]: +) -> Variable: + ... + + +def clone_replace( + output: Union[Sequence[Variable], Variable], + replace: Optional[ReplaceTypes] = None, + **rebuild_kwds, +) -> Union[list[Variable], Variable]: """Clone a graph and replace subgraphs within it. It returns a copy of the initial subgraph with the corresponding @@ -39,19 +68,8 @@ def clone_replace( """ from pytensor.compile.function.pfunc import rebuild_collect_shared - items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]] - if isinstance(replace, dict): - items = list(replace.items()) - elif isinstance(replace, (list, tuple)): - items = replace - elif replace is None: - items = [] - else: - raise ValueError( - "replace is neither a dictionary, list, " - f"tuple or None ! The value provided is {replace}," - f"of type {type(replace)}" - ) + items = list(_format_replace(replace).items()) + tmp_replace = [(x, x.type()) for x, y in items] new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)] _, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds) @@ -59,20 +77,40 @@ def clone_replace( # TODO Explain why we call it twice ?! _, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds) - return cast(List[Variable], outs) + return outs +@overload +def graph_replace( + outputs: Variable, + replace: Optional[ReplaceTypes] = None, + *, + strict=True, +) -> Variable: + ... + + +@overload def graph_replace( outputs: Sequence[Variable], - replace: Dict[Variable, Variable], + replace: Optional[ReplaceTypes] = None, + *, + strict=True, +) -> list[Variable]: + ... + + +def graph_replace( + outputs: Union[Sequence[Variable], Variable], + replace: Optional[ReplaceTypes] = None, *, strict=True, -) -> List[Variable]: +) -> Union[list[Variable], Variable]: """Replace variables in ``outputs`` by ``replace``. Parameters ---------- - outputs: Sequence[Variable] + outputs: Union[Sequence[Variable], Variable] Output graph replace: Dict[Variable, Variable] Replace mapping @@ -83,20 +121,26 @@ def graph_replace( Returns ------- - List[Variable] - Output graph with subgraphs replaced + Union[Variable, List[Variable]] + Output graph with subgraphs replaced, see function overload for the exact type Raises ------ ValueError - If some replacemens could not be applied and strict is True + If some replacements could not be applied and strict is True """ + as_list = False + if not isinstance(outputs, Sequence): + outputs = [outputs] + else: + as_list = True + replace_dict = _format_replace(replace) # collect minimum graph inputs which is required to compute outputs # and depend on replacements # additionally remove constants, they do not matter in clone get equiv conditions = [ c - for c in truncated_graph_inputs(outputs, replace) + for c in truncated_graph_inputs(outputs, replace_dict) if not isinstance(c, Constant) ] # for the function graph we need the clean graph where @@ -117,7 +161,7 @@ def graph_replace( # replace the conditions back fg_replace = {equiv[c]: c for c in conditions} # add the replacements on top of input mappings - fg_replace.update({equiv[r]: v for r, v in replace.items() if r in equiv}) + fg_replace.update({equiv[r]: v for r, v in replace_dict.items() if r in equiv}) # replacements have to be done in reverse topological order so that nested # expressions get recursively replaced correctly @@ -126,12 +170,14 @@ def graph_replace( # So far FunctionGraph does these replacements inplace it is thus unsafe # apply them using fg.replace, it may change the original graph if strict: - non_fg_replace = {r: v for r, v in replace.items() if r not in equiv} + non_fg_replace = {r: v for r, v in replace_dict.items() if r not in equiv} if non_fg_replace: raise ValueError(f"Some replacements were not used: {non_fg_replace}") toposort = fg.toposort() - def toposort_key(fg: FunctionGraph, ts, pair): + def toposort_key( + fg: FunctionGraph, ts: list[Apply], pair: tuple[Variable, Variable] + ) -> int: key, _ = pair if key.owner is not None: return ts.index(key.owner) @@ -148,4 +194,7 @@ def toposort_key(fg: FunctionGraph, ts, pair): reverse=True, ) fg.replace_all(sorted_replacements, import_missing=True) - return list(fg.outputs) + if as_list: + return list(fg.outputs) + else: + return fg.outputs[0] diff --git a/tests/graph/test_replace.py b/tests/graph/test_replace.py index 487aa87f4b..7fc0e530f9 100644 --- a/tests/graph/test_replace.py +++ b/tests/graph/test_replace.py @@ -169,6 +169,17 @@ def test_graph_replace(self): # the old reference is still kept assert oc.owner.inputs[0].owner.inputs[1] is w + def test_non_list_input(self): + x = MyVariable("x") + y = MyVariable("y") + o = MyOp("xyop")(x, y) + new_x = x.clone(name="x_new") + new_y = y.clone(name="y2_new") + # test non list inputs as well + oc = graph_replace(o, {x: new_x, y: new_y}) + assert oc.owner.inputs[1] is new_y + assert oc.owner.inputs[0] is new_x + def test_graph_replace_advanced(self): x = MyVariable("x") y = MyVariable("y")