diff --git a/backends/xnnpack/operators/__init__.py b/backends/xnnpack/operators/__init__.py index dcffa42ac3f..9d81b7f8e29 100644 --- a/backends/xnnpack/operators/__init__.py +++ b/backends/xnnpack/operators/__init__.py @@ -30,6 +30,7 @@ op_minimum, op_multiply, op_negate, + op_permute, op_prelu, op_quantize_per_tensor, op_relu, @@ -42,7 +43,6 @@ op_squeeze, op_static_constant_pad, op_static_resize_bilinear_2d, - op_static_transpose, op_sub, op_to_copy, ) diff --git a/backends/xnnpack/operators/op_static_transpose.py b/backends/xnnpack/operators/op_permute.py similarity index 97% rename from backends/xnnpack/operators/op_static_transpose.py rename to backends/xnnpack/operators/op_permute.py index ce1cd43c1ad..0ca92a7a039 100644 --- a/backends/xnnpack/operators/op_static_transpose.py +++ b/backends/xnnpack/operators/op_permute.py @@ -20,7 +20,7 @@ @register_node_visitor -class StaticTransposeVisitor(NodeVisitor): +class PermuteVisitor(NodeVisitor): target = "aten.permute_copy.default" def __init__(self, *args) -> None: diff --git a/backends/xnnpack/operators/op_skip_ops.py b/backends/xnnpack/operators/op_skip_ops.py index 83b6eee32b0..345b7896d34 100644 --- a/backends/xnnpack/operators/op_skip_ops.py +++ b/backends/xnnpack/operators/op_skip_ops.py @@ -113,12 +113,3 @@ class OpSymSizeInt(OpSkipOps): """ target = "sym_size.int" - - -@register_node_visitor -class OpPermuteCopyDefault(OpSkipOps): - """ - do nothing if node is permute_copy.default - """ - - target = "aten.permute_copy.default" diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 0c1c9e6d42c..8498bd84c5f 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -1517,6 +1517,7 @@ __ET_NODISCARD Error XNNCompiler::compileModel( if (!executor->qinputs_.empty() && flatbuffer_graph->xnodes()->size() > 0 && flatbuffer_graph->xnodes()->Get(0)->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) { +#ifdef ENABLE_DYNAMIC_QUANTIZATION // This delegate is for DQLinear which supports dynamic input shapes if (executor->getNumInputs() < 1 || executor->getNumOutputs() != 1) { ET_LOG( @@ -1525,6 +1526,10 @@ __ET_NODISCARD Error XNNCompiler::compileModel( return Error::NotSupported; } executor->setNeedsResizeOutput(); +#else + ET_LOG(Error, "DQ Linear is not supported"); + return Error::NotSupported; +#endif } return err; diff --git a/backends/xnnpack/runtime/XNNExecutor.cpp b/backends/xnnpack/runtime/XNNExecutor.cpp index 5e39c86c1ba..30b60ee329d 100644 --- a/backends/xnnpack/runtime/XNNExecutor.cpp +++ b/backends/xnnpack/runtime/XNNExecutor.cpp @@ -7,7 +7,9 @@ */ #include +#ifdef ENABLE_DYNAMIC_QUANTIZATION #include +#endif namespace torch { namespace executor { @@ -17,6 +19,7 @@ namespace delegate { Error XNNExecutor::set_external_input(uint32_t id, Tensor* input) { auto qinput_pair = qinputs_.find(id); if (qinput_pair != qinputs_.end()) { +#ifdef ENABLE_DYNAMIC_QUANTIZATION auto qinput = qinput_pair->second; // dq the input and copy it in to qinput float input_min, input_max; @@ -60,6 +63,10 @@ Error XNNExecutor::set_external_input(uint32_t id, Tensor* input) { {static_cast(input_qparam.scale), static_cast(input_qparam.zero_point)}, batch_size}); +#else + ET_LOG(Error, "Dynamic Quantization is not supported"); + return Error::NotSupported; +#endif } else { externals_.emplace_back(xnn_external_value{id, input->mutable_data_ptr()}); } diff --git a/backends/xnnpack/targets.bzl b/backends/xnnpack/targets.bzl index c4aa62e7f8e..b53a1fc8f0f 100644 --- a/backends/xnnpack/targets.bzl +++ b/backends/xnnpack/targets.bzl @@ -65,6 +65,7 @@ def define_common_targets(): "//executorch/extension/pybindings/...", "@EXECUTORCH_CLIENTS", ], + preprocessor_flags = [] if runtime.is_oss else ["-DENABLE_DYNAMIC_QUANTIZATION"], deps = [ third_party_dep("XNNPACK"), ":xnnpack_schema", diff --git a/backends/xnnpack/test/TARGETS b/backends/xnnpack/test/TARGETS index 6ccc8d4c345..3305f259931 100644 --- a/backends/xnnpack/test/TARGETS +++ b/backends/xnnpack/test/TARGETS @@ -123,3 +123,16 @@ python_unittest( "//executorch/backends/xnnpack/test/tester:tester", ], ) + +python_unittest( + name = "test_xnnpack_models", + srcs = glob([ + "models/*.py", + ]), + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/backends/xnnpack/test/tester:tester", + "//pytorch/vision:torchvision", + ], +) diff --git a/backends/xnnpack/test/models/mobilenet_v2.py b/backends/xnnpack/test/models/mobilenet_v2.py new file mode 100644 index 00000000000..939512bec63 --- /dev/null +++ b/backends/xnnpack/test/models/mobilenet_v2.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torchvision.models as models +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackQuantizedPartitioner2, +) +from executorch.backends.xnnpack.test.tester import Partition, Tester +from executorch.backends.xnnpack.test.tester.tester import Export +from executorch.backends.xnnpack.utils.configs import get_xnnpack_capture_config +from torchvision.models.mobilenetv2 import MobileNet_V2_Weights + + +class TestXNNPACKMobileNetV2(unittest.TestCase): + export_stage = Export(get_xnnpack_capture_config(enable_aot=True)) + + mv2 = models.__dict__["mobilenet_v2"](weights=MobileNet_V2_Weights) + mv2 = mv2.eval() + model_inputs = (torch.ones(1, 3, 224, 244),) + + all_operators = { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", + "executorch_exir_dialects_edge__ops_aten_add_Tensor", + "executorch_exir_dialects_edge__ops_aten_permute_copy_default", + "executorch_exir_dialects_edge__ops_aten_addmm_default", + "executorch_exir_dialects_edge__ops_aten_mean_dim", + "executorch_exir_dialects_edge__ops_aten_hardtanh_default", + "executorch_exir_dialects_edge__ops_aten_convolution_default", + } + + def test_fp32(self): + + ( + Tester(self.mv2, self.model_inputs) + .export(self.export_stage) + .to_edge() + .check(list(self.all_operators)) + .partition() + .check(["torch.ops.executorch_call_delegate"]) + .check_not(list(self.all_operators)) + .to_executorch() + .serialize() + .run_method() + .compare_outputs() + ) + + def test_qs8_pt2e(self): + # Quantization fuses away batchnorm, so it is no longer in the graph + ops_after_quantization = self.all_operators - { + "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default", + } + + ( + Tester(self.mv2, self.model_inputs) + .quantize2() + .export(self.export_stage) + .to_edge() + .check(list(ops_after_quantization)) + .partition(Partition(partitioner=XnnpackQuantizedPartitioner2)) + .check(["torch.ops.executorch_call_delegate"]) + .check_not(list(ops_after_quantization)) + .to_executorch() + .serialize() + .run_method() + .compare_outputs() + ) diff --git a/backends/xnnpack/test/ops/add.py b/backends/xnnpack/test/ops/add.py index ee19be67cdd..fe7686d1f99 100644 --- a/backends/xnnpack/test/ops/add.py +++ b/backends/xnnpack/test/ops/add.py @@ -75,9 +75,9 @@ def test_add_quantized_pt2e(self): ( Tester(add_module, model_inputs) + .quantize2() .export() .check_count({"torch.ops.aten.add.Tensor": 4}) - .quantize2() .check(["torch.ops.quantized_decomposed"]) .to_edge() .check_count({"executorch_exir_dialects_edge__ops_aten_add_Tensor": 4}) diff --git a/backends/xnnpack/test/tester/tester.py b/backends/xnnpack/test/tester/tester.py index 23602bde8bf..a736284bd9d 100644 --- a/backends/xnnpack/test/tester/tester.py +++ b/backends/xnnpack/test/tester/tester.py @@ -10,6 +10,7 @@ from typing import Any, Dict, List, Optional, Tuple import torch +import torch._export as export from executorch import exir from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( XnnpackFloatingPointPartitioner, @@ -145,23 +146,23 @@ def __init__( self.quantizer.set_global(self.quantization_config) - self.converted_program = None + self.converted_graph = None def run( - self, artifact: ExirExportedProgram, inputs: Optional[Tuple[torch.Tensor]] + self, artifact: torch.nn.Module, inputs: Optional[Tuple[torch.Tensor]] ) -> None: - prepared = prepare_pt2e(artifact.exported_program.graph_module, self.quantizer) + captured_graph = export.capture_pre_autograd_graph(artifact, inputs) + prepared = prepare_pt2e(captured_graph, self.quantizer) converted = convert_pt2e(prepared) - artifact.exported_program._graph_module = converted - self.converted_program = artifact + self.converted_graph = converted @property - def artifact(self) -> ExirExportedProgram: - return self.converted_program + def artifact(self) -> torch.fx.GraphModule: + return self.converted_graph @property def graph_module(self) -> str: - return self.converted_program.exported_program.graph_module + return self.converted_graph @register_stage @@ -274,12 +275,11 @@ def __init__( self.inputs = inputs self.stages: Dict[str, Stage] = OrderedDict.fromkeys(list(_stages_.keys())) self.pipeline = { + self._stage_name(Quantize2): [self._stage_name(Export)], self._stage_name(Quantize): [self._stage_name(Export)], self._stage_name(Export): [ - self._stage_name(Quantize2), self._stage_name(ToEdge), ], - self._stage_name(Quantize2): [self._stage_name(ToEdge)], self._stage_name(ToEdge): [self._stage_name(Partition)], # TODO Make this Stage optional self._stage_name(Partition): [self._stage_name(ToExecutorch)], diff --git a/backends/xnnpack/utils/configs.py b/backends/xnnpack/utils/configs.py index b6114f939d9..653e11a9746 100644 --- a/backends/xnnpack/utils/configs.py +++ b/backends/xnnpack/utils/configs.py @@ -36,4 +36,8 @@ def get_xnnpack_capture_config(dynamic_shape=False, enable_aot: Optional[bool] = if enable_aot is None: return CaptureConfig(enable_dynamic_shape=dynamic_shape) else: - return CaptureConfig(enable_dynamic_shape=dynamic_shape, enable_aot=enable_aot) + return CaptureConfig( + enable_dynamic_shape=dynamic_shape, + enable_aot=enable_aot, + _unlift=enable_aot, + ) diff --git a/examples/backend/TARGETS b/examples/backend/TARGETS new file mode 100644 index 00000000000..6c29686d2ea --- /dev/null +++ b/examples/backend/TARGETS @@ -0,0 +1,13 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +runtime.python_binary( + name = "xnnpack_lowering_examples", + main_src = "xnnpack_lowering_examples.py", + deps = [ + "//caffe2:torch", + "//executorch/backends/xnnpack:xnnpack_preprocess", + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + "//executorch/examples/models:models", + "//executorch/exir/backend:backend_api", + ], +) diff --git a/examples/backend/xnnpack_lowering_examples.py b/examples/backend/xnnpack_lowering_examples.py new file mode 100644 index 00000000000..79da6a5a5cf --- /dev/null +++ b/examples/backend/xnnpack_lowering_examples.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Example script for exporting simple models to flatbuffer + +import argparse +import copy + +import executorch.exir as exir +import torch._export as export +from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackFloatingPointPartitioner, + XnnpackQuantizedPartitioner2, +) +from executorch.exir.backend.backend_api import to_backend, validation_disabled + +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e +from torch.ao.quantization.quantizer.xnnpack_quantizer import ( + get_symmetric_quantization_config, + XNNPACKQuantizer, +) + +from ..models import MODEL_NAME_TO_MODEL + +# Note: for mv3, the mul op is not supported in XNNPACKQuantizer, that could be supported soon +XNNPACK_MODEL_NAME_TO_MODEL = { + name: MODEL_NAME_TO_MODEL[name] for name in ["linear", "add", "add_mul", "mv2"] +} + + +def quantize(model, example_inputs): + """This is the official recommended flow for quantization in pytorch 2.0 export""" + m = model.eval() + m = export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs)) + quantizer = XNNPACKQuantizer() + # if we set is_per_channel to True, we also need to add out_variant of quantize_per_channel/dequantize_per_channel + operator_config = get_symmetric_quantization_config(is_per_channel=False) + quantizer.set_global(operator_config) + m = prepare_pt2e(m, quantizer) + # calibration + m(*example_inputs) + m = convert_pt2e(m) + return m + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model_name", + required=True, + help=f"Provide model name. Valid ones: {list(XNNPACK_MODEL_NAME_TO_MODEL.keys())}", + ) + parser.add_argument( + "-q", + "--quantize", + action="store_true", + required=False, + default=False, + help="Flag for producing quantized or floating-point model", + ) + args = parser.parse_args() + + if args.model_name not in XNNPACK_MODEL_NAME_TO_MODEL: + raise RuntimeError( + f"Model {args.model_name} is not a valid name. or not quantizable right now, " + "please contact executorch team if you want to learn why or how to support " + "quantization for the requested model" + f"Available models are {list(XNNPACK_MODEL_NAME_TO_MODEL.keys())}." + ) + + model, example_inputs = MODEL_NAME_TO_MODEL[args.model_name]() + model = model.eval() + + partitioner = XnnpackFloatingPointPartitioner + if args.quantize: + print("Quantizing Model...") + model = quantize(model, example_inputs) + # Partitioner will eventually be a single partitioner for both fp32 and quantized models + partitioner = XnnpackQuantizedPartitioner2 + + edge = exir.capture( + model, example_inputs, exir.CaptureConfig(enable_aot=True, _unlift=True) + ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) + print("Exported graph:\n", edge.exported_program.graph) + + with validation_disabled(): + edge.exported_program = to_backend(edge.exported_program, partitioner) + print("Lowered graph:\n", edge.exported_program.graph) + + exec_prog = edge.to_executorch() + buffer = exec_prog.buffer + quant_tag = "_quantize" if args.quantize else "" + filename = f"xnnpack_{args.model_name}{quant_tag}.pte" + print(f"Saving exported program to {filename}.") + with open(filename, "wb") as f: + f.write(buffer) diff --git a/examples/executor_runner/targets.bzl b/examples/executor_runner/targets.bzl index 0e71947aa68..77b8838a727 100644 --- a/examples/executor_runner/targets.bzl +++ b/examples/executor_runner/targets.bzl @@ -7,6 +7,25 @@ def define_common_targets(): TARGETS and BUCK files that call this function. """ + # Wraps a commandline executable that can be linked against any desired + # kernel or backend implementations. Contains a main() function. + runtime.cxx_library( + name = "executor_runner_lib", + srcs = ["executor_runner.cpp"], + deps = [ + "//executorch/runtime/executor:program", + "//executorch/extension/data_loader:file_data_loader", + "//executorch/util:util", + ], + external_deps = [ + "gflags", + ], + define_static_target = True, + visibility = [ + "//executorch/examples/...", + ], + ) + register_custom_op = native.read_config("executorch", "register_custom_op", "0") if register_custom_op == "1": @@ -16,20 +35,31 @@ def define_common_targets(): else: custom_ops_lib = [] - # Test driver for models, uses all portable kernels. + # Test driver for models, uses all portable kernels and a demo backend. This + # is intended to have minimal dependencies. If you want a runner that links + # against a different backend or kernel library, define a new executable + # based on :executor_runner_lib. runtime.cxx_binary( name = "executor_runner", - srcs = ["executor_runner.cpp"], + srcs = [], deps = [ + ":executor_runner_lib", "//executorch/runtime/executor/test:test_backend_compiler_lib", - "//executorch/runtime/executor:program", - "//executorch/extension/data_loader:file_data_loader", - "//executorch/util:util", "//executorch/kernels/portable:generated_lib_all_ops", ] + custom_ops_lib, - external_deps = [ - "gflags", - ], + define_static_target = True, + **get_oss_build_kwargs() + ) + + # executor runner for XNNPACK Backend and portable kernels. + runtime.cxx_binary( + name = "xnn_executor_runner", + srcs = [], + deps = [ + ":executor_runner_lib", + "//executorch/backends/xnnpack:xnnpack_backend", + "//executorch/kernels/portable:generated_lib_all_ops", + ] + custom_ops_lib, define_static_target = True, **get_oss_build_kwargs() )