diff --git a/backends/apple/coreml/test/tester.py b/backends/apple/coreml/test/tester.py new file mode 100644 index 00000000000..643a51473e0 --- /dev/null +++ b/backends/apple/coreml/test/tester.py @@ -0,0 +1,62 @@ +git # Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Callable, List, Optional, Sequence, Tuple, Type + +import executorch +import executorch.backends.test.harness.stages as BaseStages + +import torch +from executorch.backends.apple.coreml.partition import CoreMLPartitioner +from executorch.backends.test.harness import Tester as TesterBase +from executorch.backends.test.harness.stages import StageType +from executorch.exir import EdgeCompileConfig +from executorch.exir.backend.partitioner import Partitioner + + +class Partition(BaseStages.Partition): + def __init__(self, partitioner: Optional[Partitioner] = None): + super().__init__( + partitioner=partitioner or CoreMLPartitioner, + ) + + +class ToEdgeTransformAndLower(BaseStages.ToEdgeTransformAndLower): + def __init__( + self, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + ): + super().__init__( + default_partitioner_cls=CoreMLPartitioner, + partitioners=partitioners, + edge_compile_config=edge_compile_config, + ) + + +class CoreMLTester(TesterBase): + def __init__( + self, + module: torch.nn.Module, + example_inputs: Tuple[torch.Tensor], + dynamic_shapes: Optional[Tuple[Any]] = None, + ): + # Specialize for XNNPACK + stage_classes = ( + executorch.backends.test.harness.Tester.default_stage_classes() + | { + StageType.PARTITION: Partition, + StageType.TO_EDGE_TRANSFORM_AND_LOWER: ToEdgeTransformAndLower, + } + ) + + super().__init__( + module=module, + stage_classes=stage_classes, + example_inputs=example_inputs, + dynamic_shapes=dynamic_shapes, + ) + diff --git a/backends/test/compliance_suite/README.md b/backends/test/compliance_suite/README.md new file mode 100644 index 00000000000..04ff4c961f6 --- /dev/null +++ b/backends/test/compliance_suite/README.md @@ -0,0 +1,15 @@ +# Operator Compliance Test Suite + +This directory contains operator tests that all backends are expected to pass. While not every backend will implement every operator or permutation, the expectation is that backend partitioners will only partition nodes that the backend can support. The partitioner should never error out due to not supporting an input node. + +## Backend Registration + +To plug into the test framework, each backend should provide an implementation of the Tester class, defined in backends/test/harness/tester.py. Backends can provide implementations of each stage, or use the default implementation, as appropriate. + +At a minimum, the backend will likely need to provide a custom implementation of the Partition and ToEdgeTransformAndLower stages using the appropriate backend partitioner. See backends/xnnpack/test/tester/tester.py for an example implementation. + +Once a tester is available, the backend flow(s) can be added in __init__.py in this directory by adding an entry to `ALL_TESTER_FLOWS`. Each flow entry consists of a name (used in the test case naming) and a function to instantiate a tester for a given model and input tuple. + +## Test Cases + +Operator test cases are defined under the operators/ directory. Tests are written in a backend-independent manner, and each test is programmatically expanded to generate a variant for each registered backend flow. The `@operator_test` decorator is applied to each test class to trigger this behavior. Tests can also be tagged with an appropriate type specifier, such as `@dtype_test`, to generate variants for each dtype. The decorators and "magic" live in __init__.py in this directory. diff --git a/backends/test/compliance_suite/TARGETS b/backends/test/compliance_suite/TARGETS new file mode 100644 index 00000000000..8832b48d98a --- /dev/null +++ b/backends/test/compliance_suite/TARGETS @@ -0,0 +1,3 @@ +load(":targets.bzl", "define_common_targets") + +define_common_targets(is_fbcode = True) diff --git a/backends/test/compliance_suite/__init__.py b/backends/test/compliance_suite/__init__.py new file mode 100644 index 00000000000..7a4467ade95 --- /dev/null +++ b/backends/test/compliance_suite/__init__.py @@ -0,0 +1,141 @@ +import os +import unittest + +from enum import Enum +from typing import Any, Callable, Tuple + +import logging +import torch +from executorch.backends.test.harness import Tester + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +# Read enabled backends from the environment variable. Enable all if +# not specified (signalled by None). +def get_enabled_backends(): + et_test_backends = os.environ.get("ET_TEST_BACKENDS") + if et_test_backends is not None: + return et_test_backends.split(",") + else: + return None + +_ENABLED_BACKENDS = get_enabled_backends() + +def is_backend_enabled(backend): + if _ENABLED_BACKENDS is None: + return True + else: + return backend in _ENABLED_BACKENDS + +ALL_TEST_FLOWS = [] + +if is_backend_enabled("xnnpack"): + from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester + + XNNPACK_TEST_FLOW = ("xnnpack", XnnpackTester) + ALL_TEST_FLOWS.append(XNNPACK_TEST_FLOW) + +if is_backend_enabled("coreml"): + from executorch.backends.apple.coreml.test.tester import CoreMLTester + + COREML_TEST_FLOW = ("coreml", CoreMLTester) + ALL_TEST_FLOWS.append(COREML_TEST_FLOW) + + +DTYPES = [ + torch.int8, + torch.uint8, + torch.int16, + torch.uint16, + torch.int32, + torch.uint32, + torch.int64, + torch.uint64, + torch.float16, + torch.float32, + torch.float64, +] + +FLOAT_DTYPES =[ + torch.float16, + torch.float32, + torch.float64, +] + +class TestType(Enum): + STANDARD = 1 + DTYPE = 2 + +def dtype_test(func): + setattr(func, "test_type", TestType.DTYPE) + return func + +def operator_test(cls): + _create_tests(cls) + return cls + +def _create_tests(cls): + for key in dir(cls): + if key.startswith("test_"): + _expand_test(cls, key) + +def _expand_test(cls, test_name: str): + test_func = getattr(cls, test_name) + for (flow_name, tester_factory) in ALL_TEST_FLOWS: + _create_test_for_backend(cls, test_func, flow_name, tester_factory) + delattr(cls, test_name) + +def _create_test_for_backend( + cls, + test_func: Callable, + flow_name: str, + tester_factory: Callable[[torch.nn.Module, Tuple[Any]], Tester] +): + test_type = getattr(test_func, "test_type", TestType.STANDARD) + + if test_type == TestType.STANDARD: + def wrapped_test(self): + test_func(self, tester_factory) + + test_name = f"{test_func.__name__}_{flow_name}" + setattr(cls, test_name, wrapped_test) + elif test_type == TestType.DTYPE: + for dtype in DTYPES: + def wrapped_test(self): + test_func(self, dtype, tester_factory) + + dtype_name = str(dtype)[6:] # strip "torch." + test_name = f"{test_func.__name__}_{dtype_name}_{flow_name}" + setattr(cls, test_name, wrapped_test) + else: + raise NotImplementedError(f"Unknown test type {test_type}.") + + +class OperatorTest(unittest.TestCase): + def _test_op(self, model, inputs, tester_factory): + tester = ( + tester_factory( + model, + inputs, + ) + .export() + .to_edge_transform_and_lower() + ) + + is_delegated = any( + n.target == torch._higher_order_ops.executorch_call_delegate + for n in tester.stages[tester.cur].graph_module.graph.nodes + if n.op == "call_function" + ) + + # Only run the runtime test if the op was delegated. + if is_delegated: + ( + tester + .to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + \ No newline at end of file diff --git a/backends/test/compliance_suite/operators/__init__.py b/backends/test/compliance_suite/operators/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/test/compliance_suite/operators/test_add.py b/backends/test/compliance_suite/operators/test_add.py new file mode 100644 index 00000000000..d7aceecb601 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_add.py @@ -0,0 +1,74 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def forward(self, x, y): + return x + y + +class ModelAlpha(torch.nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = alpha + + def forward(self, x, y): + return torch.add(x, y, alpha=self.alpha) + +@operator_test +class Add(OperatorTest): + @dtype_test + def test_add_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + (torch.rand(2, 10) * 100).to(dtype), + (torch.rand(2, 10) * 100).to(dtype), + ), + tester_factory) + + def test_add_f32_bcast_first(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(5), + torch.randn(1, 5, 1, 5), + ), + tester_factory) + + def test_add_f32_bcast_second(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(4, 4, 2, 7), + torch.randn(2, 7), + ), + tester_factory) + + def test_add_f32_bcast_unary(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(5), + torch.randn(1, 1, 5), + ), + tester_factory) + + def test_add_f32_alpha(self, tester_factory: Callable) -> None: + self._test_op( + ModelAlpha(alpha=2), + ( + torch.randn(1, 25), + torch.randn(1, 25), + ), + tester_factory) + diff --git a/backends/test/compliance_suite/operators/test_div.py b/backends/test/compliance_suite/operators/test_div.py new file mode 100644 index 00000000000..d4df2ab857c --- /dev/null +++ b/backends/test/compliance_suite/operators/test_div.py @@ -0,0 +1,82 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable, Optional + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def forward(self, x, y): + return x / y + +class ModelWithRounding(torch.nn.Module): + def __init__(self, rounding_mode: Optional[str]): + super().__init__() + self.rounding_mode = rounding_mode + + def forward(self, x, y): + return torch.div(x, y, rounding_mode=self.rounding_mode) + +@operator_test +class Divide(OperatorTest): + @dtype_test + def test_divide_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + (torch.rand(2, 10) * 100).to(dtype), + (torch.rand(2, 10) * 100 + 0.1).to(dtype), # Adding 0.1 to avoid division by zero + ), + tester_factory) + + def test_divide_f32_bcast_first(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(5), + torch.randn(1, 5, 1, 5).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero + ), + tester_factory) + + def test_divide_f32_bcast_second(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(4, 4, 2, 7), + torch.randn(2, 7).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero + ), + tester_factory) + + def test_divide_f32_bcast_unary(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(5), + torch.randn(1, 1, 5).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero + ), + tester_factory) + + def test_divide_f32_trunc(self, tester_factory: Callable) -> None: + self._test_op( + ModelWithRounding(rounding_mode="trunc"), + ( + torch.randn(3, 4) * 10, + torch.randn(3, 4).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero + ), + tester_factory) + + def test_divide_f32_floor(self, tester_factory: Callable) -> None: + self._test_op( + ModelWithRounding(rounding_mode="floor"), + ( + torch.randn(3, 4) * 10, + torch.randn(3, 4).abs() + 0.1, # Using abs and adding 0.1 to avoid division by zero + ), + tester_factory) diff --git a/backends/test/compliance_suite/operators/test_elu.py b/backends/test/compliance_suite/operators/test_elu.py new file mode 100644 index 00000000000..6ffcba051b5 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_elu.py @@ -0,0 +1,41 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, alpha=1.0, inplace=False): + super().__init__() + self.alpha = alpha + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.elu(x, alpha=self.alpha, inplace=self.inplace) + +@operator_test +class TestELU(OperatorTest): + @dtype_test + def test_elu_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 100).to(dtype),), tester_factory) + + def test_elu_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_elu_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_elu_f32_alpha(self, tester_factory: Callable) -> None: + self._test_op(Model(alpha=0.5), (torch.randn(3, 4, 5),), tester_factory) + + def test_elu_f32_inplace(self, tester_factory: Callable) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + diff --git a/backends/test/compliance_suite/operators/test_gelu.py b/backends/test/compliance_suite/operators/test_gelu.py new file mode 100644 index 00000000000..71dd6ed25c2 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_gelu.py @@ -0,0 +1,46 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, approximate="none"): + super().__init__() + self.approximate = approximate + + def forward(self, x): + return torch.nn.functional.gelu(x, approximate=self.approximate) + +@operator_test +class TestGELU(OperatorTest): + @dtype_test + def test_gelu_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory) + + def test_gelu_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_gelu_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_gelu_f32_tanh_approximation(self, tester_factory: Callable) -> None: + self._test_op(Model(approximate="tanh"), (torch.randn(3, 4, 5),), tester_factory) + + def test_gelu_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with specific values spanning negative and positive ranges + x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + self._test_op(Model(), (x,), tester_factory) + + def test_gelu_f32_tanh_boundary_values(self, tester_factory: Callable) -> None: + # Test tanh approximation with specific values + x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0]) + self._test_op(Model(approximate="tanh"), (x,), tester_factory) diff --git a/backends/test/compliance_suite/operators/test_glu.py b/backends/test/compliance_suite/operators/test_glu.py new file mode 100644 index 00000000000..5aba05c855c --- /dev/null +++ b/backends/test/compliance_suite/operators/test_glu.py @@ -0,0 +1,46 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, dim=-1): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.nn.functional.glu(x, dim=self.dim) + +@operator_test +class TestGLU(OperatorTest): + @dtype_test + def test_glu_dtype(self, dtype, tester_factory: Callable) -> None: + # Input must have even number of elements in the specified dimension + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory) + + def test_glu_f32_dim_last(self, tester_factory: Callable) -> None: + # Default dim is -1 (last dimension) + self._test_op(Model(), (torch.randn(3, 4, 6),), tester_factory) + + def test_glu_f32_dim_first(self, tester_factory: Callable) -> None: + # Test with dim=0 (first dimension) + self._test_op(Model(dim=0), (torch.randn(4, 3, 5),), tester_factory) + + def test_glu_f32_dim_middle(self, tester_factory: Callable) -> None: + # Test with dim=1 (middle dimension) + self._test_op(Model(dim=1), (torch.randn(3, 8, 5),), tester_factory) + + def test_glu_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with specific values spanning negative and positive ranges + # Input must have even number of elements in the specified dimension + x = torch.tensor([[-10.0, -5.0, -1.0, 0.0], [1.0, 5.0, 10.0, -2.0]]) + self._test_op(Model(dim=1), (x,), tester_factory) diff --git a/backends/test/compliance_suite/operators/test_hardsigmoid.py b/backends/test/compliance_suite/operators/test_hardsigmoid.py new file mode 100644 index 00000000000..602e3c2e1d7 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_hardsigmoid.py @@ -0,0 +1,41 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, inplace=False): + super().__init__() + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.hardsigmoid(x, inplace=self.inplace) + +@operator_test +class TestHardsigmoid(OperatorTest): + @dtype_test + def test_hardsigmoid_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10)).to(dtype),), tester_factory) + + def test_hardsigmoid_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_hardsigmoid_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_hardsigmoid_f32_inplace(self, tester_factory: Callable) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + + def test_hardsigmoid_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with values that span the hardsigmoid's piecewise regions + x = torch.tensor([-5.0, -3.0, -1.0, 0.0, 1.0, 3.0, 5.0]) + self._test_op(Model(), (x,), tester_factory) diff --git a/backends/test/compliance_suite/operators/test_hardswish.py b/backends/test/compliance_suite/operators/test_hardswish.py new file mode 100644 index 00000000000..3d2ebe0fd46 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_hardswish.py @@ -0,0 +1,41 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, inplace=False): + super().__init__() + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.hardswish(x, inplace=self.inplace) + +@operator_test +class TestHardswish(OperatorTest): + @dtype_test + def test_hardswish_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10)).to(dtype),), tester_factory) + + def test_hardswish_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_hardswish_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_hardswish_f32_inplace(self, tester_factory: Callable) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + + def test_hardswish_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with values that span the hardswish's piecewise regions + x = torch.tensor([-5.0, -3.0, -1.0, 0.0, 1.0, 3.0, 5.0]) + self._test_op(Model(), (x,), tester_factory) diff --git a/backends/test/compliance_suite/operators/test_hardtanh.py b/backends/test/compliance_suite/operators/test_hardtanh.py new file mode 100644 index 00000000000..8ec665f0de4 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_hardtanh.py @@ -0,0 +1,46 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, min_val=-1.0, max_val=1.0, inplace=False): + super().__init__() + self.min_val = min_val + self.max_val = max_val + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.hardtanh(x, min_val=self.min_val, max_val=self.max_val, inplace=self.inplace) + +@operator_test +class TestHardtanh(OperatorTest): + @dtype_test + def test_hardtanh_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 4 - 2).to(dtype),), tester_factory) + + def test_hardtanh_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_hardtanh_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_hardtanh_f32_custom_range(self, tester_factory: Callable) -> None: + self._test_op(Model(min_val=-2.0, max_val=2.0), (torch.randn(3, 4, 5),), tester_factory) + + def test_hardtanh_f32_inplace(self, tester_factory: Callable) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + + def test_hardtanh_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with values that span the hardtanh's piecewise regions + x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]) + self._test_op(Model(), (x,), tester_factory) diff --git a/backends/test/compliance_suite/operators/test_leaky_relu.py b/backends/test/compliance_suite/operators/test_leaky_relu.py new file mode 100644 index 00000000000..6065a2dae9b --- /dev/null +++ b/backends/test/compliance_suite/operators/test_leaky_relu.py @@ -0,0 +1,46 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, negative_slope=0.01, inplace=False): + super().__init__() + self.negative_slope = negative_slope + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.leaky_relu(x, negative_slope=self.negative_slope, inplace=self.inplace) + +@operator_test +class TestLeakyReLU(OperatorTest): + @dtype_test + def test_leaky_relu_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 2 - 1).to(dtype),), tester_factory) + + def test_leaky_relu_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_leaky_relu_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_leaky_relu_f32_custom_slope(self, tester_factory: Callable) -> None: + self._test_op(Model(negative_slope=0.1), (torch.randn(3, 4, 5),), tester_factory) + + def test_leaky_relu_f32_inplace(self, tester_factory: Callable) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + + def test_leaky_relu_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with specific positive and negative values + x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]) + self._test_op(Model(), (x,), tester_factory) + diff --git a/backends/test/compliance_suite/operators/test_logsigmoid.py b/backends/test/compliance_suite/operators/test_logsigmoid.py new file mode 100644 index 00000000000..f7fdc567a45 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_logsigmoid.py @@ -0,0 +1,34 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.logsigmoid(x) + +@operator_test +class TestLogSigmoid(OperatorTest): + @dtype_test + def test_logsigmoid_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory) + + def test_logsigmoid_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_logsigmoid_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_logsigmoid_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with specific values spanning negative and positive ranges + x = torch.tensor([-10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0]) + self._test_op(Model(), (x,), tester_factory) diff --git a/backends/test/compliance_suite/operators/test_mul.py b/backends/test/compliance_suite/operators/test_mul.py new file mode 100644 index 00000000000..c885505bd3c --- /dev/null +++ b/backends/test/compliance_suite/operators/test_mul.py @@ -0,0 +1,56 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def forward(self, x, y): + return x * y + +@operator_test +class Multiply(OperatorTest): + @dtype_test + def test_multiply_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + (torch.rand(2, 10) * 100).to(dtype), + (torch.rand(2, 10) * 100).to(dtype), + ), + tester_factory) + + def test_multiply_f32_bcast_first(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(5), + torch.randn(1, 5, 1, 5), + ), + tester_factory) + + def test_multiply_f32_bcast_second(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(4, 4, 2, 7), + torch.randn(2, 7), + ), + tester_factory) + + def test_multiply_f32_bcast_unary(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(5), + torch.randn(1, 1, 5), + ), + tester_factory) diff --git a/backends/test/compliance_suite/operators/test_prelu.py b/backends/test/compliance_suite/operators/test_prelu.py new file mode 100644 index 00000000000..4c08485888a --- /dev/null +++ b/backends/test/compliance_suite/operators/test_prelu.py @@ -0,0 +1,49 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, num_parameters=1, init=0.25): + super().__init__() + self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=init) + + def forward(self, x): + return self.prelu(x) + +@operator_test +class TestPReLU(OperatorTest): + @dtype_test + def test_prelu_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model().to(dtype), ((torch.rand(2, 10) * 2 - 1).to(dtype),), tester_factory) + + def test_prelu_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_prelu_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_prelu_f32_custom_init(self, tester_factory: Callable) -> None: + self._test_op(Model(init=0.1), (torch.randn(3, 4, 5),), tester_factory) + + def test_prelu_f32_channel_shared(self, tester_factory: Callable) -> None: + # Default num_parameters=1 means the parameter is shared across all channels + self._test_op(Model(num_parameters=1), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_prelu_f32_per_channel_parameter(self, tester_factory: Callable) -> None: + # num_parameters=3 means each channel has its own parameter (for dim=1) + self._test_op(Model(num_parameters=3), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_prelu_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with specific positive and negative values + x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]) + self._test_op(Model(), (x,), tester_factory) diff --git a/backends/test/compliance_suite/operators/test_relu.py b/backends/test/compliance_suite/operators/test_relu.py new file mode 100644 index 00000000000..a2a95466e3e --- /dev/null +++ b/backends/test/compliance_suite/operators/test_relu.py @@ -0,0 +1,37 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, inplace=False): + super().__init__() + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.relu(x, self.inplace) + +@operator_test +class TestReLU(OperatorTest): + @dtype_test + def test_relu_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 100).to(dtype),), tester_factory) + + def test_relu_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_relu_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),) , tester_factory) + + def test_relu_f32_inplace(self, tester_factory: Callable) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + \ No newline at end of file diff --git a/backends/test/compliance_suite/operators/test_sigmoid.py b/backends/test/compliance_suite/operators/test_sigmoid.py new file mode 100644 index 00000000000..e4806909b28 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_sigmoid.py @@ -0,0 +1,34 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.sigmoid(x) + +@operator_test +class TestSigmoid(OperatorTest): + @dtype_test + def test_sigmoid_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory) + + def test_sigmoid_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_sigmoid_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_sigmoid_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with specific values spanning negative and positive ranges + x = torch.tensor([-10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0]) + self._test_op(Model(), (x,), tester_factory) diff --git a/backends/test/compliance_suite/operators/test_silu.py b/backends/test/compliance_suite/operators/test_silu.py new file mode 100644 index 00000000000..4436da9af17 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_silu.py @@ -0,0 +1,41 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, inplace=False): + super().__init__() + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.silu(x, inplace=self.inplace) + +@operator_test +class TestSiLU(OperatorTest): + @dtype_test + def test_silu_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.randn(2, 10) * 100).to(dtype),), tester_factory) + + def test_silu_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_silu_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_silu_f32_inplace(self, tester_factory: Callable) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + + def test_silu_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with specific values spanning negative and positive ranges + x = torch.tensor([-10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0]) + self._test_op(Model(), (x,), tester_factory) diff --git a/backends/test/compliance_suite/operators/test_sub.py b/backends/test/compliance_suite/operators/test_sub.py new file mode 100644 index 00000000000..693562610b2 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_sub.py @@ -0,0 +1,73 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def forward(self, x, y): + return x - y + +class ModelAlpha(torch.nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = alpha + + def forward(self, x, y): + return torch.sub(x, y, alpha=self.alpha) + +@operator_test +class Subtract(OperatorTest): + @dtype_test + def test_subtract_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + (torch.rand(2, 10) * 100).to(dtype), + (torch.rand(2, 10) * 100).to(dtype), + ), + tester_factory) + + def test_subtract_f32_bcast_first(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(5), + torch.randn(1, 5, 1, 5), + ), + tester_factory) + + def test_subtract_f32_bcast_second(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(4, 4, 2, 7), + torch.randn(2, 7), + ), + tester_factory) + + def test_subtract_f32_bcast_unary(self, tester_factory: Callable) -> None: + self._test_op( + Model(), + ( + torch.randn(5), + torch.randn(1, 1, 5), + ), + tester_factory) + + def test_subtract_f32_alpha(self, tester_factory: Callable) -> None: + self._test_op( + ModelAlpha(alpha=2), + ( + torch.randn(1, 25), + torch.randn(1, 25), + ), + tester_factory) diff --git a/backends/test/compliance_suite/operators/test_tanh.py b/backends/test/compliance_suite/operators/test_tanh.py new file mode 100644 index 00000000000..b6a32ae4b27 --- /dev/null +++ b/backends/test/compliance_suite/operators/test_tanh.py @@ -0,0 +1,34 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def forward(self, x): + return torch.nn.functional.tanh(x) + +@operator_test +class TestTanh(OperatorTest): + @dtype_test + def test_tanh_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory) + + def test_tanh_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_tanh_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_tanh_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with specific values spanning negative and positive ranges + x = torch.tensor([-10.0, -5.0, -1.0, 0.0, 1.0, 5.0, 10.0]) + self._test_op(Model(), (x,), tester_factory) diff --git a/backends/test/compliance_suite/operators/test_threshold.py b/backends/test/compliance_suite/operators/test_threshold.py new file mode 100644 index 00000000000..6c61f38c5ae --- /dev/null +++ b/backends/test/compliance_suite/operators/test_threshold.py @@ -0,0 +1,56 @@ +# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. + +# pyre-strict + +from typing import Callable + +import torch + +from executorch.backends.test.compliance_suite import ( + dtype_test, + operator_test, + OperatorTest, +) + +class Model(torch.nn.Module): + def __init__(self, threshold=0.0, value=0.0, inplace=False): + super().__init__() + self.threshold = threshold + self.value = value + self.inplace = inplace + + def forward(self, x): + return torch.nn.functional.threshold(x, threshold=self.threshold, value=self.value, inplace=self.inplace) + +@operator_test +class TestThreshold(OperatorTest): + @dtype_test + def test_threshold_dtype(self, dtype, tester_factory: Callable) -> None: + self._test_op(Model(), ((torch.rand(2, 10) * 10 - 5).to(dtype),), tester_factory) + + def test_threshold_f32_single_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(20),), tester_factory) + + def test_threshold_f32_multi_dim(self, tester_factory: Callable) -> None: + self._test_op(Model(), (torch.randn(2, 3, 4, 5),), tester_factory) + + def test_threshold_f32_custom_threshold(self, tester_factory: Callable) -> None: + self._test_op(Model(threshold=1.0), (torch.randn(3, 4, 5),), tester_factory) + + def test_threshold_f32_custom_value(self, tester_factory: Callable) -> None: + self._test_op(Model(value=2.0), (torch.randn(3, 4, 5),), tester_factory) + + def test_threshold_f32_custom_threshold_value(self, tester_factory: Callable) -> None: + self._test_op(Model(threshold=0.5, value=1.0), (torch.randn(3, 4, 5),), tester_factory) + + def test_threshold_f32_inplace(self, tester_factory: Callable) -> None: + self._test_op(Model(inplace=True), (torch.randn(3, 4, 5),), tester_factory) + + def test_threshold_f32_boundary_values(self, tester_factory: Callable) -> None: + # Test with specific values around the threshold + x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]) + self._test_op(Model(), (x,), tester_factory) + + def test_threshold_f32_all_params(self, tester_factory: Callable) -> None: + # Test with all parameters customized + self._test_op(Model(threshold=0.5, value=3.0, inplace=True), (torch.randn(3, 4, 5),), tester_factory) diff --git a/backends/test/compliance_suite/targets.bzl b/backends/test/compliance_suite/targets.bzl new file mode 100644 index 00000000000..b4ce253a353 --- /dev/null +++ b/backends/test/compliance_suite/targets.bzl @@ -0,0 +1,37 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_tests_for_backend(name, deps): + runtime.python_test( + name = "compliance_suite_" + name, + srcs = glob([ + "operators/*.py", + ]) + [ + "__init__.py", + ], + deps = [ + "//executorch/backends/xnnpack/test/tester:tester", + "//executorch/exir:lib", + "fbsource//third-party/pypi/parameterized:parameterized", + ] + deps, + external_deps = [ + "libtorch", + ], + supports_static_listing = False, + labels = [ + "exclude_from_coverage", + ], + env = { + "ET_TEST_BACKENDS": name, + }, + ) + + +def define_common_targets(is_fbcode): + if is_fbcode: + define_tests_for_backend("xnnpack", [ + "//executorch/backends/xnnpack/partition:xnnpack_partitioner", + ]) + + define_tests_for_backend("coreml", [ + "//executorch/backends/apple/coreml:tester", + ]) diff --git a/backends/test/operators/__init__.py b/backends/test/operators/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/test/operators/facto_specs.py b/backends/test/operators/facto_specs.py new file mode 100644 index 00000000000..96fe86b2ea7 --- /dev/null +++ b/backends/test/operators/facto_specs.py @@ -0,0 +1,59 @@ +import facto.specdb.function as fn +import torch + +from facto.inputgen.argument.type import ArgType +from facto.inputgen.specs.model import ConstraintProducer as cp, InPosArg, OutArg, Spec + +""" +This file contains FACTO operator specs for ops not in the standard FACTO db. This mainly +includes ops not in the Core ATen op set and preserved by a backend, such as linear. +""" + +LiNEAR_DEFAULT_SPEC = Spec( + op="linear.default", # (Tensor input, Tensor weight, Tensor? bias=None) -> Tensor + inspec=[ + InPosArg( + ArgType.Tensor, + name="input", + deps=[1, 2], + constraints=[ + cp.Dtype.Eq(lambda deps: deps[0].dtype), + cp.Rank.Ge(lambda deps: 2), + cp.Size.In( + lambda deps, r, d: fn.broadcast_to( + (fn.safe_size(deps[0], 0), fn.safe_size(deps[1], 1)), r, d + ) + ), + ], + ), + InPosArg( + ArgType.Tensor, + name="weight", + constraints=[ + cp.Dtype.Ne(lambda deps: torch.bool), + cp.Rank.Eq(lambda deps: 2), + ], + ), + InPosArg( + ArgType.Tensor, + name="bias", + deps=[1], + constraints=[ + cp.Dtype.Eq(lambda deps: deps[0].dtype), + cp.Rank.Eq(lambda deps: 2), + cp.Size.Eq( + lambda deps, r, d: fn.safe_size(deps[0], 1) if d == 0 else None + ), + ], + ), + ], + outspec=[ + OutArg(ArgType.Tensor), + ], +) + +_extra_specs = [ + LiNEAR_DEFAULT_SPEC, +] + +ExtraSpecDB: dict[str, Spec] = {s.op: s for s in _extra_specs} diff --git a/backends/test/operators/test_facto.py b/backends/test/operators/test_facto.py new file mode 100644 index 00000000000..208aaa042a9 --- /dev/null +++ b/backends/test/operators/test_facto.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# +# This file contains logic to run generated operator tests using the FACTO +# library (https://github.com/pytorch-labs/FACTO). To run the tests, first +# clone and install FACTO by running pip install . from the FACTO source +# directory. Then, from the executorch root directory, run the following: +# +# python -m unittest backends.test.operators.test_facto.FactoTestsXNNPACK +# + +import copy +import functools +import traceback +import unittest +from typing import Any, Callable, Sequence + +import torch +from executorch.backends.test.harness.tester import Tester as TesterBase +from executorch.backends.xnnpack.test.tester.tester import Tester as XnnpackTester +from facto.inputgen.argtuple.gen import ArgumentTupleGenerator +from facto.inputgen.specs.model import ConstraintProducer as cp, Spec +from facto.inputgen.utils.random_manager import random_manager +from facto.specdb.db import SpecDictDB +from torch._ops import OpOverload + +from .facto_specs import ExtraSpecDB + +CombinedSpecDB = SpecDictDB | ExtraSpecDB + +COMMON_TENSOR_CONSTRAINTS = [ + cp.Rank.Ge(lambda deps: 1), # Avoid zero and high rank tensors. + cp.Rank.Le(lambda deps: 4), + cp.Size.Ge(lambda deps, r, d: 1), # Keep sizes reasonable. + cp.Size.Le(lambda deps, r, d: 2**9), +] + +COMMON_SCALAR_CONSTRAINS = [ + cp.Value.Ge(lambda deps, dtype: -1000), + cp.Value.Le(lambda deps, dtype: 1000), +] + +# Operator args are treated as runtime graph inputs if the argument name is +# in this list. +RUNTIME_INPUT_NAMES = { + "self", + "tensor", + "other", +} + + +def _patch_spec(spec: Spec) -> Spec: + spec = copy.deepcopy(spec) + for inspec in spec.inspec: + if inspec.type.is_tensor(): + inspec.constraints.extend(COMMON_TENSOR_CONSTRAINTS) + elif inspec.type.is_scalar(): + inspec.constraints.extend(COMMON_SCALAR_CONSTRAINS) + return spec + + +class OpModel(torch.nn.Module): + """ + Wraps a single torch operator in an nn.Module. + """ + + def __init__( + self, + op: OpOverload, + runtime_input_count: int, + fixed_args: Sequence[Any], + fixed_kwargs: dict[str, Any], + ): + super().__init__() + self.op = op + self.runtime_input_count = runtime_input_count + self.fixed_kwargs = fixed_kwargs + + # Register parameters for fixed tensors. Some things will choke on + # constant tensor weights, for example. + new_args = [] + for i, arg in enumerate(fixed_args): + if isinstance(arg, torch.Tensor): + param = torch.nn.Parameter(arg, requires_grad=False) + param_name = f"arg_{i}_param" + setattr(self, param_name, param) + self.register_parameter(param_name, param) + new_args.append(param) + else: + new_args.append(arg) + self.fixed_args = tuple(new_args) + + def forward(self, *args, **kwargs): + return self.op(*(args + self.fixed_args), **(kwargs | self.fixed_kwargs)) + + +class ConvModel(OpModel): + def forward(self, *args, **kwargs): + weight, bias, stride, padding, dilation, transposed, output_padding, groups = ( + self.fixed_args + ) + + if not transposed: + if len(weight.shape) == 3: + op = torch.nn.functional.conv1d + elif len(weight.shape) == 4: + op = torch.nn.functional.conv2d + elif len(weight.shape) == 5: + op = torch.nn.functional.conv3d + + return op(args[0], weight, bias, stride, padding, dilation, groups) + else: + if len(weight.shape) == 3: + op = torch.nn.functional.conv_transpose1d + elif len(weight.shape) == 4: + op = torch.nn.functional.conv_transpose2d + elif len(weight.shape) == 5: + op = torch.nn.functional.conv_transpose3d + + return op( + args[0], weight, bias, stride, padding, output_padding, groups, dilation + ) + + +def get_module_for_op(op: OpOverload): + if op == torch.ops.aten.convolution.default: + return ConvModel + else: + return OpModel + + +class FactoTestsBase(unittest.TestCase): + def __init__(self, tester_factory: Callable[[], TesterBase], *args, **kwargs): + super().__init__(*args, **kwargs) + self._tester_factory = tester_factory + + @staticmethod + def _generate_test(op_name: str) -> None: + # Find the torch op with the given name. + sections = op_name.split(".") + torch_op = functools.reduce(getattr, sections, torch.ops.aten) + + test_name = "test_" + op_name.replace(".", "_") + + def test_body(self): + self._test_op(torch_op) + + setattr(FactoTestsBase, test_name, test_body) + + @staticmethod + def get_runtime_input_count(spec: Spec): + # Determine which inputs are fixed at tracing time (weights, for example), + # vs inputs to the runtime graph. We currently assume that the runtime graph + # inputs start at the beginning of the arg list and are contiguous. + # + # Args are consider to be runtime inputs if they are positional and are named + # one of RUNTIME_INPUT_NAMES. If none match, we assume only the first arg is a + # runtime input. + runtime_input_count = 0 + for inspec in spec.inspec: + is_runtime_input = ( + inspec.type.is_tensor() and inspec.name.lower() in RUNTIME_INPUT_NAMES + ) + if is_runtime_input: + runtime_input_count += 1 + else: + break + + return max(1, runtime_input_count) + + def setUp(self): + torch.set_printoptions(threshold=3) + + def _test_op(self, op: OpOverload) -> None: # noqa: C901 + random_manager.seed(0) + + # Strip namespace + op_name = op.name().split("::")[-1] + + # Default to .default overload + if "." not in op_name: + op_name += ".default" + + # Find and patch op spec + if op_name not in CombinedSpecDB: + raise ValueError(f"Operator {op_name} not found in SpecDictDB.") + spec = _patch_spec(CombinedSpecDB[op_name]) + + runtime_input_count = FactoTestsBase.get_runtime_input_count(spec) + + print(f"Op: {op_name}, {runtime_input_count} runtime inputs") + + # Run test cases + success_count_delegated = 0 + success_count_undelegated = 0 + fail_count = 0 + + i = 0 + for posargs, inkwargs, _ in ArgumentTupleGenerator(spec).gen(): + i += 1 + + try: + if isinstance(posargs[0], torch.Tensor): + # Temporary for getting around XNN crashes + if posargs[0].dtype not in {torch.float32, torch.float16}: + print("SKIPPING NON FLOAT CASE") + continue + + module_cls = get_module_for_op(op) + model = module_cls( + op, runtime_input_count, posargs[runtime_input_count:], inkwargs + ) + + # Sanity check to make sure it runs in eager. This can present nicer error + # messages sometimes compared to tracing. + try: + model(*posargs[:runtime_input_count]) + except Exception as e: + print(f"Eager execution failed: {e}") + continue + + tester = ( + self._tester_factory(model, tuple(posargs[:runtime_input_count])) + .export() + .dump_artifact() + .to_edge_transform_and_lower() + ) + + is_delegated = any( + n.target == torch._higher_order_ops.executorch_call_delegate + for n in tester.stages[tester.cur].graph_module.graph.nodes + if n.op == "call_function" + ) + + # Only run the runtime test if the op was delegated. + if is_delegated: + ( + tester.to_executorch() + .serialize() + .run_method_and_compare_outputs() + ) + + if is_delegated: + success_count_delegated += 1 + else: + success_count_undelegated += 1 + except Exception as e: + fail_count += 1 + print(f"Error: {e}") + print("Args:") + for arg in posargs: + if isinstance(arg, torch.Tensor): + print(f" {arg.dtype} {arg.shape}") + else: + print(f" {arg}") + + traceback.print_exc() + + print( + f"{success_count_delegated + success_count_undelegated} PASS, {fail_count} FAIL" + ) + print( + f" {success_count_delegated} DELEGATED, {success_count_undelegated} UNDELEGATED" + ) + + +# Programatically generate tests for each operator. +for op_name in CombinedSpecDB.keys(): + FactoTestsBase._generate_test(op_name) + + +# TODO Figure out where to put these +class FactoTestsXNNPACK(FactoTestsBase): + def __init__(self, *args, **kwargs): + super().__init__(XnnpackTester, *args, **kwargs) diff --git a/backends/xnnpack/test/tester/__init__.py b/backends/xnnpack/test/tester/__init__.py index 44933c43309..a4527d9edc8 100644 --- a/backends/xnnpack/test/tester/__init__.py +++ b/backends/xnnpack/test/tester/__init__.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# TODO: Be more delibrate on module structure from executorch.backends.xnnpack.test.tester.tester import ( Export, Partition, @@ -18,13 +17,13 @@ ) __all__ = [ - Export, - ToEdge, - Partition, - Quantize, - RunPasses, - ToEdgeTransformAndLower, - Tester, - Serialize, - ToExecutorch, + "Export", + "ToEdge", + "Partition", + "Quantize", + "RunPasses", + "ToEdgeTransformAndLower", + "Tester", + "Serialize", + "ToExecutorch", ]