Skip to content

[Backend Tester] Add backend test suite skeleton #11960

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/scripts/unittest-buck2.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ set -eux
# TODO: can't query cadence & vulkan backends
# TODO: can't query //kernels/prim_ops because of non-buckified stuff in OSS.
buck2 query "//backends/apple/... + //backends/example/... + \
//backends/mediatek/... + //backends/test/... + //backends/transforms/... + \
//backends/mediatek/... + //backends/transforms/... + \
//backends/xnnpack/... + //configurations/... + //kernels/aten/... + \
//kernels/optimized/... + //kernels/portable/... + //kernels/quantized/... + \
//kernels/test/... + //runtime/... + //schema/... + //test/... + //util/..."
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions backends/test/harness/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,11 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
f"Output {i} does not match reference output.\n"
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n"
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref).to(torch.double))}.\n"
f"\t-- Model vs. Reference --\n"
f"\t Numel: {model.numel()}, {ref.numel()}\n"
f"\tMedian: {model.median()}, {ref.median()}\n"
f"\t Mean: {model.mean()}, {ref.mean()}\n"
f"\t Mean: {model.to(torch.double).mean()}, {ref.to(torch.double).mean()}\n"
f"\t Max: {model.max()}, {ref.max()}\n"
f"\t Min: {model.min()}, {ref.min()}\n"
)
Expand Down
23 changes: 23 additions & 0 deletions backends/test/suite/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Backend Test Suite

This directory contains tests that validate correctness and coverage of backends. These tests are written such that the backend is treated as a black box. The test logic verifies that the backend is able to handle a given pattern without erroring out (not partitioning is fine) and is able to run the graphs and yield reasonable outputs. As backends may differ significantly in implementation, numerical bounds are intentionally left loose.

These tests are intended to ensure that backends are robust and provide a smooth, "out-of-box" experience for users across the full span of input patterns. They are not intended to be a replacement for backend-specific tests, as they do not attempt to validate performance or that backends delegate operators that they expect to.

## Backend Registration

To plug into the test framework, each backend should provide an implementation of the Tester class, defined in backends/test/harness/tester.py. Backends can provide implementations of each stage, or use the default implementation, as appropriate.

At a minimum, the backend will likely need to provide a custom implementation of the Partition and ToEdgeTransformAndLower stages using the appropriate backend partitioner. See backends/xnnpack/test/tester/tester.py for an example implementation.

Once a tester is available, the backend flow(s) can be added in __init__.py in this directory by adding an entry to `ALL_TESTER_FLOWS`. Each flow entry consists of a name (used in the test case naming) and a function to instantiate a tester for a given model and input tuple.

## Test Cases

Operator test cases are defined under the operators/ directory. Tests are written in a backend-independent manner, and each test is programmatically expanded to generate a variant for each registered backend flow. The `@operator_test` decorator is applied to each test class to trigger this behavior. Tests can also be tagged with an appropriate type specifier, such as `@dtype_test`, to generate variants for each dtype. The decorators and "magic" live in __init__.py in this directory.

## Evolution of this Test Suite

This test suite is experimental and under active development. Tests are subject to added, removed, or modified without notice. It is anticipated that this suite will be stabilized by the 1.0 release of ExecuTorch.

There is currently no expectation that all backends pass all tests, as the content of the test suite is under development and open questions remain on error reporting, accuracy thresholds, and more.
3 changes: 3 additions & 0 deletions backends/test/suite/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
load(":targets.bzl", "define_common_targets")

define_common_targets(is_fbcode = True)
176 changes: 176 additions & 0 deletions backends/test/suite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


import logging
import os
import unittest

from enum import Enum
from typing import Any, Callable, Tuple

import torch
from executorch.backends.test.harness import Tester

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


# Read enabled backends from the environment variable. Enable all if
# not specified (signalled by None).
def get_enabled_backends():
et_test_backends = os.environ.get("ET_TEST_ENABLED_BACKENDS")
if et_test_backends is not None:
return et_test_backends.split(",")
else:
return None


_ENABLED_BACKENDS = get_enabled_backends()


def is_backend_enabled(backend):
if _ENABLED_BACKENDS is None:
return True
else:
return backend in _ENABLED_BACKENDS


ALL_TEST_FLOWS = []

if is_backend_enabled("xnnpack"):
from executorch.backends.xnnpack.test.tester import Tester as XnnpackTester

XNNPACK_TEST_FLOW = ("xnnpack", XnnpackTester)
ALL_TEST_FLOWS.append(XNNPACK_TEST_FLOW)

if is_backend_enabled("coreml"):
try:
from executorch.backends.apple.coreml.test.tester import CoreMLTester

