Skip to content

Minor cleanup #12265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/test/harness/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
ref,
atol=atol,
rtol=rtol,
equal_nan=True,
), (
f"Output {i} does not match reference output.\n"
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
Expand Down
16 changes: 12 additions & 4 deletions backends/test/operators/facto_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@
import torch

from facto.inputgen.argument.type import ArgType
from facto.inputgen.specs.model import ConstraintProducer as cp, InPosArg, OutArg, Spec
from facto.inputgen.specs.model import (
ConstraintProducer as cp,
InKwArg,
InPosArg,
OutArg,
Spec,
)

"""
This file contains FACTO operator specs for ops not in the standard FACTO db. This mainly
includes ops not in the Core ATen op set and preserved by a backend, such as linear.
"""

LiNEAR_DEFAULT_SPEC = Spec(
LINEAR_DEFAULT_SPEC = Spec(
op="linear.default", # (Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
inspec=[
InPosArg(
Expand Down Expand Up @@ -53,7 +59,9 @@
)

_extra_specs = [
LiNEAR_DEFAULT_SPEC,
LINEAR_DEFAULT_SPEC,
]

ExtraSpecDB: dict[str, Spec] = {s.op: s for s in _extra_specs}
ExtraSpecDB: dict[str, Spec] = {
s.op: s for s in _extra_specs
}
127 changes: 61 additions & 66 deletions backends/test/operators/test_facto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,21 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

#
# This file contains logic to run generated operator tests using the FACTO
# library (https://github.com/pytorch-labs/FACTO). To run the tests, first
# clone and install FACTO by running pip install . from the FACTO source
# directory. Then, from the executorch root directory, run the following:
#
# python -m unittest backends.test.operators.test_facto.FactoTestsXNNPACK
#
# pyre-strict

import copy
import functools
import traceback
from typing import Any, Callable, List, OrderedDict, Sequence, Tuple
import unittest
from typing import Any, Callable, Sequence

import torch
from executorch.backends.test.harness.tester import Tester as TesterBase
from executorch.backends.xnnpack.test.tester.tester import Tester as XnnpackTester
from executorch.backends.xnnpack.test.tester.tester import ToEdgeTransformAndLower, Tester as XnnpackTester
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
from facto.inputgen.specs.model import ConstraintProducer as cp, Spec
from facto.inputgen.specs.model import Constraint, ConstraintProducer as cp, Spec
from facto.inputgen.utils.random_manager import random_manager
from facto.inputgen.variable.type import ScalarDtype
from facto.specdb.db import SpecDictDB
from torch._ops import OpOverload

Expand All @@ -35,9 +27,9 @@
CombinedSpecDB = SpecDictDB | ExtraSpecDB

COMMON_TENSOR_CONSTRAINTS = [
cp.Rank.Ge(lambda deps: 1), # Avoid zero and high rank tensors.
cp.Rank.Ge(lambda deps: 1),
cp.Rank.Le(lambda deps: 4),
cp.Size.Ge(lambda deps, r, d: 1), # Keep sizes reasonable.
cp.Size.Ge(lambda deps, r, d: 1),
cp.Size.Le(lambda deps, r, d: 2**9),
]

Expand All @@ -54,7 +46,6 @@
"other",
}


def _patch_spec(spec: Spec) -> Spec:
spec = copy.deepcopy(spec)
for inspec in spec.inspec:
Expand All @@ -64,18 +55,16 @@ def _patch_spec(spec: Spec) -> Spec:
inspec.constraints.extend(COMMON_SCALAR_CONSTRAINS)
return spec


class OpModel(torch.nn.Module):
"""
Wraps a single torch operator in an nn.Module.
"""

def __init__(
self,
op: OpOverload,
runtime_input_count: int,
self,
op: OpOverload,
runtime_input_count: int,
fixed_args: Sequence[Any],
fixed_kwargs: dict[str, Any],
fixed_kwargs: dict[str, Any]
):
super().__init__()
self.op = op
Expand All @@ -99,12 +88,9 @@ def __init__(
def forward(self, *args, **kwargs):
return self.op(*(args + self.fixed_args), **(kwargs | self.fixed_kwargs))


class ConvModel(OpModel):
def forward(self, *args, **kwargs):
weight, bias, stride, padding, dilation, transposed, output_padding, groups = (
self.fixed_args
)
weight, bias, stride, padding, dilation, transposed, output_padding, groups = self.fixed_args

if not transposed:
if len(weight.shape) == 3:
Expand All @@ -113,7 +99,7 @@ def forward(self, *args, **kwargs):
op = torch.nn.functional.conv2d
elif len(weight.shape) == 5:
op = torch.nn.functional.conv3d

return op(args[0], weight, bias, stride, padding, dilation, groups)
else:
if len(weight.shape) == 3:
Expand All @@ -122,19 +108,15 @@ def forward(self, *args, **kwargs):
op = torch.nn.functional.conv_transpose2d
elif len(weight.shape) == 5:
op = torch.nn.functional.conv_transpose3d

return op(
args[0], weight, bias, stride, padding, output_padding, groups, dilation
)


return op(args[0], weight, bias, stride, padding, output_padding, groups, dilation)

def get_module_for_op(op: OpOverload):
if op == torch.ops.aten.convolution.default:
return ConvModel
else:
return OpModel


class FactoTestsBase(unittest.TestCase):
def __init__(self, tester_factory: Callable[[], TesterBase], *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -147,37 +129,36 @@ def _generate_test(op_name: str) -> None:
torch_op = functools.reduce(getattr, sections, torch.ops.aten)

test_name = "test_" + op_name.replace(".", "_")

def test_body(self):
self._test_op(torch_op)
test_body = lambda self: self._test_op(torch_op)

setattr(FactoTestsBase, test_name, test_body)

@staticmethod
def get_runtime_input_count(spec: Spec):
# Determine which inputs are fixed at tracing time (weights, for example),
# vs inputs to the runtime graph. We currently assume that the runtime graph
# inputs start at the beginning of the arg list and are contiguous.
#
#
# Args are consider to be runtime inputs if they are positional and are named
# one of RUNTIME_INPUT_NAMES. If none match, we assume only the first arg is a
# runtime input.
runtime_input_count = 0
for inspec in spec.inspec:
is_runtime_input = (
inspec.type.is_tensor() and inspec.name.lower() in RUNTIME_INPUT_NAMES
inspec.type.is_tensor() and
inspec.name.lower() in RUNTIME_INPUT_NAMES
)
if is_runtime_input:
runtime_input_count += 1
else:
break

return max(1, runtime_input_count)

def setUp(self):
torch.set_printoptions(threshold=3)

def _test_op(self, op: OpOverload) -> None: # noqa: C901
def _test_op(self, op: OpOverload) -> None:
random_manager.seed(0)

# Strip namespace
Expand All @@ -186,15 +167,15 @@ def _test_op(self, op: OpOverload) -> None: # noqa: C901
# Default to .default overload
if "." not in op_name:
op_name += ".default"

# Find and patch op spec
if op_name not in CombinedSpecDB:
if not op_name in CombinedSpecDB:
raise ValueError(f"Operator {op_name} not found in SpecDictDB.")
spec = _patch_spec(CombinedSpecDB[op_name])

runtime_input_count = FactoTestsBase.get_runtime_input_count(spec)

print(f"Op: {op_name}, {runtime_input_count} runtime inputs")
print(f"Op: {op_name}, {runtime_input_count} runtime inputs")

# Run test cases
success_count_delegated = 0
Expand All @@ -207,14 +188,18 @@ def _test_op(self, op: OpOverload) -> None: # noqa: C901

try:
if isinstance(posargs[0], torch.Tensor):
# Temporary for getting around XNN crashes
if posargs[0].dtype not in {torch.float32, torch.float16}:
print("SKIPPING NON FLOAT CASE")
# Temporary for getting around XNN crashes (https://github.com/pytorch/executorch/issues/10960).
# TODO Re-enable when resolved.
if posargs[0].dtype in {torch.int8, torch.uint8}:
print("Skipping (u)int8 case.")
continue

module_cls = get_module_for_op(op)
model = module_cls(
op, runtime_input_count, posargs[runtime_input_count:], inkwargs
op,
runtime_input_count,
posargs[runtime_input_count:],
inkwargs
)

# Sanity check to make sure it runs in eager. This can present nicer error
Expand All @@ -225,13 +210,20 @@ def _test_op(self, op: OpOverload) -> None: # noqa: C901
print(f"Eager execution failed: {e}")
continue

tester = (
self._tester_factory(model, tuple(posargs[:runtime_input_count]))
.export()
.dump_artifact()
.to_edge_transform_and_lower()
tester = self._tester_factory(
model,
tuple(posargs[:runtime_input_count])
)

# Dynamo will also fail to handle some patterns that are valid in eager.
try:
tester.export()
except Exception as e:
print(f"Export failed.")
continue

tester.to_edge_transform_and_lower()

is_delegated = any(
n.target == torch._higher_order_ops.executorch_call_delegate
for n in tester.stages[tester.cur].graph_module.graph.nodes
Expand All @@ -241,19 +233,20 @@ def _test_op(self, op: OpOverload) -> None: # noqa: C901
# Only run the runtime test if the op was delegated.
if is_delegated:
(
tester.to_executorch()
tester
.to_executorch()
.serialize()
.run_method_and_compare_outputs()
)

if is_delegated:
success_count_delegated += 1
else:
success_count_undelegated += 1
#finally:
except Exception as e:
fail_count += 1
print(f"Error: {e}")
print("Args:")
print(f"Args:")
for arg in posargs:
if isinstance(arg, torch.Tensor):
print(f" {arg.dtype} {arg.shape}")
Expand All @@ -262,20 +255,22 @@ def _test_op(self, op: OpOverload) -> None: # noqa: C901

traceback.print_exc()

print(
f"{success_count_delegated + success_count_undelegated} PASS, {fail_count} FAIL"
)
print(
f" {success_count_delegated} DELEGATED, {success_count_undelegated} UNDELEGATED"
)

print(f"{success_count_delegated + success_count_undelegated} PASS, {fail_count} FAIL")
print(f" {success_count_delegated} DELEGATED, {success_count_undelegated} UNDELEGATED")

# Programatically generate tests for each operator.
for op_name in CombinedSpecDB.keys():
FactoTestsBase._generate_test(op_name)


# TODO Figure out where to put these
class FactoTestsXNNPACK(FactoTestsBase):
def __init__(self, *args, **kwargs):
super().__init__(XnnpackTester, *args, **kwargs)

try:
from executorch.backends.apple.coreml.test.tester import CoreMLTester
class FactoTestsCoreML(FactoTestsBase):
def __init__(self, *args, **kwargs):
super().__init__(CoreMLTester, *args, **kwargs)
except:
print("Skipping Core ML facto tests as Core ML AOT is not available.")
16 changes: 16 additions & 0 deletions backends/test/runner/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
add_executable(executorch-test-runner
test_runner.cpp
# TODO
../../../runtime/platform/runtime.cpp
)

target_link_libraries(
executorch-test-runner
PRIVATE executorch
gflags
extension_flat_tensor
extension_flat_tensor_serialize
extension_module
extension_tensor
optimized_native_cpu_ops_lib
xnnpack_backend)
Loading
Loading