From e13b1f5dc11bf99eef16c6583d0169d0cb052a0c Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 1 Apr 2025 16:32:01 -0700 Subject: [PATCH 1/2] Update [ghstack-poisoned] --- exir/backend/backend_api.py | 373 ++++++++++++++- exir/backend/test/TARGETS | 52 ++ .../test/backend_with_preprocess_all_demo.py | 277 +++++++++++ .../test/test_to_backend_multi_method.py | 452 ++++++++++++++++++ exir/lowered_backend_module.py | 31 +- 5 files changed, 1169 insertions(+), 16 deletions(-) create mode 100644 exir/backend/test/backend_with_preprocess_all_demo.py create mode 100644 exir/backend/test/test_to_backend_multi_method.py diff --git a/exir/backend/backend_api.py b/exir/backend/backend_api.py index ab2e66f7885..0a2395b7724 100644 --- a/exir/backend/backend_api.py +++ b/exir/backend/backend_api.py @@ -8,8 +8,9 @@ import copy import logging from contextlib import contextmanager, nullcontext +from dataclasses import dataclass from functools import singledispatch -from typing import Generator, List +from typing import Dict, Generator, List import torch @@ -417,3 +418,373 @@ def to_backend( constants=tagged_exported_program.constants, verifiers=[tagged_exported_program.verifier], ) + + +def _create_partitions_in_graph_module( + tagged_graph_module: torch.fx.GraphModule, + partition_result: PartitionResult, + owning_program: ExportedProgram, + is_submodule: bool, +) -> Dict[str, List[torch.fx.Node]]: + backend_id_to_submodule_name = {} + for tag, delegation_spec in partition_result.partition_tags.items(): + # Create partition with nodes containing this tag. There should only be + # one contained submodule per tag + node_list = _get_node_list_with_same_tag( + tagged_graph_module, tag, owning_program + ) + + if len(node_list) == 0: + logging.debug(f"Did not find any nodes for tag {tag}") + continue + + logging.debug(f"For tag {tag}, found nodes {node_list}") + # Tag the nodes that are params as buffers, so we can order the submodule as (Parms + Buffers) (User Inputs) + + replace_ctx = ( + tagged_graph_module._set_replace_hook( + owning_program.graph_signature.get_replace_hook() + ) + if not is_submodule + else nullcontext() + ) + with replace_ctx: + submodule, call_module_node = create_submodule_from_nodes( + tagged_graph_module, node_list, tag + ) + + tagged_graph_module_output_node = [ + node for node in tagged_graph_module.graph.nodes if node.op == "output" + ][0] + submodule_output_node = [ + node for node in submodule.graph.nodes if node.op == "output" + ][0] + # Copy the output node meta from the original output node, because + # create_submodule_from_nodes doesn't cover the meta field + submodule_output_node.meta = tagged_graph_module_output_node.meta + logging.debug(f"Partitioned graph module: {tagged_graph_module}") + ( + submodule_program, + toplevel_input_specs_to_delete, + toplevel_output_specs_to_delete, + ) = create_exported_program_from_submodule( + submodule, + owning_program, + tag, + call_module_node, + is_submodule, + ) + call_module_node.meta["backend_id"] = delegation_spec.backend_id + call_module_node.meta["compile_spec"] = delegation_spec.compile_specs + call_module_node.meta["submodule_program"] = submodule_program + call_module_node.meta["toplevel_input_specs_to_delete"] = ( + toplevel_input_specs_to_delete + ) + call_module_node.meta["toplevel_output_specs_to_delete"] = ( + toplevel_output_specs_to_delete + ) + call_module_node.meta["is_submodule"] = is_submodule + + if delegation_spec.backend_id not in backend_id_to_submodule_name: + backend_id_to_submodule_name[delegation_spec.backend_id] = [] + + # The call_module_node created here might not be the same node instance as + # the one in the final graph module. This is because this node might be replaced + # in future edits to the graph. As a result, we just keep track of the node's name + # and at the end we search for this node in our final graph module + backend_id_to_submodule_name[delegation_spec.backend_id].append( + call_module_node.target + ) + + created_submodule_nodes = dict( + (key, []) for key in backend_id_to_submodule_name.keys() + ) + for backend_id, submodule_name in backend_id_to_submodule_name.items(): + for node in tagged_graph_module.graph.nodes: + if node.op == "call_module" and node.target in submodule_name: + created_submodule_nodes[backend_id].append(node) + + # check the number of submodule_names and submodule_nodes are equal + for backend_id in created_submodule_nodes.keys(): + assert len(created_submodule_nodes[backend_id]) == len( + backend_id_to_submodule_name[backend_id] + ) + + return created_submodule_nodes + + +def _create_partitions( + tagged_graph_module: torch.fx.GraphModule, + partition_result: PartitionResult, + owning_program: ExportedProgram, + is_submodule: bool = False, +) -> Dict[str, List[torch.fx.Node]]: + backend_id_to_call_submodules = _create_partitions_in_graph_module( + tagged_graph_module, partition_result, owning_program, is_submodule + ) + + # Recursively partition and lower for submodules + for _, submod, _ in get_control_flow_submodules(tagged_graph_module): + nested_backend_id_to_call_submodules = _create_partitions( + submod, partition_result, owning_program, is_submodule=True + ) + for ( + backend_id, + nested_submodules, + ) in nested_backend_id_to_call_submodules.items(): + if backend_id not in backend_id_to_call_submodules: + backend_id_to_call_submodules[backend_id] = nested_submodules + else: + backend_id_to_call_submodules[backend_id].extend(nested_submodules) + + return backend_id_to_call_submodules + + +def lower_all_submodules_to_backend( + backend_id: str, + method_to_submodules_nodes: Dict[str, List[torch.fx.Node]], + method_to_tagged_edge_program: Dict[str, ExportedProgram], +) -> None: + """ + Lower all submodules nodes given in the method_to_submodule_nodes map to backend_id. + """ + # The created exported program for the submodules are in the call_module node's meta data + # We just map the method_to_submodule_nodes directly to the method_to_partitioned_exported_programs + method_to_partitioned_program = { + method_name: [node.meta["submodule_program"] for node in call_submodule_nodes] + for method_name, call_submodule_nodes in method_to_submodules_nodes.items() + } + method_to_compile_specs = { + method_name: [node.meta["compile_spec"] for node in call_submodule_nodes] + for method_name, call_submodule_nodes in method_to_submodules_nodes.items() + } + backend_found = False + for cls in BackendDetails.__subclasses__(): + if backend_id == cls.__name__: + method_to_preprocess_result: dict[str, List[PreprocessResult]] = ( + cls.preprocess_multimethod( + method_to_partitioned_program, method_to_compile_specs + ) + ) + backend_found = True + + if not backend_found: + raise NotImplementedError(f"Backend {backend_id} was not found.") + + for method_name in method_to_preprocess_result.keys(): + owning_program = method_to_tagged_edge_program[method_name] + list_of_preprocess_results = method_to_preprocess_result[method_name] + list_of_call_submodule_nodes = method_to_submodules_nodes[method_name] + list_of_compile_specs = method_to_compile_specs[method_name] + assert ( + len(list_of_preprocess_results) == len(list_of_call_submodule_nodes), + f"Expected {len(list_of_call_submodule_nodes)} preprocessed results for method {method_name} but got {len(list_of_preprocess_results)}", + ) + for preprocess_result, call_submodule_node, compile_spec in zip( + list_of_preprocess_results, + list_of_call_submodule_nodes, + list_of_compile_specs, + ): + submodule_program = call_submodule_node.meta["submodule_program"] + lowered_module = LoweredBackendModule( + edge_program=submodule_program, + backend_id=backend_id, + processed_bytes=preprocess_result.processed_bytes, + compile_specs=compile_spec, + ) + owning_graph_module = call_submodule_node.graph.owning_module + is_submodule = call_submodule_node.meta["is_submodule"] + toplevel_input_specs_to_delete = call_submodule_node.meta[ + "toplevel_input_specs_to_delete" + ] + toplevel_output_specs_to_delete = call_submodule_node.meta[ + "toplevel_output_specs_to_delete" + ] + # call delegate args should only use user_inputs + call_delegate_args = [] + # Preserve input order as user_inputs + for inp_name in submodule_program.graph_signature.user_inputs: + for inp_node in call_submodule_node.all_input_nodes: + if inp_node.name == inp_name: + call_delegate_args.append(inp_node) + break + + def generate_debug_handle(ep: ExportedProgram) -> int: + """ + Generate a debug handle for the given ExportedProgram. + """ + debug_handle = 0 + for node in ep.graph_module.graph.nodes: + debug_handle = max(debug_handle, node.meta.get("debug_handle", 0)) + return debug_handle + 1 + + # Replace the partitioned submodule with a lowered submodule + # Add call_method node with function "forward" + with owning_graph_module.graph.inserting_before(call_submodule_node): + lowered_name = get_lowered_module_name( + owning_graph_module, lowered_module + ) + lowered_node = owning_graph_module.graph.get_attr(lowered_name) + call_delegate_node = owning_graph_module.graph.call_function( + executorch_call_delegate, + (lowered_node,) + tuple(call_delegate_args), + call_submodule_node.kwargs, + ) + call_delegate_node.meta["debug_handle"] = generate_debug_handle( + owning_program + ) + call_delegate_node.meta["val"] = call_submodule_node.meta["val"] + call_submodule_node.replace_all_uses_with(call_delegate_node) + owning_graph_module.graph.erase_node(call_submodule_node) + + if is_submodule: + assert len(toplevel_input_specs_to_delete) == 0 + assert len(toplevel_output_specs_to_delete) == 0 + elif ( + len(toplevel_input_specs_to_delete) > 0 + or len(toplevel_output_specs_to_delete) > 0 + ): + _unsafe_adjust_original_program( + owning_program, + call_delegate_node, + toplevel_input_specs_to_delete, + toplevel_output_specs_to_delete, + ) + + +@dataclass +class MethodProgramsPartitionerSpec: + """ + Since single dispatch for to_backend requires the first argument to be a + valid class, we create the following dataclass spec to hold the dictionaries + mapping the method name to the corresponding program, partitioner + """ + + method_to_edge_program: Dict[str, ExportedProgram] + method_to_partitioner: Dict[str, Partitioner] + + +@to_backend.register +def _( + method_edge_program_partitioners: MethodProgramsPartitionerSpec, +) -> Dict[str, ExportedProgram]: + """ + Add overloaded implementations for to_backend: + + :: + + def to_backend( + method_edge_program_partitioners: MethodProgramsPartitionerSpec + ) -> Dict[str, ExportedProgram]: + + Returns a semantically-equivalent dictionary of programs to the programs given as input (represented + as a graph module in Edge dialect), but with portions of the program targeted for + delegation as determined by the partitioner. + + Args: + method_edge_program_partitioners: contains two mappings, + - method_to_edge_program: mapping of method names to their respective programs in Edge dialect. + - method_to_partitioner: mapping of method names to an instance of the partitioner, in charge with tagging + portions of the specified program for delegation. A valid partitioner must return PartitionerResult + including both tagged exported program and partitioner_tag: Dict[str, DelegationSpec], where each key is a tag name and + the nodes with same tag will be fused a one subgraph and delegated to backend specififed in delegation spec. + + + Returns: + ExportedProgram: The input program, with some portions targeted for delegation. + """ + method_to_edge_program = method_edge_program_partitioners.method_to_edge_program + method_to_partitioner = method_edge_program_partitioners.method_to_partitioner + + partitioned_and_lowered_exported_programs = {} + backend_id_to_method_submodules_map = {} + method_to_tagged_exported_program = {} + + for method_name, partitioner_instance in method_to_partitioner.items(): + assert ( + method_name in method_to_edge_program + ), f"Partitioner for method {method_name} is not provided" + edge_program = method_to_edge_program[method_name] + edge_program._validate() + + # Use fake program, with FakeTensors in the state dict, to avoid copying large constant values. + # Fall back to deepcopy if no fake mode is found. TODO(T182910699): Remove this fallback. + try: + fake_edge_program = get_fake_program(edge_program) + except Exception as e: + logging.warning( + f"Error in get_fake_program for graph {edge_program.graph_module}, fallback to deepcopy: {e}" + ) + fake_edge_program = copy.deepcopy(edge_program) + partitioner_result = partitioner_instance(fake_edge_program) + tagged_exported_program = partitioner_result.tagged_exported_program + method_to_tagged_exported_program[method_name] = tagged_exported_program + + # Check that the partitioner did not modify the original graph + if _ENABLE_VALIDATION: + assert is_identical_graph( + tagged_exported_program.graph_module, + edge_program.graph_module, + ), f"The partitioner {partitioner_instance} should not modify the graph module" + else: + logging.warning("Disabled validating the partitioner.") + + assert ( + partitioner_result.partition_tags is not None + ), f"Partitioner {partitioner_instance} needs a `partition_tags` field containing a mapping of tags to delegate spec" + + update_to_real_program(tagged_exported_program, edge_program) + + for tag, _ in partitioner_result.partition_tags.items(): + _maybe_duplicate_constant_nodes(tagged_exported_program, tag) + + backend_id_to_call_submodule_nodes = _create_partitions( + tagged_exported_program.graph_module, + partitioner_result, + tagged_exported_program, + ) + for ( + backend_id, + call_submodule_nodes, + ) in backend_id_to_call_submodule_nodes.items(): + if backend_id not in backend_id_to_method_submodules_map: + backend_id_to_method_submodules_map[backend_id] = {} + backend_id_to_method_submodules_map[backend_id][ + method_name + ] = call_submodule_nodes + + for ( + backend_id, + method_to_submodule_nodes, + ) in backend_id_to_method_submodules_map.items(): + lower_all_submodules_to_backend( + backend_id, + method_to_submodule_nodes, + method_to_tagged_exported_program, + ) + + for method_name in method_to_edge_program.keys(): + if method_name in method_to_tagged_exported_program: + tagged_exported_program = method_to_tagged_exported_program[method_name] + partitioned_and_lowered_exported_programs[method_name] = ExportedProgram( + root=tagged_exported_program.graph_module, + graph=tagged_exported_program.graph_module.graph, + graph_signature=tagged_exported_program.graph_signature, + state_dict=tagged_exported_program.state_dict, + range_constraints=copy.deepcopy( + tagged_exported_program.range_constraints + ), + module_call_graph=copy.deepcopy( + tagged_exported_program.module_call_graph + ), + example_inputs=None, + constants=tagged_exported_program.constants, + verifiers=[tagged_exported_program.verifier], + ) + else: + # this edge program wasn't partitioned, so we can just return it as is + partitioned_and_lowered_exported_programs[method_name] = ( + method_to_edge_program[method_name] + ) + + return partitioned_and_lowered_exported_programs diff --git a/exir/backend/test/TARGETS b/exir/backend/test/TARGETS index f0ba618936d..b7f46076868 100644 --- a/exir/backend/test/TARGETS +++ b/exir/backend/test/TARGETS @@ -189,6 +189,58 @@ python_unittest( ], ) +python_unittest( + name = "test_to_backend_multi_method", + srcs = [ + "test_to_backend_multi_method.py", + ], + preload_deps = [ + "//executorch/kernels/portable:custom_ops_generated_lib", + "//executorch/kernels/quantized:custom_ops_generated_lib", + "//executorch/runtime/executor/test:test_backend_compiler_lib", + ], + deps = [ + ":backend_with_preprocess_all_demo", + "//caffe2:torch", + "//caffe2/functorch:functorch_src", + "//executorch/exir:delegate", + "//executorch/exir:graph_module", + "//executorch/exir:lib", + "//executorch/exir:lowered_backend_module", + "//executorch/exir:print_program", + "//executorch/exir:schema", + "//executorch/exir/backend:backend_api", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/backend:partitioner", + "//executorch/exir/dialects:lib", + "//executorch/extension/pybindings:portable_lib", # @manual + "//executorch/extension/pytree:pylib", + ], +) + +python_library( + name = "backend_with_preprocess_all_demo", + srcs = [ + "backend_with_preprocess_all_demo.py" + ], + deps = [ + "//caffe2:torch", + "//caffe2/functorch:functorch_src", + "//executorch/exir:delegate", + "//executorch/exir:graph_module", + "//executorch/exir:lib", + "//executorch/exir:lowered_backend_module", + "//executorch/exir:print_program", + "//executorch/exir:schema", + "//executorch/exir/backend:backend_api", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/backend:partitioner", + "//executorch/exir/dialects:lib", + "//executorch/extension/pybindings:portable_lib", # @manual + "//executorch/extension/pytree:pylib", + ], +) + python_unittest( name = "test_debug_handle_map", srcs = [ diff --git a/exir/backend/test/backend_with_preprocess_all_demo.py b/exir/backend/test/backend_with_preprocess_all_demo.py new file mode 100644 index 00000000000..be0c15fd765 --- /dev/null +++ b/exir/backend/test/backend_with_preprocess_all_demo.py @@ -0,0 +1,277 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, final, List, Tuple + +import torch + +from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult +from executorch.exir.backend.canonical_partitioners.pattern_op_partitioner import ( + generate_pattern_op_partitions, +) + +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import ( + DelegationSpec, + Partitioner, + PartitionResult, +) +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.graph_module import get_control_flow_submodules +from torch.export.exported_program import ExportedProgram +from torch.export import ExportedProgram +from torch.fx.passes.operator_support import any_chain, OperatorSupportBase + + +def _preprocess_multimethod( + edge_programs: Dict[str, List[ExportedProgram]], + compile_specs: Dict[str, List[List[CompileSpec]]], + supported_ops: List[torch._ops.OpOverload], + backend_name: str, +) -> Dict[str, List[PreprocessResult]]: + """ + Helper function to abstract out the logic to be shared between the two backends: + FirstBackendWithPreprocessAll and SecondBackendWithPreprocessAll. This will be used + in testing for a partitioner which tags different partitions for different backends + to be lowered to + """ + total_number_of_ops = 0 + for edge_program in edge_programs.values(): + for partitioned_program in edge_program: + for node in partitioned_program.graph.nodes: + if node.op == "call_function": + if node.target in supported_ops: + total_number_of_ops += 1 + all_processed_results = dict((key, []) for key in edge_programs.keys()) + + for method_name, partitioned_programs in edge_programs.items(): + compile_specs_for_method = compile_specs[method_name] + + assert len(compile_specs_for_method) == len(partitioned_programs) + for compile_spec_for_partition, partitioned_program in zip( + compile_specs_for_method, partitioned_programs + ): + debug_handle_map = {} + processed_bytes = f"{backend_name}#{total_number_of_ops}#" + for node in partitioned_program.graph.nodes: + if node.op == "call_function": + if node.target in supported_ops: + op_name = node.target.__name__ + processed_bytes += f"{op_name}:" + original_debug_id = node.meta["debug_handle"] + new_debug_id = original_debug_id + debug_handle_map[new_debug_id] = (original_debug_id,) + else: + raise RuntimeError( + f"{node.op} {node.target.__name__} is not supported in backend {backend_name}" + ) + elif node.op == "placeholder": + continue + elif node.op == "output": + continue + elif node.op == "get_attr": + continue + else: + raise RuntimeError( + f"{node.op} is not supported in backend {backend_name}" + ) + + processed_bytes += "#" + for cs in compile_spec_for_partition: + processed_bytes += f"{cs.key}:{cs.value};" + + all_processed_results[method_name].append( + PreprocessResult( + processed_bytes=bytes(processed_bytes, encoding="utf8"), + debug_handle_map=debug_handle_map, + ) + ) + + return all_processed_results + + +@final +class FirstBackendWithPreprocessAll(BackendDetails): + """ + Backend used to test the preprocess_multimethod for multi methods lowering. + lowered modules are returned in the format: + FirstBackendWithPreprocessAll##::#;: + + + lowered blobs are not functional, and are purely used for testing purposes + """ + + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Not used for testing + """ + return PreprocessResult( + processed_bytes=bytes(b"\x00"), + debug_handle_map={}, + ) + + @staticmethod + def preprocess_multimethod( + edge_programs: Dict[str, List[ExportedProgram]], + compile_specs: Dict[str, List[List[CompileSpec]]], + ) -> Dict[str, list[PreprocessResult]]: + """ + Preprocess all the edge programs in the given dictionary and return a dictionary + of preprocess results. The preprocess result is a tuple of processed bytes and + a map from the node name to the original debug handle. + """ + match_ops = [ + exir_ops.edge.aten.sin.default, + exir_ops.edge.aten.add.Tensor, + ] + + return _preprocess_multimethod( + edge_programs, compile_specs, match_ops, "FirstBackendWithPreprocessAll" + ) + + +@final +class SecondBackendWithPreprocessAll(BackendDetails): + """ + Backend used to test the preprocess_multimethod for multi methods lowering. + lowered modules are returned in the format: + SecondBackendWithPreprocessAll##::#;: + + + lowered blobs are not functional, and are purely used for testing purposes + """ + + @staticmethod + def preprocess( + edge_program: ExportedProgram, + compile_specs: List[CompileSpec], + ) -> PreprocessResult: + """ + Not used for testing + """ + return PreprocessResult( + processed_bytes=bytes(b"\x00"), + debug_handle_map={}, + ) + + @staticmethod + def preprocess_multimethod( + edge_programs: Dict[str, List[ExportedProgram]], + compile_specs: Dict[str, List[List[CompileSpec]]], + ) -> Dict[str, list[PreprocessResult]]: + """ + Preprocess all the edge programs in the given dictionary and return a dictionary + of preprocess results. The preprocess result is a tuple of processed bytes and + a map from the node name to the original debug handle. + """ + match_ops = [ + exir_ops.edge.aten.cos.default, + exir_ops.edge.aten.sub.Tensor, + ] + + return _preprocess_multimethod( + edge_programs, compile_specs, match_ops, "SecondBackendWithPreprocessAll" + ) + + +class AddSinOperatorSupport(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + exir_ops.edge.aten.add.Tensor, + exir_ops.edge.aten.sin.default, + ] + + +class SubCosOperatorSupport(OperatorSupportBase): + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + return node.op == "call_function" and node.target in [ + exir_ops.edge.aten.sub.Tensor, + exir_ops.edge.aten.cos.default, + ] + + +@final +class BackendWithPreprocessAllPartitioner(Partitioner): + """ + Partitioner that partitions for both FirstBackendWithPreprocessAll + and SecondBackendWithPreprocessAll. + + - The partitioner tags all sin and add nodes for delegation to + FirstBackendWithPreprocessAll + - The partitioner tags all cos and sub nodes for delegation to + SecondBackendWithPreprocessAll + """ + + def __init__(self) -> None: + self.add_sin_support = any_chain(AddSinOperatorSupport()) + self.add_sin_backend_id = FirstBackendWithPreprocessAll.__name__ + + self.sub_cos_support = any_chain(SubCosOperatorSupport()) + self.sub_cos_backend_id = SecondBackendWithPreprocessAll.__name__ + + def _partition_graph_module( + self, + graph_module: torch.fx.GraphModule, + id_start=0, + ) -> Tuple[Dict[str, DelegationSpec], int]: + partition_tags: Dict[str, DelegationSpec] = {} + + num_partitions_in_gm = 0 + for op_support, backend_id, tag_prefix in [ + (self.add_sin_support, self.add_sin_backend_id, "first"), + (self.sub_cos_support, self.sub_cos_backend_id, "second"), + ]: + partition_list = generate_pattern_op_partitions( + graph_module, op_support=op_support + ) + num_partitions_in_gm = num_partitions_in_gm + len(partition_list) + for partition in partition_list: + compile_specs = [] + delegation_tag = f"{tag_prefix}_tag{id_start + partition.id}" + for node in partition.nodes: + node.meta["delegation_tag"] = delegation_tag + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.add.Tensor + ): + compile_specs.append(CompileSpec("add", bytes(b"\x00"))) + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.sin.default + ): + compile_specs.append(CompileSpec("sin", bytes(b"\x01"))) + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.sub.Tensor + ): + compile_specs.append(CompileSpec("sub", bytes(b"\x02"))) + if ( + node.op == "call_function" + and node.target == exir_ops.edge.aten.cos.default + ): + compile_specs.append(CompileSpec("cos", bytes(b"\x03"))) + + delegation_spec = DelegationSpec(backend_id, compile_specs) + partition_tags[delegation_tag] = delegation_spec + + start_idx_for_submodules = num_partitions_in_gm + for _, submodule, _ in get_control_flow_submodules(graph_module): + ret_partition_tags, start_idx_for_submodules = self._partition_graph_module( + submodule, id_start=start_idx_for_submodules + ) + partition_tags.update(ret_partition_tags) + + return partition_tags, start_idx_for_submodules + + def partition(self, exported_program: ExportedProgram) -> PartitionResult: + partition_tags, _ = self._partition_graph_module(exported_program.graph_module) + return PartitionResult( + tagged_exported_program=exported_program, partition_tags=partition_tags + ) diff --git a/exir/backend/test/test_to_backend_multi_method.py b/exir/backend/test/test_to_backend_multi_method.py new file mode 100644 index 00000000000..d33ac2c6427 --- /dev/null +++ b/exir/backend/test/test_to_backend_multi_method.py @@ -0,0 +1,452 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import Dict, List, Tuple + +import torch + +from executorch.exir import to_edge +from executorch.exir.backend.backend_api import ( + MethodProgramsPartitionerSpec, + to_backend, +) +from executorch.exir.backend.canonical_partitioners.all_node_partitioner import ( + AllNodePartitioner, +) +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.partitioner import Partitioner + +from executorch.exir.backend.test.backend_with_preprocess_all_demo import ( + BackendWithPreprocessAllPartitioner, +) +from executorch.exir.graph_module import get_control_flow_submodules +from executorch.exir.lowered_backend_module import ( + get_lowered_submodules, + LoweredBackendModule, +) +from torch.export.exported_program import ExportedProgram + +from torch.testing import FileCheck + + +class TestToBackendMultiMethod(unittest.TestCase): + """ + Testing suite used to test multi method to_backend lowering. The test suite uses demo backends + FirstBackendWithPreprocessAll and SecondBackendWithPreprocessAll. + - FirstBackendWithPreprocessAll: supports add + sin + - SecondBackendWithPreprocessAll: supports sub + cos + + Both backends lower exported programs into payloads in the string format: + - (backend_id)#(total_number_ops across methods)#[op_target_name;]#[compile_spec.key:compile_spec.value;] + + We leverage the above expectation to test various lowering across different modules, and ensure + that the right exported programs and compile specs are given when lowering a specifed exported program + + We leverage the demo partitioner BackendWithPreprocessAll which partitions add + sin nodes to + FirstBackendWithPreprocessAll and sub + cos nodes to SecondBackendWithPreprocessAll. This allows + us to test cases in which multiple backends are being lowered. + """ + + def _get_lowered_submodules_across_controlflow( + self, graph_module: torch.fx.GraphModule + ) -> List[Tuple[str, LoweredBackendModule, torch.fx.Node]]: + top_level_submodules = get_lowered_submodules(graph_module) + + for _, submodule, _ in get_control_flow_submodules(graph_module): + top_level_submodules.extend( + self._get_lowered_submodules_across_controlflow(submodule) + ) + + return top_level_submodules + + def _test( + self, test_set: Dict[str, Tuple[ExportedProgram, Partitioner, List[str]]] + ): + method_to_edge_program = { + method_name: ep for method_name, (ep, _, _) in test_set.items() + } + + method_to_partitioner = { + method_name: partitioner + for method_name, (_, partitioner, _) in test_set.items() + } + + lowered_ep_dict = to_backend( + MethodProgramsPartitionerSpec( + method_to_edge_program, + method_to_partitioner, + ) + ) + + self.assertEqual(len(lowered_ep_dict.keys()), len(test_set.keys())) + for method_name in test_set.keys(): + self.assertTrue(method_name in lowered_ep_dict.keys()) + (_, _, list_of_payload_as_string) = test_set[method_name] + lowered_ep = lowered_ep_dict[method_name] + FileCheck().check_count( + "torch.ops.higher_order.executorch_call_delegate", + len(list_of_payload_as_string), + exactly=True, + ).run(str(lowered_ep)) + lowered_submodules = self._get_lowered_submodules_across_controlflow( + lowered_ep.graph_module + ) + self.assertEqual(len(lowered_submodules), len(list_of_payload_as_string)) + + for idx, (_, lowered_backend_module, _) in enumerate(lowered_submodules): + self.assertEqual( + lowered_backend_module.processed_bytes.decode("utf-8"), + list_of_payload_as_string[idx], + ) + + def test_multi_method_to_backend_single_method(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + edgeir_m = to_edge(torch.export.export(SinModule(), (torch.ones(1),))) + # Payload String: + # [Number of Ops lowered across all methods/partitions]#OpTargetNames#CompileSpecs; + test_set = { + "forward": ( + edgeir_m.exported_program(), + AllNodePartitioner( + "FirstBackendWithPreprocessAll", + [CompileSpec("max_value", bytes([1]))], + ), + [ + "FirstBackendWithPreprocessAll#1#aten.sin.default:#max_value:b'\\x01';" + ], + ) + } + self._test(test_set) + + def test_multi_method_to_backend_two_methods(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x + + sin_edgeir_m = to_edge(torch.export.export(SinModule(), (torch.ones(1),))) + add_edgeir_m = to_edge(torch.export.export(AddModule(), (torch.ones(1),))) + sin_partitioner = AllNodePartitioner( + "FirstBackendWithPreprocessAll", [CompileSpec("sin", bytes([2]))] + ) + add_partitioner = AllNodePartitioner( + "FirstBackendWithPreprocessAll", [CompileSpec("add", bytes([3]))] + ) + # Payload String: + # [Number of Ops lowered across all methods/partitions]#OpTargetNames#CompileSpecs; + test_set = { + "sin": ( + sin_edgeir_m.exported_program(), + sin_partitioner, + ["FirstBackendWithPreprocessAll#2#aten.sin.default:#sin:b'\\x02';"], + ), + "add": ( + add_edgeir_m.exported_program(), + add_partitioner, + ["FirstBackendWithPreprocessAll#2#aten.add.Tensor:#add:b'\\x03';"], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_two_methods_multiple_partitions(self): + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = x + x + y = y * y + y = y + y + return y + + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.sin(x) + y = y * y + return torch.sin(y) + + add_edgeir_m = to_edge(torch.export.export(AddModule(), (torch.ones(1),))) + sin_edgeir_m = to_edge(torch.export.export(SinModule(), (torch.ones(1),))) + test_set = { + "add": ( + add_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + ], + ), + "sin": ( + sin_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_two_methods_different_partitions(self): + class AddSinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = x + x + y = y * y + y = torch.sin(y) + return y + + class SinAddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.sin(x) + y = y * y + return y + y + + add_sin_edgeir_m = to_edge( + torch.export.export(AddSinModule(), (torch.ones(1),)) + ) + sin_add_edgeir_m = to_edge( + torch.export.export(SinAddModule(), (torch.ones(1),)) + ) + test_set = { + "add_sin": ( + add_sin_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + ], + ), + "sin_add": ( + sin_add_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_two_methods_different_partitions(self): + class AddSinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = x + x + y = y * y + y = torch.sin(y) + return y + + class SinAddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.sin(x) + y = y * y + return y + y + + add_sin_edgeir_m = to_edge( + torch.export.export(AddSinModule(), (torch.ones(1),)) + ) + sin_add_edgeir_m = to_edge( + torch.export.export(SinAddModule(), (torch.ones(1),)) + ) + test_set = { + "add_sin": ( + add_sin_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + ], + ), + "sin_add": ( + sin_add_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_two_methods_different_backends(self): + class AddSinCosSubModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = x + x + y = torch.sin(y) + y = torch.cos(y) + y = y - x + return y + + class CosSubAddSinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.cos(x) + y = y - x + y = y + y + y = torch.sin(y) + return y + + first_second_edgeir_m = to_edge( + torch.export.export(AddSinCosSubModule(), (torch.ones(1),)) + ) + second_first_edgeir_m = to_edge( + torch.export.export(CosSubAddSinModule(), (torch.ones(1),)) + ) + test_set = { + "first_second": ( + first_second_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:aten.sin.default:#add:b'\\x00';sin:b'\\x01';", + "SecondBackendWithPreprocessAll#4#aten.cos.default:aten.sub.Tensor:#cos:b'\\x03';sub:b'\\x02';", + ], + ), + "second_first": ( + second_first_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "SecondBackendWithPreprocessAll#4#aten.cos.default:aten.sub.Tensor:#cos:b'\\x03';sub:b'\\x02';", + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:aten.sin.default:#add:b'\\x00';sin:b'\\x01';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_control_flow(self): + class SinCosModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def true_fn(self, x): + return torch.sin(x) + + def false_fn(self, x): + return torch.cos(x) + + def forward(self, x): + x = x + x + return torch.cond(x > 0, self.true_fn, self.false_fn, [x]) + + class SinAddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def true_fn(self, x): + return torch.sin(x) + + def false_fn(self, x): + return x + x + + def forward(self, x): + return torch.cond(x > 0, self.true_fn, self.false_fn, [x]) + + sin_cos_edgeir_m = to_edge( + torch.export.export(SinCosModule(), (torch.ones(1),)) + ) + sin_add_edgeir_m = to_edge( + torch.export.export(SinAddModule(), (torch.ones(1),)) + ) + + test_set = { + "sin_cos": ( + sin_cos_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + # True Module Partition + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + # False Module Partition + "SecondBackendWithPreprocessAll#1#aten.cos.default:#cos:b'\\x03';", + ], + ), + "sin_add": ( + sin_add_edgeir_m.exported_program(), + BackendWithPreprocessAllPartitioner(), + [ + # True Module Partition + "FirstBackendWithPreprocessAll#4#aten.sin.default:#sin:b'\\x01';", + # False Module Partition + "FirstBackendWithPreprocessAll#4#aten.add.Tensor:#add:b'\\x00';", + ], + ), + } + self._test(test_set) + + def test_multi_method_to_backend_not_found(self): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + class AddModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + x + + sin_edgeir_m = to_edge(torch.export.export(SinModule(), (torch.ones(1),))) + add_edgeir_m = to_edge(torch.export.export(AddModule(), (torch.ones(1),))) + sin_partitioner = AllNodePartitioner( + "Invalid", [CompileSpec("sin", bytes([2]))] + ) + add_partitioner = AllNodePartitioner( + "FirstBackendWithPreprocessAll", [CompileSpec("add", bytes([3]))] + ) + + test_set = { + "sin": ( + sin_edgeir_m.exported_program(), + sin_partitioner, + [], + ), + "add": ( + sin_edgeir_m.exported_program(), + add_partitioner, + [], + ), + } + with self.assertRaisesRegex( + NotImplementedError, "Backend Invalid was not found." + ): + self._test(test_set) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 6bcc1b2f3d8..78b031a238e 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -766,15 +766,15 @@ def create_submodule_from_nodes( gm = insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) submodule_node = None for node in gm.graph.nodes: - if node.op == "call_module": - if node.target == submodule_name: - submodule_node = node - else: - raise RuntimeError( - f"The submodule created with nodes {node_list} did not form \ - one fully contained subgraph. Check that these nodes form a \ - fully contained graph. Partitioned graph: {gm.graph}." - ) + if node.op == "call_module" and node.target == submodule_name: + submodule_node = node + + if submodule_node is None: + raise RuntimeError( + f"The submodule created with nodes {node_list} did not form \ + one fully contained subgraph. Check that these nodes form a \ + fully contained graph. Partitioned graph: {gm.graph}." + ) if len(orig_outputs) == 1 and isinstance(orig_outputs[0].meta["val"], FakeTensor): # If the original output is a single tensor, it has been @@ -809,12 +809,13 @@ def create_submodule_from_nodes( for node in gm.graph.nodes: if node.op == "call_module" and node.target == submodule_name: submodule_node = node - elif node.op == "call_module": - raise RuntimeError( - f"The submodule created with nodes {node_list} did not form \ - one fully contained subgraph. Check that these nodes form a \ - fully contained graph. Partitioned graph: {gm.graph}." - ) + + if submodule_node is None: + raise RuntimeError( + f"The submodule created with nodes {node_list} did not form \ + one fully contained subgraph. Check that these nodes form a \ + fully contained graph. Partitioned graph: {gm.graph}." + ) assert ( submodule_node is not None From fb05e4c4f99a54dc83f6acc7760c82b8fbf4a820 Mon Sep 17 00:00:00 2001 From: Max Ren Date: Tue, 8 Apr 2025 15:40:02 -0700 Subject: [PATCH 2/2] Update [ghstack-poisoned] --- exir/backend/test/test_to_backend_multi_method.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/exir/backend/test/test_to_backend_multi_method.py b/exir/backend/test/test_to_backend_multi_method.py index 183bf7d9285..d4f8fccb8f2 100644 --- a/exir/backend/test/test_to_backend_multi_method.py +++ b/exir/backend/test/test_to_backend_multi_method.py @@ -434,6 +434,12 @@ def forward(self, x): self._test(test_set) def test_multi_method_end_to_end(self): + """ + Tests multi method lowering end-to-end. Lowers the same Sin Module for two methods + "forward" and "forward_copy". Ensures that the lowered program has two delegates + but only one serialized blob. Ensures that the lowered program runs correctly. + """ + class SinModule(torch.nn.Module): def __init__(self): super().__init__() @@ -471,6 +477,7 @@ def forward(self, x): exec_prog = new_edge_manager.to_executorch() program = exec_prog.executorch_program + # Since the preprocessed bytes are the same, there should only be on copy self.assertEqual(len(program.backend_delegate_data), 1) self.check_backend_delegate(