From 7889c0fcbe4334045edd5b836753a71e0a34758f Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Tue, 18 Mar 2025 09:29:33 +0100 Subject: [PATCH 1/2] Arm backend: Add check for unsupported dtypes on Ethos-U55 Move all Ethos-U55 support checks into a single file. Signed-off-by: Erik Lundell Change-Id: Ib6444abdbe1cc15d7ec1a91efa15362022f57895 --- backends/arm/operator_support/__init__.py | 1 + .../arm/operator_support/ethos_u55_support.py | 176 ++++++++++++++++++ .../tosa_supported_operators.py | 60 +----- backends/arm/test/ops/test_sigmoid_16bit.py | 38 +--- backends/arm/test/ops/test_sigmoid_32bit.py | 48 +---- 5 files changed, 198 insertions(+), 125 deletions(-) create mode 100644 backends/arm/operator_support/ethos_u55_support.py diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index c1189b2ae59..bd54c3e1f85 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -7,6 +7,7 @@ from . import ( # noqa convolution_support, + ethos_u55_support, minmax_support, pool_2d_support, reduce_sum_support, diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py new file mode 100644 index 00000000000..25b7a4d04c9 --- /dev/null +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -0,0 +1,176 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import typing + +import torch +import torch.fx as fx +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm._passes.insert_table_ops import TableOps +from executorch.exir.backend.utils import WhyNoPartitionReporter + +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx.passes.operator_support import OperatorSupportBase + + +class EthosU55DtypeSupport(OperatorSupportBase): + + def __init__(self, reporter: WhyNoPartitionReporter): + super().__init__() + self.reporter = reporter + + targeted_ops_i8_i16_i32 = [ + exir_ops.edge.aten.cat.default, + exir_ops.edge.aten.repeat.default, + exir_ops.edge.aten.constant_pad_nd.default, + exir_ops.edge.aten.view.default, + exir_ops.edge.aten.permute.default, + ] + + target_ops_i8 = tuple(TableOps.included_ops()) + + def _try_determine_dtype(self, node: fx.Node) -> torch.dtype | None: + """Attempt to figure out the quantized data type of node. On failure, return None.""" + + dtype = get_first_fake_tensor(node).dtype + if not dtype.is_floating_point: + return dtype + + if ( + node.target + is exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default + ): + return get_first_fake_tensor(node.all_input_nodes[0]).dtype + + if len(node.users) == 0: + return None + + q_node = list(node.users)[0] + if ( + q_node.target + is exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ): + return typing.cast(torch.dtype, q_node.args[-1]) + + # We can't easily figure out dtype, return None + return None + + def is_node_supported( # noqa: C901 + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + + dtype = self._try_determine_dtype(node) + if dtype is None: + # If we couldn't determine dtype, just return ok. + return True + + if node.target in self.targeted_ops_i8_i16_i32: + if dtype not in (torch.int8, torch.int16, torch.int32): + self.reporter.report_reject( + node, f"Unsupported dtype {dtype} (Supports i8, i16, i32)." + ) + return False + + if node.target in self.target_ops_i8: + if dtype not in (torch.int8,): + self.reporter.report_reject( + node, f"Unsupported dtype {dtype} (Supports i8)." + ) + return False + + if node.target == exir_ops.edge.aten.convolution.default: + ifm, weight = node.all_input_nodes[0:2] + ifm_dtype = self._try_determine_dtype(ifm) + if ifm_dtype is not None and ifm_dtype not in (torch.int8, torch.int16): + self.reporter.report_reject( + node, f"Unsupported input dtype {dtype} (Supports i8, i16)." + ) + return False + weight_dtype = self._try_determine_dtype(weight) + if weight_dtype is not None and weight_dtype not in (torch.int8,): + self.reporter.report_reject( + node, f"Unsupported weight dtype {dtype} (Supports i8)." + ) + return False + if len(node.all_input_nodes) > 2: + bias = node.all_input_nodes[2] + bias_dtype = self._try_determine_dtype(bias) + if bias_dtype is not None and bias_dtype not in (torch.int32,): + self.reporter.report_reject( + node, f"Unsupported bias dtype {dtype} (Supports i32)." + ) + return False + + if node.target in ( + exir_ops.edge.aten.mm.default, + exir_ops.edge.aten.bmm.default, + ): + for input_node in node.all_input_nodes: + dtype = self._try_determine_dtype(input_node) + if dtype is not None and dtype != torch.int8: + self.reporter.report_reject( + input_node, + f"Input {input_node.name} has unsupported dtype {dtype} (Supports i8).", + ) + return False + + return True + + +class EthosU55NotSupported(OperatorSupportBase): + """ + Certain operators are not supported on U55. These are listed in `unsupported_ops`. + The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious. + For unimplemented operators, this is the anticipated mapping, and it might be incorrect. + """ + + unsupported_ops = [ + exir_ops.edge.aten.any.default, # REDUCE_ANY + exir_ops.edge.aten.any.dim, # REDUCE_ANY + exir_ops.edge.aten.any.dims, # REDUCE_ANY + exir_ops.edge.aten.bitwise_and.Tensor, + exir_ops.edge.aten.bitwise_or.Tensor, + exir_ops.edge.aten.bitwise_xor.Tensor, + exir_ops.edge.aten.bitwise_not, + exir_ops.edge.aten.logical_and.default, + exir_ops.edge.aten.logical_or.default, + exir_ops.edge.aten.logical_xor.default, + exir_ops.edge.aten.logical_not.default, + exir_ops.edge.aten.amax.default, # REDUCE_MAX + exir_ops.edge.aten.amin.default, # REDUCE_MIN + exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.eq.Scalar, + exir_ops.edge.aten.ge.Tensor, + exir_ops.edge.aten.gt.Tensor, + exir_ops.edge.aten.le.Tensor, + exir_ops.edge.aten.lt.Tensor, + exir_ops.edge.aten.flip.default, # REVERSE + exir_ops.edge.aten.grid_sampler_2d, # GATHER + exir_ops.edge.aten.scatter.src, + exir_ops.edge.aten.scatter.value, + exir_ops.edge.aten.select_scatter.default, + exir_ops.edge.aten.scatter_reduce.two, + exir_ops.edge.aten.scatter_add.default, + exir_ops.edge.aten.upsample_nearest2d.vec, # RESIZE + exir_ops.edge.aten.upsample_bilinear2d.vec, # RESIZE + exir_ops.edge.aten.reflection_pad1d.default, # REVERSE + exir_ops.edge.aten.reflection_pad2d.default, # REVERSE + exir_ops.edge.aten.reflection_pad3d.default, # REVERSE + ] + + def __init__(self, reporter: WhyNoPartitionReporter): + self.reporter = reporter + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + + if node.target in self.unsupported_ops: + self.reporter.report_reject(node, "Op is not supported on U55.") + return False + + return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index f37289fa001..167be984e57 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -18,6 +18,10 @@ from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( FuseQuantizedActivationPass, ) +from executorch.backends.arm.operator_support.ethos_u55_support import ( + EthosU55DtypeSupport, + EthosU55NotSupported, +) from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification from executorch.exir import ExportedProgram from executorch.exir.backend.utils import WhyNoPartitionReporter @@ -118,6 +122,7 @@ def tosa_support_factory( negative_checks.append(CheckProperQuantization(reporter)) if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) + negative_checks.append(EthosU55DtypeSupport(reporter)) return chain( reporter.wrap_check( @@ -216,61 +221,6 @@ def is_node_supported( return supported -class EthosU55NotSupported(OperatorSupportBase): - """ - Certain operators are not supported on U55. These are listed in `unsupported_ops`. - The comment mentions the unsupported TOSA operator that the aten operator maps to where it is not obvious. - For unimplemented operators, this is the anticipated mapping, and it might be incorrect. - """ - - unsupported_ops = [ - exir_ops.edge.aten.any.default, # REDUCE_ANY - exir_ops.edge.aten.any.dim, # REDUCE_ANY - exir_ops.edge.aten.any.dims, # REDUCE_ANY - exir_ops.edge.aten.bitwise_and.Tensor, - exir_ops.edge.aten.bitwise_or.Tensor, - exir_ops.edge.aten.bitwise_xor.Tensor, - exir_ops.edge.aten.bitwise_not, - exir_ops.edge.aten.logical_and.default, - exir_ops.edge.aten.logical_or.default, - exir_ops.edge.aten.logical_xor.default, - exir_ops.edge.aten.logical_not.default, - exir_ops.edge.aten.amax.default, # REDUCE_MAX - exir_ops.edge.aten.amin.default, # REDUCE_MIN - exir_ops.edge.aten.eq.Tensor, - exir_ops.edge.aten.eq.Scalar, - exir_ops.edge.aten.ge.Tensor, - exir_ops.edge.aten.gt.Tensor, - exir_ops.edge.aten.le.Tensor, - exir_ops.edge.aten.lt.Tensor, - exir_ops.edge.aten.flip.default, # REVERSE - exir_ops.edge.aten.grid_sampler_2d, # GATHER - exir_ops.edge.aten.scatter.src, - exir_ops.edge.aten.scatter.value, - exir_ops.edge.aten.select_scatter.default, - exir_ops.edge.aten.scatter_reduce.two, - exir_ops.edge.aten.scatter_add.default, - exir_ops.edge.aten.upsample_nearest2d.vec, # RESIZE - exir_ops.edge.aten.upsample_bilinear2d.vec, # RESIZE - exir_ops.edge.aten.reflection_pad1d.default, # REVERSE - exir_ops.edge.aten.reflection_pad2d.default, # REVERSE - exir_ops.edge.aten.reflection_pad3d.default, # REVERSE - ] - - def __init__(self, reporter: WhyNoPartitionReporter): - self.reporter = reporter - - def is_node_supported( - self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node - ) -> bool: - - if node.target in self.unsupported_ops: - self.reporter.report_reject(node, "Op is not supported on U55.") - return False - - return True - - class NeedsDecompositionCheck(OperatorSupportBase): """ Targeted operators need to be decomposed prior to quantization in order to get a pair of q-dq-nodes surrounding diff --git a/backends/arm/test/ops/test_sigmoid_16bit.py b/backends/arm/test/ops/test_sigmoid_16bit.py index 3f53141543e..c3907887ac9 100644 --- a/backends/arm/test/ops/test_sigmoid_16bit.py +++ b/backends/arm/test/ops/test_sigmoid_16bit.py @@ -13,8 +13,8 @@ from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( - EthosU55PipelineBI, EthosU85PipelineBI, + OpNotSupportedPipeline, TosaPipelineBI, ) from executorch.backends.xnnpack.test.tester import Quantize @@ -109,22 +109,10 @@ def test_sigmoid_add_sigmoid_tosa_BI(test_data): @common.parametrize( "test_data", test_data_suite, - xfails={ - "ones": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "rand": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "rand_4d": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "randn_pos": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "randn_neg": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "ramp": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - }, - # int16 tables are not supported, but some tests happen to pass regardless. - # Set them to xfail but strict=False -> ok if they pass. - strict=False, ) -@common.XfailIfNoCorstone300 def test_sigmoid_tosa_u55(test_data): - pipeline = EthosU55PipelineBI( - Sigmoid(), (test_data(),), Sigmoid.aten_op, Sigmoid.exir_op, run_on_fvp=True + pipeline = OpNotSupportedPipeline( + Sigmoid(), (test_data(),), "TOSA-0.80+BI+u55", {Sigmoid.exir_op: 1} ) pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI+u55")) pipeline.run() @@ -133,26 +121,14 @@ def test_sigmoid_tosa_u55(test_data): @common.parametrize( "test_data", test_data_suite, - xfails={ - "ones": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "rand": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "rand_4d": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "randn_neg": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "randn_pos": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "ramp": "AsssertionError: Output 0 does not match reference output. MLBEDSW-9770", - }, - # int16 tables are not supported, but some tests happen to pass regardless. - # Set them to xfail but strict=False -> ok if they pass. - strict=False, ) -@common.XfailIfNoCorstone300 def test_sigmoid_add_sigmoid_tosa_u55(test_data): - pipeline = EthosU55PipelineBI( + pipeline = OpNotSupportedPipeline( SigmoidAddSigmoid(), (test_data(),), - Sigmoid.aten_op, - Sigmoid.exir_op, - run_on_fvp=True, + "TOSA-0.80+BI+u55", + {Sigmoid.exir_op: 3}, + n_expected_delegates=1, ) pipeline.change_args("quantize", get_16bit_sigmoid_quantizer("TOSA-0.80+BI+u55")) pipeline.run() diff --git a/backends/arm/test/ops/test_sigmoid_32bit.py b/backends/arm/test/ops/test_sigmoid_32bit.py index 6ba4ab2d030..5388eae83c3 100644 --- a/backends/arm/test/ops/test_sigmoid_32bit.py +++ b/backends/arm/test/ops/test_sigmoid_32bit.py @@ -9,8 +9,8 @@ from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( - EthosU55PipelineBI, EthosU85PipelineBI, + OpNotSupportedPipeline, TosaPipelineBI, ) from executorch.backends.xnnpack.test.tester import Quantize @@ -122,53 +122,23 @@ def test_sigmoid_add_sigmoid_tosa_BI(test_data): pipeline.run() -@common.parametrize( - "test_data", - test_data_suite, - xfails={ - "ones": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "rand": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "rand_4d": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "randn_pos": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "randn_neg": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "ramp": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - }, - # int16 tables are not supported, but some tests happen to pass regardless. - # Set them to xfail but strict=False -> ok if they pass. - strict=False, -) -@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) def test_sigmoid_tosa_u55(test_data): - pipeline = EthosU55PipelineBI( - Sigmoid(), (test_data(),), Sigmoid.aten_op, Sigmoid.exir_op, run_on_fvp=True + pipeline = OpNotSupportedPipeline( + Sigmoid(), (test_data(),), "TOSA-0.80+BI+u55", {Sigmoid.exir_op: 1} ) pipeline.change_args("quantize", get_32bit_sigmoid_quantizer("TOSA-0.80+BI+u55")) pipeline.run() -@common.parametrize( - "test_data", - test_data_suite, - xfails={ - "ones": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "rand": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "rand_4d": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "randn_pos": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "randn_neg": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - "ramp": "AssertionError: Output 0 does not match reference output. MLBEDSW-9770", - }, - # int16 tables are not supported, but some tests happen to pass regardless. - # Set them to xfail but strict=False -> ok if they pass. - strict=False, -) -@common.XfailIfNoCorstone300 +@common.parametrize("test_data", test_data_suite) def test_sigmoid_add_sigmoid_tosa_u55(test_data): - pipeline = EthosU55PipelineBI( + pipeline = OpNotSupportedPipeline( SigmoidAddSigmoid(), (test_data(),), - Sigmoid.aten_op, - Sigmoid.exir_op, - run_on_fvp=True, + "TOSA-0.80+BI+u55", + {Sigmoid.exir_op: 3}, + n_expected_delegates=1, ) pipeline.change_args("quantize", get_32bit_sigmoid_quantizer("TOSA-0.80+BI+u55")) pipeline.run() From ffe9181e0f7512b65ad5546685c17b15ec302d2e Mon Sep 17 00:00:00 2001 From: Erik Lundell Date: Fri, 21 Mar 2025 15:50:41 +0100 Subject: [PATCH 2/2] Arm backend: Add Ethos-U55 permute check Signed-off-by: Erik Lundell Change-Id: Id7c6d6469e96e4133b7b1a54be6ea66bc7dc861a --- .../arm/operator_support/ethos_u55_support.py | 162 ++++++++++++++---- .../tosa_supported_operators.py | 2 + backends/arm/operators/op_permute.py | 44 +++-- backends/arm/test/ops/test_permute.py | 26 ++- 4 files changed, 185 insertions(+), 49 deletions(-) diff --git a/backends/arm/operator_support/ethos_u55_support.py b/backends/arm/operator_support/ethos_u55_support.py index 25b7a4d04c9..64f3fb3f816 100644 --- a/backends/arm/operator_support/ethos_u55_support.py +++ b/backends/arm/operator_support/ethos_u55_support.py @@ -11,12 +11,27 @@ import torch.fx as fx from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor from executorch.backends.arm._passes.insert_table_ops import TableOps +from executorch.backends.arm.operators.op_permute import transform_permutation_vector +from executorch.backends.arm.tosa_utils import tosa_shape from executorch.exir.backend.utils import WhyNoPartitionReporter from executorch.exir.dialects._ops import ops as exir_ops from torch.fx.passes.operator_support import OperatorSupportBase +def _try_determine_dtype(node: fx.Node) -> torch.dtype | None: + dtype = get_first_fake_tensor(node).dtype + if not dtype.is_floating_point: + return dtype + if node.target is exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: + return get_first_fake_tensor(node.all_input_nodes[0]).dtype + q_node = list(node.users)[0] + if q_node.target is exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: + return typing.cast(torch.dtype, q_node.args[-1]) + # We can't easily figure out dtype, return None + return None + + class EthosU55DtypeSupport(OperatorSupportBase): def __init__(self, reporter: WhyNoPartitionReporter): @@ -33,37 +48,11 @@ def __init__(self, reporter: WhyNoPartitionReporter): target_ops_i8 = tuple(TableOps.included_ops()) - def _try_determine_dtype(self, node: fx.Node) -> torch.dtype | None: - """Attempt to figure out the quantized data type of node. On failure, return None.""" - - dtype = get_first_fake_tensor(node).dtype - if not dtype.is_floating_point: - return dtype - - if ( - node.target - is exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default - ): - return get_first_fake_tensor(node.all_input_nodes[0]).dtype - - if len(node.users) == 0: - return None - - q_node = list(node.users)[0] - if ( - q_node.target - is exir_ops.edge.quantized_decomposed.quantize_per_tensor.default - ): - return typing.cast(torch.dtype, q_node.args[-1]) - - # We can't easily figure out dtype, return None - return None - def is_node_supported( # noqa: C901 self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node ) -> bool: - dtype = self._try_determine_dtype(node) + dtype = _try_determine_dtype(node) if dtype is None: # If we couldn't determine dtype, just return ok. return True @@ -84,13 +73,13 @@ def is_node_supported( # noqa: C901 if node.target == exir_ops.edge.aten.convolution.default: ifm, weight = node.all_input_nodes[0:2] - ifm_dtype = self._try_determine_dtype(ifm) + ifm_dtype = _try_determine_dtype(ifm) if ifm_dtype is not None and ifm_dtype not in (torch.int8, torch.int16): self.reporter.report_reject( node, f"Unsupported input dtype {dtype} (Supports i8, i16)." ) return False - weight_dtype = self._try_determine_dtype(weight) + weight_dtype = _try_determine_dtype(weight) if weight_dtype is not None and weight_dtype not in (torch.int8,): self.reporter.report_reject( node, f"Unsupported weight dtype {dtype} (Supports i8)." @@ -98,7 +87,7 @@ def is_node_supported( # noqa: C901 return False if len(node.all_input_nodes) > 2: bias = node.all_input_nodes[2] - bias_dtype = self._try_determine_dtype(bias) + bias_dtype = _try_determine_dtype(bias) if bias_dtype is not None and bias_dtype not in (torch.int32,): self.reporter.report_reject( node, f"Unsupported bias dtype {dtype} (Supports i32)." @@ -110,7 +99,7 @@ def is_node_supported( # noqa: C901 exir_ops.edge.aten.bmm.default, ): for input_node in node.all_input_nodes: - dtype = self._try_determine_dtype(input_node) + dtype = _try_determine_dtype(input_node) if dtype is not None and dtype != torch.int8: self.reporter.report_reject( input_node, @@ -174,3 +163,114 @@ def is_node_supported( return False return True + + +shape_t = list[int] + + +class EthosU55TransposeCheck(OperatorSupportBase): + + def __init__(self, reporter: WhyNoPartitionReporter): + super().__init__() + self.reporter = reporter + + def _pad_to_rank_4( + self, shape: shape_t, permutation: list[int] + ) -> tuple[shape_t, shape_t]: + diff = 4 - len(shape) + padded_shape = [1] * diff + shape + for i in range(len(permutation)): + permutation[i] += diff + padded_permutation = list(range(diff)) + permutation + return padded_shape, padded_permutation + + def axes_product(self, nhwc_shape: shape_t) -> int: + product = 1 + for axes in nhwc_shape: + product *= axes + return product + + def _permute_constraint_i8_i16( + self, nhwc_shape: list[int], permutation: list[int] + ) -> bool: + """Returns True if the constraints are ok.""" + N, H, W, C = nhwc_shape + match permutation: + case (0, 1, 2, 3): # NHWC -> NHWC + return True + case (0, 2, 1, 3) | (0, 1, 3, 2) | (0, 3, 1, 2): # NHWC -> NWHC, NHCW, NCWH + return N * H <= 65536 and W <= 65536 and C <= 65536 + case _: + return self.axes_product(nhwc_shape) <= 65536 + + def _permute_constraint_i32( + self, nhwc_shape: list[int], permutation: list[int] + ) -> bool: + """Returns True if the constraints are ok.""" + N, H, W, C = nhwc_shape + match permutation: + case (0, 1, 2, 3): # NHWC -> NHWC + return C <= 32768 + case (0, 2, 1, 3): # NHWC -> NHWC + return N == 1 and H <= 65536 and W <= 65536 and C <= 16384 + case (0, 1, 3, 2): # NHWC -> NHCW + return N * H <= 65536 and W <= 65536 and C <= 65536 + case _: + return False + + def _permute_constraint(self, shape, permutation, dtype): + if dtype in (torch.int8, torch.int16): + return self._permute_constraint_i8_i16(shape, permutation) + if dtype == torch.int32: + return not self._permute_constraint_i32(shape, permutation) + return True + + def is_node_supported( + self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node + ) -> bool: + + if not node.target == exir_ops.edge.aten.permute_copy.default: + return True + + shape = list(get_first_fake_tensor(node).shape) + dtype = _try_determine_dtype(node) + permutation = list(typing.cast(list[int], node.args[1])) + + rank = len(shape) + if rank > 4: + if dtype == torch.int32: + self.reporter.report_reject( + node, f"No support for {permutation=} in int32." + ) + return False + if dtype in (torch.int8, torch.int16): + if self.axes_product(shape) > 65536: + self.reporter.report_reject( + node, + f"No support for {shape=}, {dtype=}. Product of axes must be <65536", + ) + return False + return True + + shape, permutation = self._pad_to_rank_4(shape, permutation) + if rank == 3 or rank == 4: + # For rank 3 and 4, we can have channels first or channels last dim order. + # Since we don't know which at partition-time, test both. + + nhwc_shape = tosa_shape(shape, [0, 2, 3, 1]) + nhwc_permutation = transform_permutation_vector(permutation, [0, 2, 3, 1]) + + if not self._permute_constraint(nhwc_shape, nhwc_permutation, dtype): + self.reporter.report_reject( + node, + f"Unsupported NHWC {nhwc_shape=} for {nhwc_permutation=}, {dtype=}", + ) + return False + + if not self._permute_constraint(shape, permutation, dtype): + self.reporter.report_reject( + node, f"Unsupported NCHW {shape=} for {permutation=}, {dtype=}" + ) + return False + + return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 167be984e57..2a31ecbc775 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -21,6 +21,7 @@ from executorch.backends.arm.operator_support.ethos_u55_support import ( EthosU55DtypeSupport, EthosU55NotSupported, + EthosU55TransposeCheck, ) from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification from executorch.exir import ExportedProgram @@ -123,6 +124,7 @@ def tosa_support_factory( if isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset: negative_checks.append(EthosU55NotSupported(reporter)) negative_checks.append(EthosU55DtypeSupport(reporter)) + negative_checks.append(EthosU55TransposeCheck(reporter)) return chain( reporter.wrap_check( diff --git a/backends/arm/operators/op_permute.py b/backends/arm/operators/op_permute.py index 103ae1b9a2f..e659918baf2 100644 --- a/backends/arm/operators/op_permute.py +++ b/backends/arm/operators/op_permute.py @@ -65,6 +65,29 @@ def permutation_matrix_to_vector(permutation_matrix: torch.Tensor) -> list[int]: return p +def transform_permutation_vector(permutation_vector: list[int], dim_order: list[int]): + """Transforms a permutation to dim_order.""" + + # We need to first transform to dim_order, apply the permutation P, + # and then transform back to the original dim_order. + # This transformation, S, is also a permutation, with the dim_order as permutation vector. + + # To do this, represent P and S with permutation matrices. + # Matrices can handle chained transformations and inversion easily. + S = permutation_vector_to_matrix(dim_order) + # The inverse of a permutation matrix is its transpose. + S_inverse = S.t() + P = permutation_vector_to_matrix(permutation_vector) + + # The complete transformation is S * P * S_inverse. + transformation_matrix = S.matmul(P.matmul(S_inverse)) + + # Luckily, since it is just a combination of permutations, the result is also a permutation + # that can again be described by a new permutation vector. + permutation_vector = permutation_matrix_to_vector(transformation_matrix) + return permutation_vector + + @register_node_visitor class PermuteVisitor(NodeVisitor): target = "aten.permute_copy.default" @@ -86,23 +109,10 @@ def define_node( if output.dim_order != tuple(range(len(output.dim_order))): # the permutation vector can't be used directly if we are not in NCHW dim_order. - # We need to first transform to NCHW, apply P, - # and then transform back to the original dim_order. - # This transformation, S, is also a permutation, with the dim_order as permutation vector. - - # To do this, represent P and S with permutation matrices. - # Matrices can handle chained transformations and inversion easily. - S = permutation_vector_to_matrix(output.dim_order) - # The inverse of a permutation matrix is its transpose. - S_inverse = S.transpose(1, 0) - P = permutation_vector_to_matrix(permutation_vector) - - # The complete transformation is S * P * S_inverse. - transformation_matrix = S.matmul(P.matmul(S_inverse)) - - # Luckily, since it is just a combination of permutations, the result is also a permutation - # that can again be described by a new permutation vector. - permutation_vector = permutation_matrix_to_vector(transformation_matrix) + # Transform to dim_order. + permutation_vector = transform_permutation_vector( + permutation_vector, output.dim_order + ) attr = ts.TosaSerializerAttribute() attr.TransposeAttribute(permutation_vector) diff --git a/backends/arm/test/ops/test_permute.py b/backends/arm/test/ops/test_permute.py index 28b7d70af35..e71b6687865 100644 --- a/backends/arm/test/ops/test_permute.py +++ b/backends/arm/test/ops/test_permute.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. -# Copyright 2024-2025 Arm Limited and/or its affiliates. # All rights reserved. +# Copyright 2024-2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -20,6 +20,7 @@ ) from executorch.backends.arm.test import common, conftest from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.backends.arm.test.tester.test_pipeline import OpNotSupportedPipeline from executorch.backends.arm.tosa_specification import TosaSpecification from executorch.backends.xnnpack.test.tester.tester import Quantize from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -163,3 +164,26 @@ def test_permute_u85_BI_xfails( self._test_permute_ethos_BI_pipeline( self.Permute(dims=dims), common.get_u85_compile_spec(), (test_data,) ) + + +reject_data_suite = { + "int8_r3_axes_product": ([1, 700, 1000], [2, 1, 0], torch.int8), + "int8_r5_axes_product": ([1, 1, 1, 700, 1000], [0, 1, 2, 3, 4], torch.int8), + "int8_r4_NH_too_large": ([700, 100, 1, 1], [0, 1, 3, 2], torch.int8), + "int32_r5_no_support": ([2, 2, 2, 2, 2], [3, 4, 2, 1, 0], torch.int32), +} +input_t = tuple[torch.Tensor] + + +@common.parametrize("test_data", reject_data_suite) +def test_permute_u55_BI_not_delegated(test_data): + # Tests that we don't delegate these ops since they are not supported on U55. + shape, permutation, dtype = test_data + data = ((torch.rand(shape) * 10).to(dtype),) + pipeline = OpNotSupportedPipeline[input_t]( + TestPermute.Permute(dims=permutation), + data, + "TOSA-0.80+BI+u55", + {"executorch_exir_dialects_edge__ops_aten_permute_copy_default": 1}, + ) + pipeline.run()