From cfe2462b969644f96169ff2abf2fec27369619ed Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 1 Apr 2025 16:31:43 -0700 Subject: [PATCH 1/7] Update [ghstack-poisoned] --- exir/backend/canonical_partitioners/TARGETS | 1 + .../all_node_partitioner.py | 55 ++++++ exir/backend/test/test_backends.py | 179 ++++++++++++++++++ exir/backend/test/test_backends_lifted.py | 15 ++ exir/backend/test/test_compatibility.py | 49 +++++ .../backend/test/test_delegate_map_builder.py | 5 + exir/program/TARGETS | 1 + exir/program/_program.py | 38 +++- 8 files changed, 333 insertions(+), 10 deletions(-) create mode 100644 exir/backend/canonical_partitioners/all_node_partitioner.py diff --git a/exir/backend/canonical_partitioners/TARGETS b/exir/backend/canonical_partitioners/TARGETS index 22a6e2c51bd..8d3e28968b3 100644 --- a/exir/backend/canonical_partitioners/TARGETS +++ b/exir/backend/canonical_partitioners/TARGETS @@ -7,6 +7,7 @@ runtime.python_library( srcs = [ "duplicate_dequant_node_pass.py", "pattern_op_partitioner.py", + "all_node_partitioner.py", ], visibility = [ "//executorch/...", diff --git a/exir/backend/canonical_partitioners/all_node_partitioner.py b/exir/backend/canonical_partitioners/all_node_partitioner.py new file mode 100644 index 00000000000..bc45f2b5239 --- /dev/null +++ b/exir/backend/canonical_partitioners/all_node_partitioner.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List + +import torch +from executorch.exir.backend.backend_details import ExportedProgram +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param + + +def is_non_tensor_placeholder(node: torch.fx.Node, ep: ExportedProgram) -> bool: + """ + Returns true if the node is a placeholder node and it is not a tensor + """ + return node.op == "placeholder" and not ( + is_param(ep, node) or is_buffer(ep, node) or is_lifted_tensor_constant(ep, node) + ) + + +class AllNodePartitioner(Partitioner): + def __init__( + self, + backend_id: str, + compile_specs: List[CompileSpec], + ): + """ + Partitioner that lowers every single node in the graph module to the + specified backend_id + """ + super().__init__() + self.delegation_spec = DelegationSpec(backend_id, compile_specs) + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + # tag all nodes + partition_tags: Dict[str, DelegationSpec] = {} + for node in exported_program.graph_module.graph.nodes: + if is_non_tensor_placeholder(node, exported_program) or node.op == "output": + continue + + delegation_tag = self.delegation_spec.backend_id + node.meta["delegation_tag"] = delegation_tag + partition_tags[delegation_tag] = self.delegation_spec + + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index d2bcfa31676..b5a38d875c2 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -10,7 +10,11 @@ import executorch.exir as exir import torch +from executorch.exir import to_edge from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, @@ -1266,3 +1270,178 @@ def forward(self, x: List[torch.Tensor]): gm = exir.capture(ComposedM(), inputs, exir.CaptureConfig()).to_edge() gm(*inputs) + + def test_to_backend_delegation_spec(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return [torch.sin(x)] + + sin_module = SinModule() + model_inputs = (torch.ones(1),) + max_value = model_inputs[0].shape[0] + + partitioner = AllNodePartitioner( + "BackendWithCompilerDemo", [CompileSpec("max_value", bytes([max_value]))] + ) + + edgeir_m = to_edge(torch.export.export(sin_module, model_inputs)) + edgeir_m = edgeir_m.to_backend(partitioner) + exec_prog = edgeir_m.to_executorch() + graph_module = exec_prog.exported_program().graph_module + # Check that there is not an aten.sin node. + self.assertTrue( + exir_ops.edge.aten.sin + not in {node.target for node in graph_module.graph.nodes} + ) + + # Check that there exists a call_delegate, representing the call to the + # delegated function + FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( + graph_module.code + ) + lowered_submodules = get_lowered_submodules(graph_module) + self.assertEqual(len(lowered_submodules), 1) + + for node in graph_module.graph.nodes: + if node.op == "call_function" and node.target == executorch_call_delegate: + # Check that first arg is lowered_module_{unique_id} + self.assertEqual(node.args[0].target, "lowered_module_0") + + program = exec_prog.executorch_program + + # Check the program can be printed + print_program(program) + + # Check the backend delegate + self.check_backend_delegate( + program=program, + delegate=program.execution_plan[0].delegates[0], + expected_id=BackendWithCompilerDemo.__name__, + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + ) + + # Check the delegate instruction + self.assertTrue( + isinstance( + program.execution_plan[0].chains[0].instructions[0].instr_args, + DelegateCall, + ) + ) + buff = exec_prog.buffer + + executorch_module = _load_for_executorch_from_buffer(buff) + model_inputs = torch.ones(1) + model_outputs = executorch_module.forward([model_inputs]) + self.assertEqual( + model_inputs, + torch.ones(1), + ) + expected_output = 0.8333 * torch.ones(1) + + self.assertTrue( + torch.allclose(model_outputs[0], expected_output, atol=1e-03, rtol=1e-03) + ) + + def test_to_backend_multimethod_delegation_spec(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + def inputs(self): + return (torch.ones(1),) + + class AddMulModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a, x, b): + y = torch.mm(a, x) + z = torch.add(y, b) + return z + + def inputs(self): + return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) + + sin_module = SinModule() + max_value_sin = sin_module.inputs()[0].shape[0] + sin_partitioner = AllNodePartitioner( + "BackendWithCompilerDemo", + [CompileSpec("max_value", bytes([max_value_sin]))], + ) + + add_mul_module = AddMulModule() + max_value_add_mul = add_mul_module.inputs()[0].shape[0] + add_mul_partitioner = AllNodePartitioner( + "BackendWithCompilerDemo", + [CompileSpec("max_value", bytes([max_value_add_mul]))], + ) + + edgeir_m = to_edge( + { + "sin": torch.export.export(sin_module, sin_module.inputs()), + "add_mul": torch.export.export(add_mul_module, add_mul_module.inputs()), + } + ) + edgeir_m = edgeir_m.to_backend( + { + "sin": sin_partitioner, + "add_mul": add_mul_partitioner, + } + ) + exec_prog = edgeir_m.to_executorch() + + for method_name in ["sin", "add_mul"]: + graph_module = exec_prog.exported_program(method_name).graph_module + # Check delegated nodes are gone + self.assertTrue( + exir_ops.edge.aten.sin + not in {node.target for node in graph_module.graph.nodes} + ) + self.assertTrue( + exir_ops.edge.aten.add + not in {node.target for node in graph_module.graph.nodes} + ) + self.assertTrue( + exir_ops.edge.aten.mm + not in {node.target for node in graph_module.graph.nodes} + ) + # Check that there exists a call_delegate, representing the call to the + # delegated function + FileCheck().check("torch.ops.higher_order.executorch_call_delegate").run( + graph_module.code + ) + lowered_submodules = get_lowered_submodules(graph_module) + self.assertEqual(len(lowered_submodules), 1) + + program = exec_prog.executorch_program + + # Check the program can be printed + print_program(program) + + buff = exec_prog.buffer + + executorch_module = _load_for_executorch_from_buffer(buff) + + for method_name, module in { + "sin": sin_module, + "add_mul": add_mul_module, + }.items(): + inputs_flattened, _ = tree_flatten(module.inputs()) + model_outputs = executorch_module.run_method( + method_name, tuple(inputs_flattened) + ) + + if method_name == "sin": + # backend with compiler demo does a taylor approximation of sin + ref_output = 0.8333 * torch.ones(1) + else: + ref_output = module(*module.inputs()) + self.assertTrue( + torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03) + ) diff --git a/exir/backend/test/test_backends_lifted.py b/exir/backend/test/test_backends_lifted.py index 3c55bebd320..be9527b8ccb 100644 --- a/exir/backend/test/test_backends_lifted.py +++ b/exir/backend/test/test_backends_lifted.py @@ -11,6 +11,9 @@ import torch from executorch.exir import to_edge from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.partitioner import ( DelegationSpec, @@ -138,6 +141,18 @@ def forward(self, x): self.assertTrue(torch.allclose(new_res, expected_res)) + # Test same flow but through edge_program_manager + edgeir_m = to_edge(export(sin_module, model_inputs, strict=True)) + loweredir_m = edgeir_m.to_backend( + AllNodePartitioner(BackendWithCompilerDemo.__name__, []) + ) + lowered_sin_module = get_lowered_submodules( + loweredir_m.exported_program().graph_module + )[0][1] + + new_res = lowered_sin_module(*model_inputs)[0] + + self.assertTrue(torch.allclose(new_res, expected_res)) # TODO(tkaruturi): emitting single LoweredBackendModule # program = to_edge(export(graph_module)).to_exectorch()._emitter_output.program diff --git a/exir/backend/test/test_compatibility.py b/exir/backend/test/test_compatibility.py index 9d87aa5be0e..bcda1d36516 100644 --- a/exir/backend/test/test_compatibility.py +++ b/exir/backend/test/test_compatibility.py @@ -10,6 +10,9 @@ from executorch.exir import to_edge from executorch.exir._serialize import _serialize_pte_binary from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.test.backend_with_compiler_demo import ( BackendWithCompilerDemo, @@ -65,3 +68,49 @@ def forward(self, x): "loading method forward failed with error 0x30", ): executorch_module = _load_for_executorch_from_buffer(buff) + + def test_compatibility_in_runtime_edge_program_manager(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + sin_module = SinModule() + model_inputs = (torch.ones(1),) + edgeir_m = to_edge(export(sin_module, model_inputs, strict=True)) + max_value = model_inputs[0].shape[0] + compile_specs = [CompileSpec("max_value", bytes([max_value]))] + lowered_edge_irm = edgeir_m.to_backend( + AllNodePartitioner("BackendWithCompilerDemo", compile_specs) + ) + exec_prog = lowered_edge_irm.to_executorch() + + buff = exec_prog.buffer + + # The demo backend works well + executorch_module = _load_for_executorch_from_buffer(buff) + model_inputs = torch.ones(1) + _ = executorch_module.forward([model_inputs]) + + prog = exec_prog.executorch_program + # Rewrite the delegate version number from 0 to 1. + prog.backend_delegate_data[0].data = bytes( + "1version:1#op:demo::aten.sin.default, numel:1, dtype:torch.float321#", + encoding="utf8", + ) + + # Generate the .pte file with the wrong version. + buff = bytes( + _serialize_pte_binary( + program=prog, + ) + ) + + # Throw runtime error with error code 0x30, meaning delegate is incompatible. + with self.assertRaisesRegex( + RuntimeError, + "loading method forward failed with error 0x30", + ): + executorch_module = _load_for_executorch_from_buffer(buff) diff --git a/exir/backend/test/test_delegate_map_builder.py b/exir/backend/test/test_delegate_map_builder.py index 827cb8cdebc..fcd23b110b6 100644 --- a/exir/backend/test/test_delegate_map_builder.py +++ b/exir/backend/test/test_delegate_map_builder.py @@ -9,12 +9,17 @@ import torch from executorch import exir +from executorch.exir import to_edge from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) from executorch.exir.backend.test.backend_with_delegate_mapping_demo import ( BackendWithDelegateMappingDemo, ) from executorch.exir.backend.utils import DelegateMappingBuilder +from executorch.exir.lowered_backend_module import get_lowered_submodules class TestDelegateMapBuilder(unittest.TestCase): diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 33e417e7326..911c33ec692 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -31,6 +31,7 @@ python_library( "//executorch/exir/_serialize:lib", "//executorch/exir/backend:backend_api", "//executorch/exir/backend:partitioner", + "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", "//executorch/exir/capture:config", "//executorch/exir/emit:emit", "//executorch/exir/emit:lib", diff --git a/exir/program/_program.py b/exir/program/_program.py index 7a2120f9e9b..2d72b4f406f 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -24,7 +24,10 @@ from executorch.exir._serialize.data_serializer import DataSerializer from executorch.exir._warnings import experimental from executorch.exir.backend.backend_api import to_backend -from executorch.exir.backend.partitioner import Partitioner +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) +from executorch.exir.backend.partitioner import DelegationSpec, Partitioner from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig from executorch.exir.delegate import executorch_call_delegate, is_lowered_module from executorch.exir.emit import emit_program, EmitterOutput @@ -1439,7 +1442,13 @@ def transform( @et_logger("to_backend") def to_backend( - self, partitioner: Union[Partitioner, Dict[str, Partitioner]] + self, + partitioner: Union[ + DelegationSpec, + Dict[str, DelegationSpec], + Partitioner, + Dict[str, Partitioner], + ], ) -> "EdgeProgramManager": """ Returns a semantically-equivalent program to the one given as input, @@ -1447,12 +1456,15 @@ def to_backend( for delegation as determined by the partitioner. Args: - partitioner: The partitioner can either be a Partitioner subclass instance, or a - dictionary mapping method names to Partitioner subclass instance. If it is a - Partitioner subclass, all programs in the given EdgeProgramManager - will be lowered using the given partitioner. If it is a - dictionary, only method names specified in the dictionary will be - lowered with the given partitioner. + partitioner: The partitioner can be: + - Partitioner Subclass Instance; all programs in the EdgeProgramManager are lowered with + this partitioner + - Dictionary mapping method name to partitioner subclass instance; Only method names specified + in the dictionary will be lowered by the given partitioner. + - DelegationSpec; All programs are completely lowered to the backend_id specified in the + DelegationSpec + - Dictionary mapping method name to DelegationSpec; Only method names specified in the dictionary + will be lowered to the backend_id specified in the DelegationSpec The Partitioner subclass instance is in charge with tagging portions of the input program for delegation. A valid partitioner must return PartitionerResult including valid @@ -1468,13 +1480,19 @@ def to_backend( if isinstance(partitioner, dict): for name, program in self._edge_programs.items(): if name in partitioner.keys(): - new_edge_programs[name] = to_backend(program, partitioner[name]) + partitioner_to_use = partitioner[name] + if isinstance(partitioner_to_use, DelegationSpec): + partitioner_to_use = AllNodePartitioner(partitioner_to_use) + new_edge_programs[name] = to_backend(program, partitioner_to_use) else: new_edge_programs[name] = program else: # apply partitioner to every method for name, program in self._edge_programs.items(): - new_edge_programs[name] = to_backend(program, partitioner) + partitioner_to_use = partitioner + if isinstance(partitioner, DelegationSpec): + partitioner_to_use = AllNodePartitioner(partitioner) + new_edge_programs[name] = to_backend(program, partitioner_to_use) config = EdgeCompileConfig(_check_ir_validity=False) return EdgeProgramManager( From 64792f14d533bd8bd29bd236dd6acd6fb8cfc173 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 1 Apr 2025 16:31:53 -0700 Subject: [PATCH 2/7] Update [ghstack-poisoned] --- exir/backend/backend_details.py | 72 ++++++++++++++++++++++++++++----- 1 file changed, 63 insertions(+), 9 deletions(-) diff --git a/exir/backend/backend_details.py b/exir/backend/backend_details.py index 248d03f2b05..513ae7c64b3 100644 --- a/exir/backend/backend_details.py +++ b/exir/backend/backend_details.py @@ -50,15 +50,6 @@ class BackendDetails(ABC): the decorators, this interface will be static, abstract and all inheritances are enforced to implement this method. - Args: - edge_program: The original exported program. It will not be modified in place. - compile_specs: List of values needed for compilation - - Returns: - PreprocessResult: It wraps the following information: - processed_bytes -> bytes: A compiled blob - a binary that can run the desired program in the backend. - debug_handle_map (Optional[Dict[int, Tuple[int]]]): For profiling purposes, a map from the node_id in the final graph (either EXIR or the user's self-defined IR) - to debug handle id attached in the original exported program. """ @staticmethod @@ -70,6 +61,69 @@ def preprocess( edge_program: ExportedProgram, compile_specs: List[CompileSpec], ) -> PreprocessResult: + """ + Preprocesses an edge program and returns the preprocess result fo the given backend + + Args: + edge_program: The original exported program. It will not be modified in place. + compile_specs: List of values needed for compilation + + Returns: + PreprocessResult: It wraps the following information: + processed_bytes -> bytes: A compiled blob - a binary that can run the desired + program in the backend. + debug_handle_map (Optional[Dict[int, Tuple[int]]]): For profiling purposes, a + map from the node_id in the final graph (either EXIR or the user's self-defined + IR) to debug handle id attached in the original exported program. + """ # Users should return a compiled blob - a binary that can run the desired # program in the backend. pass + + @classmethod + def preprocess_multimethod( + cls, + edge_programs: Dict[str, List[ExportedProgram]], + compile_specs: Dict[str, List[List[CompileSpec]]], + ) -> Dict[str, list[PreprocessResult]]: + """ + Runs preprocess on all partitioned Edge Programs across multiple methods. This allows + backends to share information across partitioned graphs. Backend can serialize shared + data by putting the shared data into the data_store_output of the preprocess results. + This will record the shared data used by that specific partition. + + Default implementation is running the existing preprocess implementation on all + + Args: + edge_programs: Dictionary mapping the method name to a list of all the partitioned + edge_programs from that method to be lowered. + compile_specs: Dictionary mapping the method name to a list of compile_specs. The + list of compile specs maps directly to the list of edge_programs for the + same given method name i.e. edge_program[method_name][i] --> compile_specs[method_name][i] + + Returns: + Dictionary mapping the method name to a list of PreprocessResults. The list of + PreprocessResults maps directly to the list of edge_programs for the same given + method name. i.e. edge_program[method_name][i] --> result[method_name][i] + + + """ + preprocess_results = {} + for method_name, programs in edge_programs.items(): + assert ( + method_name in compile_specs + ), f"Error: missing compile specs for {method_name}" + compile_specs_for_method = compile_specs[method_name] + assert len(compile_specs_for_method) == len( + programs + ), f"Error: method {method_name} has {len(programs)} partitions but only {len(compile_specs_for_method)}" + results_for_method = [] + for program, compile_spec_for_program in zip( + programs, compile_specs_for_method + ): + preprocess_result = cls.preprocess(program, compile_spec_for_program) + results_for_method.append(preprocess_result) + + preprocess_results[method_name] = results_for_method + + return preprocess_results From 15bb0c6ad89e55579d080e3fb1a02118104bc93b Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 1 Apr 2025 16:56:24 -0700 Subject: [PATCH 3/7] Update [ghstack-poisoned] --- exir/backend/test/test_delegate_map_builder.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/exir/backend/test/test_delegate_map_builder.py b/exir/backend/test/test_delegate_map_builder.py index fcd23b110b6..827cb8cdebc 100644 --- a/exir/backend/test/test_delegate_map_builder.py +++ b/exir/backend/test/test_delegate_map_builder.py @@ -9,17 +9,12 @@ import torch from executorch import exir -from executorch.exir import to_edge from executorch.exir.backend.backend_api import to_backend -from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( - AllNodePartitioner, -) from executorch.exir.backend.test.backend_with_delegate_mapping_demo import ( BackendWithDelegateMappingDemo, ) from executorch.exir.backend.utils import DelegateMappingBuilder -from executorch.exir.lowered_backend_module import get_lowered_submodules class TestDelegateMapBuilder(unittest.TestCase): From cff2f0d4459886455216e1c0db7489e31fbdc8c8 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Thu, 3 Apr 2025 16:34:17 -0700 Subject: [PATCH 4/7] Update [ghstack-poisoned] --- .../all_node_partitioner.py | 4 +- exir/program/_program.py | 38 +++++-------------- 2 files changed, 12 insertions(+), 30 deletions(-) diff --git a/exir/backend/canonical_partitioners/all_node_partitioner.py b/exir/backend/canonical_partitioners/all_node_partitioner.py index bc45f2b5239..a5cf7605343 100644 --- a/exir/backend/canonical_partitioners/all_node_partitioner.py +++ b/exir/backend/canonical_partitioners/all_node_partitioner.py @@ -33,8 +33,8 @@ def __init__( compile_specs: List[CompileSpec], ): """ - Partitioner that lowers every single node in the graph module to the - specified backend_id + Partitioner that lowers every single node in the graph module unconditionally + to the specified backend_id """ super().__init__() self.delegation_spec = DelegationSpec(backend_id, compile_specs) diff --git a/exir/program/_program.py b/exir/program/_program.py index 2d72b4f406f..7a2120f9e9b 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -24,10 +24,7 @@ from executorch.exir._serialize.data_serializer import DataSerializer from executorch.exir._warnings import experimental from executorch.exir.backend.backend_api import to_backend -from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( - AllNodePartitioner, -) -from executorch.exir.backend.partitioner import DelegationSpec, Partitioner +from executorch.exir.backend.partitioner import Partitioner from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig from executorch.exir.delegate import executorch_call_delegate, is_lowered_module from executorch.exir.emit import emit_program, EmitterOutput @@ -1442,13 +1439,7 @@ def transform( @et_logger("to_backend") def to_backend( - self, - partitioner: Union[ - DelegationSpec, - Dict[str, DelegationSpec], - Partitioner, - Dict[str, Partitioner], - ], + self, partitioner: Union[Partitioner, Dict[str, Partitioner]] ) -> "EdgeProgramManager": """ Returns a semantically-equivalent program to the one given as input, @@ -1456,15 +1447,12 @@ def to_backend( for delegation as determined by the partitioner. Args: - partitioner: The partitioner can be: - - Partitioner Subclass Instance; all programs in the EdgeProgramManager are lowered with - this partitioner - - Dictionary mapping method name to partitioner subclass instance; Only method names specified - in the dictionary will be lowered by the given partitioner. - - DelegationSpec; All programs are completely lowered to the backend_id specified in the - DelegationSpec - - Dictionary mapping method name to DelegationSpec; Only method names specified in the dictionary - will be lowered to the backend_id specified in the DelegationSpec + partitioner: The partitioner can either be a Partitioner subclass instance, or a + dictionary mapping method names to Partitioner subclass instance. If it is a + Partitioner subclass, all programs in the given EdgeProgramManager + will be lowered using the given partitioner. If it is a + dictionary, only method names specified in the dictionary will be + lowered with the given partitioner. The Partitioner subclass instance is in charge with tagging portions of the input program for delegation. A valid partitioner must return PartitionerResult including valid @@ -1480,19 +1468,13 @@ def to_backend( if isinstance(partitioner, dict): for name, program in self._edge_programs.items(): if name in partitioner.keys(): - partitioner_to_use = partitioner[name] - if isinstance(partitioner_to_use, DelegationSpec): - partitioner_to_use = AllNodePartitioner(partitioner_to_use) - new_edge_programs[name] = to_backend(program, partitioner_to_use) + new_edge_programs[name] = to_backend(program, partitioner[name]) else: new_edge_programs[name] = program else: # apply partitioner to every method for name, program in self._edge_programs.items(): - partitioner_to_use = partitioner - if isinstance(partitioner, DelegationSpec): - partitioner_to_use = AllNodePartitioner(partitioner) - new_edge_programs[name] = to_backend(program, partitioner_to_use) + new_edge_programs[name] = to_backend(program, partitioner) config = EdgeCompileConfig(_check_ir_validity=False) return EdgeProgramManager( From 73c492bc80923948b3a8e80be7c1e7095cf00f70 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 7 Apr 2025 17:04:33 -0700 Subject: [PATCH 5/7] Update [ghstack-poisoned] --- CMakeLists.txt | 4 ++++ pytest.ini | 2 -- runtime/executor/test/CMakeLists.txt | 20 ++++++++++++++++++++ setup.py | 1 + tools/cmake/cmake_deps.toml | 14 ++++++++++++++ 5 files changed, 39 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6dbb66afdaa..1ac7de469b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -810,6 +810,10 @@ if(EXECUTORCH_BUILD_PYBIND) torch ) + if(EXECUTORCH_BUILD_TESTS) + list(APPEND _dep_libs test_backend_compiler_lib) + endif() + if(EXECUTORCH_BUILD_KERNELS_OPTIMIZED) list(APPEND _dep_libs optimized_native_cpu_ops_lib) else() diff --git a/pytest.ini b/pytest.ini index cd647c43a1c..8c661aa9ee4 100644 --- a/pytest.ini +++ b/pytest.ini @@ -63,8 +63,6 @@ addopts = --ignore=exir/backend/test/demos --ignore=exir/backend/test/test_backends.py --ignore=exir/backend/test/test_backends_lifted.py - --ignore=exir/backend/test/test_compatibility.py - --ignore=exir/backend/test/test_lowered_backend_module.py --ignore=exir/backend/test/test_partitioner.py --ignore=exir/tests/test_common.py --ignore=exir/tests/test_memory_format_ops_pass_aten.py diff --git a/runtime/executor/test/CMakeLists.txt b/runtime/executor/test/CMakeLists.txt index 2de32c9176a..3003a5d2c74 100644 --- a/runtime/executor/test/CMakeLists.txt +++ b/runtime/executor/test/CMakeLists.txt @@ -152,3 +152,23 @@ target_include_directories( PRIVATE "${CMAKE_INSTALL_PREFIX}/schema/include" "${EXECUTORCH_ROOT}/third-party/flatbuffers/include" ) + +list(TRANSFORM _test_backend_compiler_lib__srcs PREPEND "${EXECUTORCH_ROOT}/") +add_library( + test_backend_compiler_lib + STATIC + ${_test_backend_compiler_lib__srcs} +) + +target_link_libraries( + test_backend_compiler_lib + PUBLIC + executorch_core +) + +target_link_options_shared_lib(test_backend_compiler_lib) + +install( + TARGETS test_backend_compiler_lib + DESTINATION lib +) diff --git a/setup.py b/setup.py index 44fb9a712a3..e29ce76ff7e 100644 --- a/setup.py +++ b/setup.py @@ -718,6 +718,7 @@ def run(self): # enabled. TODO(dbort): Remove this override once this option is # managed by cmake itself. "-DEXECUTORCH_SEPARATE_FLATCC_HOST_PROJECT=OFF", + "-DEXECUTORCH_BUILD_TESTS=ON", ] build_args = [f"-j{self.parallel}"] diff --git a/tools/cmake/cmake_deps.toml b/tools/cmake/cmake_deps.toml index ee810c2bfd5..9913a02c4d5 100644 --- a/tools/cmake/cmake_deps.toml +++ b/tools/cmake/cmake_deps.toml @@ -150,6 +150,20 @@ deps = [ "optimized_cpublas", "portable_kernels", ] + +[targets.test_backend_compiler_lib] +buck_targets = [ + "//runtime/executor/test:test_backend_compiler_lib", +] +filters = [ + ".cpp$", +] +excludes = [ +] +deps = [ + "executorch", + "executorch_core", +] # ---------------------------------- core end ---------------------------------- # ---------------------------------- extension start ---------------------------------- [targets.extension_data_loader] From 13409b0584e5e229d35f97a1adb2871af9c5563b Mon Sep 17 00:00:00 2001 From: Max Ren Date: Mon, 7 Apr 2025 18:08:28 -0700 Subject: [PATCH 6/7] Update [ghstack-poisoned] --- exir/backend/test/test_lowered_backend_module.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/exir/backend/test/test_lowered_backend_module.py b/exir/backend/test/test_lowered_backend_module.py index dcc5841bc3e..6cdaf92b3d2 100644 --- a/exir/backend/test/test_lowered_backend_module.py +++ b/exir/backend/test/test_lowered_backend_module.py @@ -22,7 +22,6 @@ from executorch.extension.pybindings.portable_lib import ( # @manual _load_for_executorch_from_buffer, ) -from hypothesis import given, settings, strategies as st from torch.export import export @@ -65,7 +64,6 @@ def forward(self, *args): .executorch_program ) - @settings(deadline=500000) def test_emit_lowered_backend_module_end_to_end(self): class SinModule(torch.nn.Module): def __init__(self): @@ -109,11 +107,7 @@ def forward(self, x): torch.allclose(model_outputs[0], expected_res, atol=1e-03, rtol=1e-03) ) - @given( - unlift=st.booleans(), # verify both lifted and unlifted graph - ) - @settings(deadline=500000) - def test_emit_lowered_backend_module(self, unlift): + def test_emit_lowered_backend_module(self): module_list = [ models.Emformer(), models.Repeat(), @@ -166,11 +160,7 @@ def test_emit_lowered_backend_module(self, unlift): _ = lowered_model.buffer() self.validate_lowered_module_program(program) - @given( - unlift=st.booleans(), # verify both lifted and unlifted graph - ) - @settings(deadline=500000) - def test_emit_nested_lowered_backend_module(self, unlift): + def test_emit_nested_lowered_backend_module(self): module_list = [ models.Emformer(), models.Repeat(), From 4ed3e444c789c017038cdbffa5a52772a0f5d7e2 Mon Sep 17 00:00:00 2001 From: Max Ren <40742183+mcr229@users.noreply.github.com> Date: Wed, 16 Apr 2025 09:10:36 -0700 Subject: [PATCH 7/7] [ExecuTorch][to_backend] Enable to_backend API to leverage preprocess_multimethod (#9824) ### Summary We add a new to_backend api which for multi method models. Specifically, we pass in a dictionary mapping MethodName to ExportedProgram, as well as a dictionary mapping MethodName to Partitioner. We then return a dictionary mapping MethodName to the partitioned and lowered exported program. In addition, we also provide a new preprocess API for backends to implement. This API is preprocess_multimethod. The signature of the new method is as follows: ``` def preprocess_multimethod( cls, edge_programs: Dict[str, List[ExportedProgram]], compile_specs: Dict[str, List[List[CompileSpec]]], ) -> Dict[str, list[PreprocessResult]]: """ Runs preprocess on all partitioned Edge Programs across multiple methods. This allows backends to share information across partitioned graphs. Backend can serialize shared data by putting the shared data into the data_store_output of the preprocess results. This will record the shared data used by that specific partition. Default implementation is running the existing preprocess implementation on all Args: edge_programs: Dictionary mapping the method name to a list of all the partitioned edge_programs from that method to be lowered. compile_specs: Dictionary mapping the method name to a list of compile_specs. The list of compile specs maps directly to the list of edge_programs for the same given method name i.e. edge_program[method_name][i] --> compile_specs[method_name][i] Returns: Dictionary mapping the method name to a list of PreprocessResults. The list of PreprocessResults maps directly to the list of edge_programs for the same given method name. i.e. edge_program[method_name][i] --> result[method_name][i] """ ``` This new API enableds backends to preprocess all partitions/methods at once. This way, while processing blobs, they can identify shared components between preprocessed blobs. Shared components can be serialized within the NamedDataStore. The key change in backend infra, is that when partitioning, we now have to identify all the partitioned graphs to be lowered at once, and pass them to preprocess_multimethod at once. Previously, as we found lowerable partitions, we preprocessed and embedded them into the graph. ### Testing python -m unittest exir.backend.test.test_to_backend_multi_method --- exir/backend/backend_api.py | 451 +++++++++++++-- exir/backend/test/TARGETS | 52 ++ .../test/backend_with_preprocess_all_demo.py | 266 +++++++++ .../test/test_to_backend_multi_method.py | 524 ++++++++++++++++++ exir/lowered_backend_module.py | 31 +- 5 files changed, 1257 insertions(+), 67 deletions(-) create mode 100644 exir/backend/test/backend_with_preprocess_all_demo.py create mode 100644 exir/backend/test/test_to_backend_multi_method.py diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index ab2e66f7885..310e5ea9379 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -8,8 +8,9 @@ import copy import logging from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from functools import singledispatch -from typing import Generator, List +from typing import Dict, Generator, List, Mapping import torch @@ -36,7 +37,7 @@ update_to_real_program, ) from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param -from torch.export import ExportedProgram +from torch.export.exported_program import ExportedProgram, InputSpec, OutputSpec @singledispatch @@ -190,6 +191,65 @@ def _get_node_list_with_same_tag( return node_list +def _insert_lowered_submodule( + submodule_program: ExportedProgram, + owning_program: ExportedProgram, + call_submodule_node: torch.fx.Node, + submodule_output_node: torch.fx.Node, + lowered_module: LoweredBackendModule, + is_submodule: bool, + toplevel_input_specs_to_delete: Dict[str, InputSpec], + toplevel_output_specs_to_delete: Dict[str, OutputSpec], +): + owning_graph_module = call_submodule_node.graph.owning_module + # call delegate args should only use user_inputs + call_delegate_args = [] + # Preserve input order as user_inputs + for inp_name in submodule_program.graph_signature.user_inputs: + for inp_node in call_submodule_node.all_input_nodes: + if inp_node.name == inp_name: + call_delegate_args.append(inp_node) + break + + def generate_debug_handle(ep: ExportedProgram) -> int: + """ + Generate a debug handle for the given ExportedProgram. + """ + debug_handle = 0 + for node in ep.graph_module.graph.nodes: + debug_handle = max(debug_handle, node.meta.get("debug_handle", 0)) + return debug_handle + 1 + + # Replace the partitioned submodule with a lowered submodule + # Add call_method node with function "forward" + with owning_graph_module.graph.inserting_before(call_submodule_node): + lowered_name = get_lowered_module_name(owning_graph_module, lowered_module) + lowered_node = owning_graph_module.graph.get_attr(lowered_name) + call_delegate_node = owning_graph_module.graph.call_function( + executorch_call_delegate, + (lowered_node,) + tuple(call_delegate_args), + call_submodule_node.kwargs, + ) + call_delegate_node.meta["debug_handle"] = generate_debug_handle(owning_program) + call_delegate_node.meta["val"] = submodule_output_node.meta["val"] + call_submodule_node.replace_all_uses_with(call_delegate_node) + owning_graph_module.graph.erase_node(call_submodule_node) + + if is_submodule: + assert len(toplevel_input_specs_to_delete) == 0 + assert len(toplevel_output_specs_to_delete) == 0 + elif ( + len(toplevel_input_specs_to_delete) > 0 + or len(toplevel_output_specs_to_delete) > 0 + ): + _unsafe_adjust_original_program( + owning_program, + call_delegate_node, + toplevel_input_specs_to_delete, + toplevel_output_specs_to_delete, + ) + + def _partition_and_lower_one_graph_module( tagged_graph_module: torch.fx.GraphModule, partition_result: PartitionResult, @@ -254,56 +314,16 @@ def _partition_and_lower_one_graph_module( delegation_spec.compile_specs, ) - # call delegate args should only use user_inputs - call_delegate_args = [] - # Preserve input order as user_inputs - for inp_name in submodule_program.graph_signature.user_inputs: - for inp_node in call_module_node.all_input_nodes: - if inp_node.name == inp_name: - call_delegate_args.append(inp_node) - break - - def generate_debug_handle(ep: ExportedProgram) -> int: - """ - Generate a debug handle for the given ExportedProgram. - """ - debug_handle = 0 - for node in ep.graph_module.graph.nodes: - debug_handle = max(debug_handle, node.meta.get("debug_handle", 0)) - return debug_handle + 1 - - # Replace the partitioned submodule with a lowered submodule - # Add call_method node with function "forward" - with tagged_graph_module.graph.inserting_before(call_module_node): - lowered_name = get_lowered_module_name( - tagged_graph_module, lowered_submodule - ) - lowered_node = tagged_graph_module.graph.get_attr(lowered_name) - call_delegate_node = tagged_graph_module.graph.call_function( - executorch_call_delegate, - (lowered_node,) + tuple(call_delegate_args), - call_module_node.kwargs, - ) - call_delegate_node.meta["debug_handle"] = generate_debug_handle( - owning_program - ) - call_delegate_node.meta["val"] = submodule_output_node.meta["val"] - call_module_node.replace_all_uses_with(call_delegate_node) - tagged_graph_module.graph.erase_node(call_module_node) - - if is_submodule: - assert len(toplevel_input_specs_to_delete) == 0 - assert len(toplevel_output_specs_to_delete) == 0 - elif ( - len(toplevel_input_specs_to_delete) > 0 - or len(toplevel_output_specs_to_delete) > 0 - ): - _unsafe_adjust_original_program( - owning_program, - call_delegate_node, - toplevel_input_specs_to_delete, - toplevel_output_specs_to_delete, - ) + _insert_lowered_submodule( + submodule_program, + owning_program, + call_module_node, + submodule_output_node, + lowered_submodule, + is_submodule, + toplevel_input_specs_to_delete, + toplevel_output_specs_to_delete, + ) return tagged_graph_module @@ -417,3 +437,330 @@ def to_backend( constants=tagged_exported_program.constants, verifiers=[tagged_exported_program.verifier], ) + + +def _create_partitions_in_graph_module( + tagged_graph_module: torch.fx.GraphModule, + partition_result: PartitionResult, + owning_program: ExportedProgram, + is_submodule: bool, +) -> Dict[str, List[torch.fx.Node]]: + backend_id_to_submodule_name = {} + for tag, delegation_spec in partition_result.partition_tags.items(): + # Create partition with nodes containing this tag. There should only be + # one contained submodule per tag + node_list = _get_node_list_with_same_tag( + tagged_graph_module, tag, owning_program + ) + + if len(node_list) == 0: + logging.debug(f"Did not find any nodes for tag {tag}") + continue + + logging.debug(f"For tag {tag}, found nodes {node_list}") + # Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs) + + replace_ctx = ( + tagged_graph_module._set_replace_hook( + owning_program.graph_signature.get_replace_hook() + ) + if not is_submodule + else nullcontext() + ) + with replace_ctx: + submodule, call_module_node = create_submodule_from_nodes( + tagged_graph_module, node_list, tag + ) + + tagged_graph_module_output_node = [ + node for node in tagged_graph_module.graph.nodes if node.op == "output" + ][0] + submodule_output_node = [ + node for node in submodule.graph.nodes if node.op == "output" + ][0] + # Copy the output node meta from the original output node, because + # create_submodule_from_nodes doesn't cover the meta field + submodule_output_node.meta = tagged_graph_module_output_node.meta + logging.debug(f"Partitioned graph module: {tagged_graph_module}") + ( + submodule_program, + toplevel_input_specs_to_delete, + toplevel_output_specs_to_delete, + ) = create_exported_program_from_submodule( + submodule, + owning_program, + tag, + call_module_node, + is_submodule, + ) + call_module_node.meta["backend_id"] = delegation_spec.backend_id + call_module_node.meta["compile_spec"] = delegation_spec.compile_specs + call_module_node.meta["submodule_program"] = submodule_program + call_module_node.meta["toplevel_input_specs_to_delete"] = ( + toplevel_input_specs_to_delete + ) + call_module_node.meta["toplevel_output_specs_to_delete"] = ( + toplevel_output_specs_to_delete + ) + call_module_node.meta["is_submodule"] = is_submodule + call_module_node.meta["submodule_output_node"] = submodule_output_node + + if delegation_spec.backend_id not in backend_id_to_submodule_name: + backend_id_to_submodule_name[delegation_spec.backend_id] = [] + + # The call_module_node created here might not be the same node instance as + # the one in the final graph module. This is because this node might be replaced + # in future edits to the graph. As a result, we just keep track of the node's name + # and at the end we search for this node in our final graph module + backend_id_to_submodule_name[delegation_spec.backend_id].append( + call_module_node.target + ) + + created_submodule_nodes = {key: [] for key in backend_id_to_submodule_name.keys()} + for backend_id, submodule_name in backend_id_to_submodule_name.items(): + for node in tagged_graph_module.graph.nodes: + if node.op == "call_module" and node.target in submodule_name: + created_submodule_nodes[backend_id].append(node) + + # check the number of submodule_names and submodule_nodes are equal + for backend_id in created_submodule_nodes.keys(): + assert len(created_submodule_nodes[backend_id]) == len( + backend_id_to_submodule_name[backend_id] + ) + + return created_submodule_nodes + + +def _create_partitions( + tagged_graph_module: torch.fx.GraphModule, + partition_result: PartitionResult, + owning_program: ExportedProgram, + is_submodule: bool = False, +) -> Dict[str, List[torch.fx.Node]]: + backend_id_to_call_submodules = _create_partitions_in_graph_module( + tagged_graph_module, partition_result, owning_program, is_submodule + ) + + # Recursively partition and lower for submodules + for _, submod, _ in get_control_flow_submodules(tagged_graph_module): + nested_backend_id_to_call_submodules = _create_partitions( + submod, partition_result, owning_program, is_submodule=True + ) + for ( + backend_id, + nested_submodules, + ) in nested_backend_id_to_call_submodules.items(): + if backend_id not in backend_id_to_call_submodules: + backend_id_to_call_submodules[backend_id] = nested_submodules + else: + backend_id_to_call_submodules[backend_id].extend(nested_submodules) + + return backend_id_to_call_submodules + + +def lower_all_submodules_to_backend( + backend_id: str, + method_to_submodules_nodes: Dict[str, List[torch.fx.Node]], + method_to_tagged_edge_program: Dict[str, ExportedProgram], +) -> None: + """ + Lower all submodules nodes given in the method_to_submodule_nodes map to backend_id. + """ + # The created exported program for the submodules are in the call_module node's meta data + # We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs + method_to_partitioned_program = { + method_name: [node.meta["submodule_program"] for node in call_submodule_nodes] + for method_name, call_submodule_nodes in method_to_submodules_nodes.items() + } + method_to_compile_specs = { + method_name: [node.meta["compile_spec"] for node in call_submodule_nodes] + for method_name, call_submodule_nodes in method_to_submodules_nodes.items() + } + backend_found = False + for cls in BackendDetails.__subclasses__(): + if backend_id == cls.__name__: + method_to_preprocess_result: dict[str, List[PreprocessResult]] = ( + cls.preprocess_multimethod( + method_to_partitioned_program, method_to_compile_specs + ) + ) + backend_found = True + + if not backend_found: + raise NotImplementedError(f"Backend {backend_id} was not found.") + + for method_name in method_to_preprocess_result.keys(): + owning_program = method_to_tagged_edge_program[method_name] + list_of_preprocess_results = method_to_preprocess_result[method_name] + list_of_call_submodule_nodes = method_to_submodules_nodes[method_name] + list_of_compile_specs = method_to_compile_specs[method_name] + for preprocess_result, call_submodule_node, compile_spec in zip( + list_of_preprocess_results, + list_of_call_submodule_nodes, + list_of_compile_specs, + ): + submodule_program = call_submodule_node.meta["submodule_program"] + lowered_module = LoweredBackendModule( + edge_program=submodule_program, + backend_id=backend_id, + processed_bytes=preprocess_result.processed_bytes, + compile_specs=compile_spec, + named_data_store_output=preprocess_result.data_store_output, + ) + is_submodule = call_submodule_node.meta["is_submodule"] + toplevel_input_specs_to_delete = call_submodule_node.meta[ + "toplevel_input_specs_to_delete" + ] + toplevel_output_specs_to_delete = call_submodule_node.meta[ + "toplevel_output_specs_to_delete" + ] + submodule_output_node = call_submodule_node.meta["submodule_output_node"] + + _insert_lowered_submodule( + submodule_program, + owning_program, + call_submodule_node, + submodule_output_node, + lowered_module, + is_submodule, + toplevel_input_specs_to_delete, + toplevel_output_specs_to_delete, + ) + + +@dataclass +class MethodProgramsPartitionerSpec: + """ + Since single dispatch for to_backend requires the first argument to be a + valid class, we create the following dataclass spec to hold the dictionaries + mapping the method name to the corresponding program, partitioner + """ + + method_to_edge_program: Mapping[str, ExportedProgram] + method_to_partitioner: Mapping[str, Partitioner] + + +@to_backend.register +def _( + method_edge_program_partitioners: MethodProgramsPartitionerSpec, +) -> Dict[str, ExportedProgram]: + """ + Add overloaded implementations for to_backend: + + :: + + def to_backend( + method_edge_program_partitioners: MethodProgramsPartitionerSpec + ) -> Dict[str, ExportedProgram]: + + Returns a semantically-equivalent dictionary of programs to the programs given as input (represented + as a graph module in Edge dialect), but with portions of the program targeted for + delegation as determined by the partitioner. + + Args: + method_edge_program_partitioners: contains two mappings, + - method_to_edge_program: mapping of method names to their respective programs in Edge dialect. + - method_to_partitioner: mapping of method names to an instance of the partitioner, in charge with tagging + portions of the specified program for delegation. A valid partitioner must return PartitionerResult + including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and + the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec. + + + Returns: + ExportedProgram: The input program, with some portions targeted for delegation. + """ + method_to_edge_program = method_edge_program_partitioners.method_to_edge_program + method_to_partitioner = method_edge_program_partitioners.method_to_partitioner + + partitioned_and_lowered_exported_programs = {} + backend_id_to_method_submodules_map = {} + method_to_tagged_exported_program = {} + + for method_name, partitioner_instance in method_to_partitioner.items(): + assert ( + method_name in method_to_edge_program + ), f"Partitioner for method {method_name} is not provided" + edge_program = method_to_edge_program[method_name] + edge_program._validate() + + # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values. + # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback. + try: + fake_edge_program = get_fake_program(edge_program) + except Exception as e: + logging.warning( + f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}" + ) + fake_edge_program = copy.deepcopy(edge_program) + partitioner_result = partitioner_instance(fake_edge_program) + tagged_exported_program = partitioner_result.tagged_exported_program + method_to_tagged_exported_program[method_name] = tagged_exported_program + + # Check that the partitioner did not modify the original graph + if _ENABLE_VALIDATION: + assert is_identical_graph( + tagged_exported_program.graph_module, + edge_program.graph_module, + ), f"The partitioner {partitioner_instance} should not modify the graph module" + else: + logging.warning("Disabled validating the partitioner.") + + assert ( + partitioner_result.partition_tags is not None + ), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec" + + update_to_real_program(tagged_exported_program, edge_program) + + for tag, _ in partitioner_result.partition_tags.items(): + _maybe_duplicate_constant_nodes(tagged_exported_program, tag) + + backend_id_to_call_submodule_nodes = _create_partitions( + tagged_exported_program.graph_module, + partitioner_result, + tagged_exported_program, + ) + for ( + backend_id, + call_submodule_nodes, + ) in backend_id_to_call_submodule_nodes.items(): + if backend_id not in backend_id_to_method_submodules_map: + backend_id_to_method_submodules_map[backend_id] = {} + backend_id_to_method_submodules_map[backend_id][ + method_name + ] = call_submodule_nodes + + for ( + backend_id, + method_to_submodule_nodes, + ) in backend_id_to_method_submodules_map.items(): + lower_all_submodules_to_backend( + backend_id, + method_to_submodule_nodes, + method_to_tagged_exported_program, + ) + + for method_name in method_to_edge_program.keys(): + if method_name in method_to_tagged_exported_program: + tagged_exported_program = method_to_tagged_exported_program[method_name] + partitioned_and_lowered_exported_programs[method_name] = ExportedProgram( + root=tagged_exported_program.graph_module, + graph=tagged_exported_program.graph_module.graph, + graph_signature=tagged_exported_program.graph_signature, + state_dict=tagged_exported_program.state_dict, + range_constraints=copy.deepcopy( + tagged_exported_program.range_constraints + ), + module_call_graph=copy.deepcopy( + tagged_exported_program.module_call_graph + ), + example_inputs=None, + constants=tagged_exported_program.constants, + verifiers=[tagged_exported_program.verifier], + ) + else: + # this edge program wasn't partitioned, so we can just return it as is + partitioned_and_lowered_exported_programs[method_name] = ( + method_to_edge_program[method_name] + ) + + return partitioned_and_lowered_exported_programs diff --git a/exir/backend/test/TARGETS b/exir/backend/test/TARGETS index f0ba618936d..b7f46076868 100644 --- a/exir/backend/test/TARGETS +++ b/exir/backend/test/TARGETS @@ -189,6 +189,58 @@ python_unittest( ], ) +python_unittest( + name = "test_to_backend_multi_method", + srcs = [ + "test_to_backend_multi_method.py", + ], + preload_deps = [ + "//executorch/kernels/portable:custom_ops_generated_lib", + "//executorch/kernels/quantized:custom_ops_generated_lib", + "//executorch/runtime/executor/test:test_backend_compiler_lib", + ], + deps = [ + ":backend_with_preprocess_all_demo", + "//caffe2:torch", + "//caffe2/functorch:functorch_src", + "//executorch/exir:delegate", + "//executorch/exir:graph_module", + "//executorch/exir:lib", + "//executorch/exir:lowered_backend_module", + "//executorch/exir:print_program", + "//executorch/exir:schema", + "//executorch/exir/backend:backend_api", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/backend:partitioner", + "//executorch/exir/dialects:lib", + "//executorch/extension/pybindings:portable_lib", # @manual + "//executorch/extension/pytree:pylib", + ], +) + +python_library( + name = "backend_with_preprocess_all_demo", + srcs = [ + "backend_with_preprocess_all_demo.py" + ], + deps = [ + "//caffe2:torch", + "//caffe2/functorch:functorch_src", + "//executorch/exir:delegate", + "//executorch/exir:graph_module", + "//executorch/exir:lib", + "//executorch/exir:lowered_backend_module", + "//executorch/exir:print_program", + "//executorch/exir:schema", + "//executorch/exir/backend:backend_api", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/backend:partitioner", + "//executorch/exir/dialects:lib", + "//executorch/extension/pybindings:portable_lib", # @manual + "//executorch/extension/pytree:pylib", + ], +) + python_unittest( name = "test_debug_handle_map", srcs = [ diff --git a/exir/backend/test/backend_with_preprocess_all_demo.py b/exir/backend/test/backend_with_preprocess_all_demo.py new file mode 100644 index 00000000000..ae9a8174be5 --- /dev/null +++ b/exir/backend/test/backend_with_preprocess_all_demo.py @@ -0,0 +1,266 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, final, List, Tuple + +import torch + +from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult +from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( + generate_pattern_op_partitions, +) + +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_control_flow_submodules +from torch.export.exported_program import ExportedProgram +from torch.fx.passes.operator_support import any_chain, OperatorSupportBase + + +def _preprocess_multimethod( + edge_programs: Dict[str, List[ExportedProgram]], + compile_specs: Dict[str, List[List[CompileSpec]]], + supported_ops: List[torch._ops.OpOverload], + backend_name: str, +) -> Dict[str, List[PreprocessResult]]: + """ + Helper function to abstract out the logic to be shared between the two backends: + FirstBackendWithPreprocessAll and SecondBackendWithPreprocessAll. This will be used + in testing for a partitioner which tags different partitions for different backends + to be lowered to + """ + total_number_of_ops = 0 + for edge_program in edge_programs.values(): + for partitioned_program in edge_program: + for node in partitioned_program.graph.nodes: + if node.op == "call_function": + if node.target in supported_ops: + total_number_of_ops += 1 + all_processed_results = {key: [] for key in edge_programs.keys()} + + for method_name, partitioned_programs in edge_programs.items(): + compile_specs_for_method = compile_specs[method_name] + + assert len(compile_specs_for_method) == len(partitioned_programs) + for compile_spec_for_partition, partitioned_program in zip( + compile_specs_for_method, partitioned_programs + ): + debug_handle_map = {} + processed_bytes = f"{backend_name}#{total_number_of_ops}#" + for node in partitioned_program.graph.nodes: + if node.op == "call_function": + if node.target in supported_ops: + op_name = node.target.__name__ + processed_bytes += f"{op_name}:" + original_debug_id = node.meta["debug_handle"] + new_debug_id = original_debug_id + debug_handle_map[new_debug_id] = (original_debug_id,) + else: + raise RuntimeError( + f"{node.op} {node.target.__name__} is not supported in backend {backend_name}" + ) + + processed_bytes += "#" + for cs in compile_spec_for_partition: + processed_bytes += f"{cs.key}:{cs.value};" + + all_processed_results[method_name].append( + PreprocessResult( + processed_bytes=bytes(processed_bytes, encoding="utf8"), + debug_handle_map=debug_handle_map, + ) + ) + + return all_processed_results + + +@final +class FirstBackendWithPreprocessAll(BackendDetails): + """ + Backend used to test the preprocess_multimethod for multi methods lowering. + lowered modules are returned in the format: + FirstBackendWithPreprocessAll##::#;: + + + lowered blobs are not functional, and are purely used for testing purposes + """ + + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Not used for testing + """ + return PreprocessResult( + processed_bytes=bytes(b"\x00"), + debug_handle_map={}, + ) + + @staticmethod + def preprocess_multimethod( + edge_programs: Dict[str, List[ExportedProgram]], + compile_specs: Dict[str, List[List[CompileSpec]]], + ) -> Dict[str, list[PreprocessResult]]: + """ + Preprocess all the edge programs in the given dictionary and return a dictionary + of preprocess results. The preprocess result is a tuple of processed bytes and + a map from the node name to the original debug handle. + """ + match_ops = [ + exir_ops.edge.aten.sin.default, + exir_ops.edge.aten.add.Tensor, + ] + + return _preprocess_multimethod( + edge_programs, compile_specs, match_ops, "FirstBackendWithPreprocessAll" + ) + + +@final +class SecondBackendWithPreprocessAll(BackendDetails): + """ + Backend used to test the preprocess_multimethod for multi methods lowering. + lowered modules are returned in the format: + SecondBackendWithPreprocessAll##::#;: + + + lowered blobs are not functional, and are purely used for testing purposes + """ + + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Not used for testing + """ + return PreprocessResult( + processed_bytes=bytes(b"\x00"), + debug_handle_map={}, + ) + + @staticmethod + def preprocess_multimethod( + edge_programs: Dict[str, List[ExportedProgram]], + compile_specs: Dict[str, List[List[CompileSpec]]], + ) -> Dict[str, list[PreprocessResult]]: + """ + Preprocess all the edge programs in the given dictionary and return a dictionary + of preprocess results. The preprocess result is a tuple of processed bytes and + a map from the node name to the original debug handle. + """ + match_ops = [ + exir_ops.edge.aten.cos.default, + exir_ops.edge.aten.sub.Tensor, + ] + + return _preprocess_multimethod( + edge_programs, compile_specs, match_ops, "SecondBackendWithPreprocessAll" + ) + + +class AddSinOperatorSupport(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sin.default, + ] + + +class SubCosOperatorSupport(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.cos.default, + ] + + +@final +class BackendWithPreprocessAllPartitioner(Partitioner): + """ + Partitioner that partitions for both FirstBackendWithPreprocessAll + and SecondBackendWithPreprocessAll. + + - The partitioner tags all sin and add nodes for delegation to + FirstBackendWithPreprocessAll + - The partitioner tags all cos and sub nodes for delegation to + SecondBackendWithPreprocessAll + """ + + def __init__(self) -> None: + self.add_sin_support = any_chain(AddSinOperatorSupport()) + self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__ + + self.sub_cos_support = any_chain(SubCosOperatorSupport()) + self.sub_cos_backend_id = SecondBackendWithPreprocessAll.__name__ + + def _partition_graph_module( + self, + graph_module: torch.fx.GraphModule, + id_start=0, + ) -> Tuple[Dict[str, DelegationSpec], int]: + partition_tags: Dict[str, DelegationSpec] = {} + + num_partitions_in_gm = 0 + for op_support, backend_id, tag_prefix in [ + (self.add_sin_support, self.add_sin_backend_id, "first"), + (self.sub_cos_support, self.sub_cos_backend_id, "second"), + ]: + partition_list = generate_pattern_op_partitions( + graph_module, op_support=op_support + ) + num_partitions_in_gm = num_partitions_in_gm + len(partition_list) + for partition in partition_list: + compile_specs = [] + delegation_tag = f"{tag_prefix}_tag{id_start + partition.id}" + for node in partition.nodes: + node.meta["delegation_tag"] = delegation_tag + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.add.Tensor + ): + compile_specs.append(CompileSpec("add", bytes(b"\x00"))) + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.sin.default + ): + compile_specs.append(CompileSpec("sin", bytes(b"\x01"))) + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.sub.Tensor + ): + compile_specs.append(CompileSpec("sub", bytes(b"\x02"))) + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.cos.default + ): + compile_specs.append(CompileSpec("cos", bytes(b"\x03"))) + + delegation_spec = DelegationSpec(backend_id, compile_specs) + partition_tags[delegation_tag] = delegation_spec + + start_idx_for_submodules = num_partitions_in_gm + for _, submodule, _ in get_control_flow_submodules(graph_module): + ret_partition_tags, start_idx_for_submodules = self._partition_graph_module( + submodule, id_start=start_idx_for_submodules + ) + partition_tags.update(ret_partition_tags) + + return partition_tags, start_idx_for_submodules + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + partition_tags, _ = self._partition_graph_module(exported_program.graph_module) + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) diff --git a/exir/backend/test/test_to_backend_multi_method.py b/exir/backend/test/test_to_backend_multi_method.py new file mode 100644 index 00000000000..d4f8fccb8f2 --- /dev/null +++ b/exir/backend/test/test_to_backend_multi_method.py @@ -0,0 +1,524 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Dict, List, Tuple + +import torch + +from executorch.exir import EdgeProgramManager, to_edge +from executorch.exir.backend.backend_api import ( + MethodProgramsPartitionerSpec, + to_backend, +) + +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import Partitioner +from executorch.exir.backend.test.backend_with_compiler_demo import ( + BackendWithCompilerDemo, +) + +from executorch.exir.backend.test.backend_with_preprocess_all_demo import ( + BackendWithPreprocessAllPartitioner, +) +from executorch.exir.graph_module import get_control_flow_submodules +from executorch.exir.lowered_backend_module import ( + get_lowered_submodules, + LoweredBackendModule, +) +from executorch.exir.schema import ( + BackendDelegate, + BackendDelegateDataReference, + DataLocation, + Program, +) +from executorch.extension.pybindings.portable_lib import ( # @manual + _load_for_executorch_from_buffer, +) +from torch.export.exported_program import ExportedProgram + +from torch.testing import FileCheck + + +class TestToBackendMultiMethod(unittest.TestCase): + """ + Testing suite used to test multi method to_backend lowering. The test suite uses demo backends + FirstBackendWithPreprocessAll and SecondBackendWithPreprocessAll. + - FirstBackendWithPreprocessAll: supports add + sin + - SecondBackendWithPreprocessAll: supports sub + cos + + Both backends lower exported programs into payloads in the string format: + - (backend_id)#(total_number_ops across methods)#[op_target_name;]#[compile_spec.key:compile_spec.value;] + + We leverage the above expectation to test various lowering across different modules, and ensure + that the right exported programs and compile specs are given when lowering a specifed exported program + + We leverage the demo partitioner BackendWithPreprocessAll which partitions add + sin nodes to + FirstBackendWithPreprocessAll and sub + cos nodes to SecondBackendWithPreprocessAll. This allows + us to test cases in which multiple backends are being lowered. + """ + + def _get_lowered_submodules_across_controlflow( + self, graph_module: torch.fx.GraphModule + ) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]: + top_level_submodules = get_lowered_submodules(graph_module) + + for _, submodule, _ in get_control_flow_submodules(graph_module): + top_level_submodules.extend( + self._get_lowered_submodules_across_controlflow(submodule) + ) + + return top_level_submodules + + def check_backend_delegate( + self, + program: Program, + delegate: BackendDelegate, + expected_id: str, + expected_processed: bytes, + ) -> None: + self.assertEqual(delegate.id, expected_id) + processed: BackendDelegateDataReference = delegate.processed + self.assertEqual(processed.location, DataLocation.INLINE) + self.assertLess(processed.index, len(program.backend_delegate_data)) + self.assertEqual( + program.backend_delegate_data[processed.index].data, expected_processed + ) + + def _test( + self, test_set: Dict[str, Tuple[ExportedProgram, Partitioner, List[str]]] + ): + method_to_edge_program = { + method_name: ep for method_name, (ep, _, _) in test_set.items() + } + + method_to_partitioner = { + method_name: partitioner + for method_name, (_, partitioner, _) in test_set.items() + } + + lowered_ep_dict = to_backend( + MethodProgramsPartitionerSpec( + method_to_edge_program, + method_to_partitioner, + ) + ) + + self.assertEqual(len(lowered_ep_dict.keys()), len(test_set.keys())) + for method_name in test_set.keys(): + self.assertTrue(method_name in lowered_ep_dict.keys()) + (_, _, list_of_payload_as_string) = test_set[method_name] + lowered_ep = lowered_ep_dict[method_name] + FileCheck().check_count( + "torch.ops.higher_order.executorch_call_delegate", + len(list_of_payload_as_string), + exactly=True, + ).run(str(lowered_ep)) + lowered_submodules = self._get_lowered_submodules_across_controlflow( + lowered_ep.graph_module + ) + self.assertEqual(len(lowered_submodules), len(list_of_payload_as_string)) + + for idx, (_, lowered_backend_module, _) in enumerate(lowered_submodules): + self.assertEqual( + lowered_backend_module.processed_bytes.decode("utf-8"), + list_of_payload_as_string[idx], + ) + + def test_multi_method_to_backend_single_method(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + edgeir_m = to_edge(torch.export.export(SinModule(), (torch.ones(1),))) + # Payload String: + # [Number of Ops lowered across all methods/partitions]#OpTargetNames#CompileSpecs; + test_set = { + "forward": ( + edgeir_m.exported_program(), + AllNodePartitioner( + "FirstBackendWithPreprocessAll", + [CompileSpec("max_value", bytes([1]))], + ), + [ + "FirstBackendWithPreprocessAll#1#aten.sin.default:#max_value:b'\\x01';" + ], + ) + } + self._test(test_set) + + def test_multi_method_to_backend_two_methods(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x + + sin_edgeir_m = to_edge(torch.export.export(SinModule(), (torch.ones(1),))) + add_edgeir_m = to_edge(torch.export.export(AddModule(), (torch.ones(1),))) + sin_partitioner = AllNodePartitioner( + "FirstBackendWithPreprocessAll", [CompileSpec("sin", bytes([2]))] + ) + add_partitioner = AllNodePartitioner( + "FirstBackendWithPreprocessAll", [CompileSpec("add", bytes([3]))] + ) + # Payload String: + # [Number of Ops lowered across all methods/partitions]#OpTargetNames#CompileSpecs; + test_set = { + "sin": ( + sin_edgeir_m.exported_program(), + sin_partitioner, + ["FirstBackendWithPreprocessAll#2#aten.sin.default:#sin:b'\\x02';"], + ), + "add": ( + add_edgeir_m.exported_program(), + add_partitioner, + ["FirstBackendWithPreprocessAll#2#aten.add.Tensor:#add:b'\\x03';"], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_two_methods_multiple_partitions(self): + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = x + x + y = y * y + y = y + y + return y + + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.sin(x) + y = y * y + return torch.sin(y) + + add_edgeir_m = to_edge(torch.export.export(AddModule(), (torch.ones(1),))) + sin_edgeir_m = to_edge(torch.export.export(SinModule(), (torch.ones(1),))) + test_set = { + "add": ( + add_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + ], + ), + "sin": ( + sin_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_two_methods_different_partitions(self): + class AddSinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = x + x + y = y * y + y = torch.sin(y) + return y + + class SinAddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.sin(x) + y = y * y + return y + y + + add_sin_edgeir_m = to_edge( + torch.export.export(AddSinModule(), (torch.ones(1),)) + ) + sin_add_edgeir_m = to_edge( + torch.export.export(SinAddModule(), (torch.ones(1),)) + ) + test_set = { + "add_sin": ( + add_sin_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + ], + ), + "sin_add": ( + sin_add_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_two_methods_different_backends(self): + class AddSinCosSubModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = x + x + y = torch.sin(y) + y = torch.cos(y) + y = y - x + return y + + class CosSubAddSinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.cos(x) + y = y - x + y = y + y + y = torch.sin(y) + return y + + first_second_edgeir_m = to_edge( + torch.export.export(AddSinCosSubModule(), (torch.ones(1),)) + ) + second_first_edgeir_m = to_edge( + torch.export.export(CosSubAddSinModule(), (torch.ones(1),)) + ) + test_set = { + "first_second": ( + first_second_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:aten.sin.default:#add:b'\\x00';sin:b'\\x01';", + "SecondBackendWithPreprocessAll#4#aten.cos.default:aten.sub.Tensor:#cos:b'\\x03';sub:b'\\x02';", + ], + ), + "second_first": ( + second_first_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "SecondBackendWithPreprocessAll#4#aten.cos.default:aten.sub.Tensor:#cos:b'\\x03';sub:b'\\x02';", + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:aten.sin.default:#add:b'\\x00';sin:b'\\x01';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_control_flow(self): + class SinCosModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def true_fn(self, x): + return torch.sin(x) + + def false_fn(self, x): + return torch.cos(x) + + def forward(self, x): + x = x + x + return torch.cond(x > 0, self.true_fn, self.false_fn, [x]) + + class SinAddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def true_fn(self, x): + return torch.sin(x) + + def false_fn(self, x): + return x + x + + def forward(self, x): + return torch.cond(x > 0, self.true_fn, self.false_fn, [x]) + + sin_cos_edgeir_m = to_edge( + torch.export.export(SinCosModule(), (torch.ones(1),)) + ) + sin_add_edgeir_m = to_edge( + torch.export.export(SinAddModule(), (torch.ones(1),)) + ) + + test_set = { + "sin_cos": ( + sin_cos_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + # True Module Partition + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + # False Module Partition + "SecondBackendWithPreprocessAll#1#aten.cos.default:#cos:b'\\x03';", + ], + ), + "sin_add": ( + sin_add_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + # True Module Partition + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + # False Module Partition + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_not_found(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x + + sin_edgeir_m = to_edge(torch.export.export(SinModule(), (torch.ones(1),))) + add_edgeir_m = to_edge(torch.export.export(AddModule(), (torch.ones(1),))) + sin_partitioner = AllNodePartitioner( + "Invalid", [CompileSpec("sin", bytes([2]))] + ) + add_partitioner = AllNodePartitioner( + "FirstBackendWithPreprocessAll", [CompileSpec("add", bytes([3]))] + ) + + test_set = { + "sin": ( + sin_edgeir_m.exported_program(), + sin_partitioner, + [], + ), + "add": ( + add_edgeir_m.exported_program(), + add_partitioner, + [], + ), + } + with self.assertRaisesRegex( + NotImplementedError, "Backend Invalid was not found." + ): + self._test(test_set) + + def test_multi_method_end_to_end(self): + """ + Tests multi method lowering end-to-end. Lowers the same Sin Module for two methods + "forward" and "forward_copy". Ensures that the lowered program has two delegates + but only one serialized blob. Ensures that the lowered program runs correctly. + """ + + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + sin_edgeir_m = to_edge(torch.export.export(SinModule(), (torch.ones(1),))) + sin_edgeir_m_copy = to_edge(torch.export.export(SinModule(), (torch.ones(1),))) + + method_edge_program = { + "forward": sin_edgeir_m.exported_program(), + "forward_copy": sin_edgeir_m_copy.exported_program(), + } + compile_specs = [CompileSpec("max_value", bytes([1]))] + + method_partitioner = { + "forward": AllNodePartitioner( + BackendWithCompilerDemo.__name__, compile_specs + ), + "forward_copy": AllNodePartitioner( + BackendWithCompilerDemo.__name__, compile_specs + ), + } + + lowered_ep_dict = to_backend( + MethodProgramsPartitionerSpec( + method_edge_program, + method_partitioner, + ) + ) + + new_edge_manager = EdgeProgramManager(lowered_ep_dict) + + exec_prog = new_edge_manager.to_executorch() + + program = exec_prog.executorch_program + # Since the preprocessed bytes are the same, there should only be on copy + self.assertEqual(len(program.backend_delegate_data), 1) + + self.check_backend_delegate( + program=program, + delegate=program.execution_plan[0].delegates[0], + expected_id=BackendWithCompilerDemo.__name__, + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + ) + self.check_backend_delegate( + program=program, + delegate=program.execution_plan[1].delegates[0], + expected_id=BackendWithCompilerDemo.__name__, + expected_processed=b"1version:0#op:demo::aten.sin.default, numel:1, dtype:torch.float322#", + ) + + # Check that there are two methods + self.assertEqual(len(program.execution_plan), 2) + + delegate_method_1 = program.execution_plan[0].delegates + delegate_method_2 = program.execution_plan[1].delegates + + # 1 delegate blob for each method + self.assertEqual(len(delegate_method_1), 1) + self.assertEqual(len(delegate_method_2), 1) + + # Delegate Blobs reference the same underlying bytes + delegate_reference1 = delegate_method_1[0].processed + delegate_reference2 = delegate_method_2[0].processed + self.assertEqual(delegate_reference1.index, delegate_reference2.index) + + et_module = _load_for_executorch_from_buffer(exec_prog.buffer) + model_inputs = torch.ones(1) + model_outputs = et_module.run_method("forward", [model_inputs]) + self.assertEqual(model_inputs, torch.ones(1)) + model_outputs_from_copy_method = et_module.run_method( + "forward_copy", [model_inputs] + ) + self.assertEqual(model_inputs, torch.ones(1)) + self.assertEqual(model_outputs, model_outputs_from_copy_method) + self.assertTrue( + torch.allclose( + model_outputs[0], 0.8333 * torch.ones(1), atol=1e-03, rtol=1e-03 + ) + ) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 6bcc1b2f3d8..78b031a238e 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -766,15 +766,15 @@ def create_submodule_from_nodes( gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) submodule_node = None for node in gm.graph.nodes: - if node.op == "call_module": - if node.target == submodule_name: - submodule_node = node - else: - raise RuntimeError( - f"The submodule created with nodes {node_list} did not form \ - one fully contained subgraph. Check that these nodes form a \ - fully contained graph. Partitioned graph: {gm.graph}." - ) + if node.op == "call_module" and node.target == submodule_name: + submodule_node = node + + if submodule_node is None: + raise RuntimeError( + f"The submodule created with nodes {node_list} did not form \ + one fully contained subgraph. Check that these nodes form a \ + fully contained graph. Partitioned graph: {gm.graph}." + ) if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor): # If the original output is a single tensor, it has been @@ -809,12 +809,13 @@ def create_submodule_from_nodes( for node in gm.graph.nodes: if node.op == "call_module" and node.target == submodule_name: submodule_node = node - elif node.op == "call_module": - raise RuntimeError( - f"The submodule created with nodes {node_list} did not form \ - one fully contained subgraph. Check that these nodes form a \ - fully contained graph. Partitioned graph: {gm.graph}." - ) + + if submodule_node is None: + raise RuntimeError( + f"The submodule created with nodes {node_list} did not form \ + one fully contained subgraph. Check that these nodes form a \ + fully contained graph. Partitioned graph: {gm.graph}." + ) assert ( submodule_node is not None