COREML_TEST_FLOW = ("coreml", CoreMLTester)
ALL_TEST_FLOWS.append(COREML_TEST_FLOW)
except Exception:
print("Core ML AOT is not available.")


DTYPES = [
torch.int8,
torch.uint8,
torch.int16,
torch.uint16,
torch.int32,
torch.uint32,
torch.int64,
torch.uint64,
torch.float16,
torch.float32,
torch.float64,
]

FLOAT_DTYPES = [
torch.float16,
torch.float32,
torch.float64,
]


# The type of test function. This controls the test generation and expected signature.
# Standard tests are run, as is. Dtype tests get a variant generated for each dtype and
# take an additional dtype parameter.
class TestType(Enum):
STANDARD = 1
DTYPE = 2


# Function annotation for dtype tests. This instructs the test framework to run the test
# for each supported dtype and to pass dtype as a test parameter.
def dtype_test(func):
func.test_type = TestType.DTYPE
return func


# Class annotation for operator tests. This triggers the test framework to register
# the tests.
def operator_test(cls):
_create_tests(cls)
return cls


# Generate test cases for each backend flow.
def _create_tests(cls):
for key in dir(cls):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assert cls type?

if key.startswith("test_"):
_expand_test(cls, key)


# Expand a test into variants for each registered flow.
def _expand_test(cls, test_name: str):
test_func = getattr(cls, test_name)
for flow_name, tester_factory in ALL_TEST_FLOWS:
_create_test_for_backend(cls, test_func, flow_name, tester_factory)
delattr(cls, test_name)


def _make_wrapped_test(test_func, *args, **kwargs):
def wrapped_test(self):
test_func(self, *args, **kwargs)

return wrapped_test


def _make_wrapped_dtype_test(test_func, dtype, tester_factory):
def wrapped_test(self):
test_func(self, dtype, tester_factory)

return wrapped_test


def _create_test_for_backend(
cls,
test_func: Callable,
flow_name: str,
tester_factory: Callable[[torch.nn.Module, Tuple[Any]], Tester],
):
test_type = getattr(test_func, "test_type", TestType.STANDARD)

if test_type == TestType.STANDARD:
wrapped_test = _make_wrapped_test(test_func, tester_factory)
test_name = f"{test_func.__name__}_{flow_name}"
setattr(cls, test_name, wrapped_test)
elif test_type == TestType.DTYPE:
for dtype in DTYPES:
# wrapped_test = _make_wrapped_dtype_test(test_func, dtype, tester_factory)
wrapped_test = _make_wrapped_test(test_func, dtype, tester_factory)
dtype_name = str(dtype)[6:] # strip "torch."
test_name = f"{test_func.__name__}_{dtype_name}_{flow_name}"
setattr(cls, test_name, wrapped_test)
else:
raise NotImplementedError(f"Unknown test type {test_type}.")


class OperatorTest(unittest.TestCase):
def _test_op(self, model, inputs, tester_factory):
tester = (
tester_factory(
model,
inputs,
)
.export()
.to_edge_transform_and_lower()
)

is_delegated = any(
n.target == torch._higher_order_ops.executorch_call_delegate
for n in tester.stages[tester.cur].graph_module.graph.nodes
if n.op == "call_function"
)

# Only run the runtime test if the op was delegated.
if is_delegated:
(tester.to_executorch().serialize().run_method_and_compare_outputs())
7 changes: 7 additions & 0 deletions backends/test/suite/operators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
82 changes: 82 additions & 0 deletions backends/test/suite/operators/test_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe


from typing import Callable

import torch

from executorch.backends.test.suite import dtype_test, operator_test, OperatorTest


class Model(torch.nn.Module):
def forward(self, x, y):
return x + y


class ModelAlpha(torch.nn.Module):
def __init__(self, alpha):
super().__init__()
self.alpha = alpha

def forward(self, x, y):
return torch.add(x, y, alpha=self.alpha)


@operator_test
class Add(OperatorTest):
@dtype_test
def test_add_dtype(self, dtype, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
(torch.rand(2, 10) * 100).to(dtype),
(torch.rand(2, 10) * 100).to(dtype),
),
tester_factory,
)

def test_add_f32_bcast_first(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(5),
torch.randn(1, 5, 1, 5),
),
tester_factory,
)

def test_add_f32_bcast_second(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(4, 4, 2, 7),
torch.randn(2, 7),
),
tester_factory,
)

def test_add_f32_bcast_unary(self, tester_factory: Callable) -> None:
self._test_op(
Model(),
(
torch.randn(5),
torch.randn(1, 1, 5),
),
tester_factory,
)

def test_add_f32_alpha(self, tester_factory: Callable) -> None:
self._test_op(
ModelAlpha(alpha=2),
(
torch.randn(1, 25),
torch.randn(1, 25),
),
tester_factory,
)
Loading
Loading