diff --git a/backends/arm/_passes/match_arg_ranks_pass.py b/backends/arm/_passes/match_arg_ranks_pass.py index 0f5bdfae599..04fc8d00c70 100644 --- a/backends/arm/_passes/match_arg_ranks_pass.py +++ b/backends/arm/_passes/match_arg_ranks_pass.py @@ -47,6 +47,7 @@ def __init__(self, exported_program): exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.bitwise_right_shift.Tensor, exir_ops.edge.aten.bitwise_left_shift.Tensor, + exir_ops.edge.aten.eq.Tensor, ] def _match_op_rank(self, graph_module, node, arg, max_rank): diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index 8eac968de0e..4843a17eb1d 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -158,6 +158,7 @@ def is_node_supported( exir_ops.edge.aten.hardswish.default, exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.eq.Scalar, exir_ops.edge.aten.exp.default, exir_ops.edge.aten.log.default, exir_ops.edge.aten.linear.default, @@ -235,6 +236,7 @@ class EthosU55NotSupported(OperatorSupportBase): exir_ops.edge.aten.amax.default, # REDUCE_MAX exir_ops.edge.aten.amin.default, # REDUCE_MIN exir_ops.edge.aten.eq.Tensor, + exir_ops.edge.aten.eq.Scalar, exir_ops.edge.aten.ge.Tensor, exir_ops.edge.aten.gt.Tensor, exir_ops.edge.aten.le.Tensor, diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 42da8fde7d2..e270fb18205 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -31,11 +31,10 @@ class TestConformer(unittest.TestCase): # .to_executorch step, i.e. after Arm partitioner. ops_after_partitioner = { "executorch_exir_dialects_edge__ops_aten_max_default": 1, - "executorch_exir_dialects_edge__ops_aten_eq_Scalar": 2, "executorch_exir_dialects_edge__ops_aten_where_self": 4, "torch.ops.aten._assert_scalar.default": 10, "torch.ops.aten._local_scalar_dense.default": 1, - "torch.ops.higher_order.executorch_call_delegate": 6, + "torch.ops.higher_order.executorch_call_delegate": 4, } dim = 16 diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py index 2656c12417d..a6da04b0e2e 100644 --- a/backends/arm/test/models/test_llama.py +++ b/backends/arm/test/models/test_llama.py @@ -114,7 +114,7 @@ def test_llama_tosa_MI(self): ) .export() .to_edge_transform_and_lower() - .check_count({"torch.ops.higher_order.executorch_call_delegate": 26}) + .check_count({"torch.ops.higher_order.executorch_call_delegate": 14}) .to_executorch() .run_method_and_compare_outputs( inputs=llama_inputs, diff --git a/backends/arm/test/ops/test_eq.py b/backends/arm/test/ops/test_eq.py index 329f65dfead..727cde2b11d 100644 --- a/backends/arm/test/ops/test_eq.py +++ b/backends/arm/test/ops/test_eq.py @@ -5,7 +5,6 @@ from typing import Tuple -import pytest import torch from executorch.backends.arm.test import common @@ -16,13 +15,15 @@ TosaPipelineMI, ) -aten_op = "torch.ops.aten.eq.Tensor" -exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor" input_t = Tuple[torch.Tensor] class Equal(torch.nn.Module): + aten_op_BI = "torch.ops.aten.eq.Tensor" + aten_op_MI = "torch.ops.aten.eq.Scalar" + exir_op = "executorch_exir_dialects_edge__ops_aten_eq_Tensor" + def __init__(self, input, other): super().__init__() self.input_ = input @@ -31,7 +32,7 @@ def __init__(self, input, other): def forward( self, input_: torch.Tensor, - other_: torch.Tensor, + other_: torch.Tensor | int | float, ): return input_ == other_ @@ -39,98 +40,111 @@ def get_inputs(self): return (self.input_, self.other_) -op_eq_rank1_ones = Equal( +op_eq_tensor_rank1_ones = Equal( torch.ones(5), torch.ones(5), ) -op_eq_rank2_rand = Equal( +op_eq_tensor_rank2_rand = Equal( torch.rand(4, 5), torch.rand(1, 5), ) -op_eq_rank3_randn = Equal( +op_eq_tensor_rank3_randn = Equal( torch.randn(10, 5, 2), torch.randn(10, 5, 2), ) -op_eq_rank4_randn = Equal( +op_eq_tensor_rank4_randn = Equal( torch.randn(3, 2, 2, 2), torch.randn(3, 2, 2, 2), ) -test_data_common = { - "eq_rank1_ones": op_eq_rank1_ones, - "eq_rank2_rand": op_eq_rank2_rand, - "eq_rank3_randn": op_eq_rank3_randn, - "eq_rank4_randn": op_eq_rank4_randn, +op_eq_scalar_rank1_ones = Equal(torch.ones(5), 1.0) +op_eq_scalar_rank2_rand = Equal(torch.rand(4, 5), 0.2) +op_eq_scalar_rank3_randn = Equal(torch.randn(10, 5, 2), -0.1) +op_eq_scalar_rank4_randn = Equal(torch.randn(3, 2, 2, 2), 0.3) + +test_data_tensor = { + "eq_tensor_rank1_ones": op_eq_tensor_rank1_ones, + "eq_tensor_rank2_rand": op_eq_tensor_rank2_rand, + "eq_tensor_rank3_randn": op_eq_tensor_rank3_randn, + "eq_tensor_rank4_randn": op_eq_tensor_rank4_randn, } +test_data_scalar = { + "eq_scalar_rank1_ones": op_eq_scalar_rank1_ones, + "eq_scalar_rank2_rand": op_eq_scalar_rank2_rand, + "eq_scalar_rank3_randn": op_eq_scalar_rank3_randn, + "eq_scalar_rank4_randn": op_eq_scalar_rank4_randn, +} + + +@common.parametrize("test_module", test_data_tensor) +def test_eq_tensor_tosa_MI(test_module): + pipeline = TosaPipelineMI[input_t]( + test_module, test_module.get_inputs(), Equal.aten_op_BI, Equal.exir_op + ) + pipeline.run() -@common.parametrize("test_module", test_data_common) -def test_eq_tosa_MI(test_module): + +@common.parametrize("test_module", test_data_scalar) +def test_eq_scalar_tosa_MI(test_module): pipeline = TosaPipelineMI[input_t]( - test_module, test_module.get_inputs(), aten_op, exir_op + test_module, + test_module.get_inputs(), + Equal.aten_op_MI, + Equal.exir_op, ) pipeline.run() -@common.parametrize("test_module", test_data_common) +@common.parametrize("test_module", test_data_tensor | test_data_scalar) def test_eq_tosa_BI(test_module): pipeline = TosaPipelineBI[input_t]( - test_module, test_module.get_inputs(), aten_op, exir_op + test_module, test_module.get_inputs(), Equal.aten_op_BI, Equal.exir_op ) pipeline.run() -@common.parametrize("test_module", test_data_common) -def test_eq_u55_BI(test_module): +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone300 +def test_eq_tensor_u55_BI(test_module): # EQUAL is not supported on U55. pipeline = OpNotSupportedPipeline[input_t]( test_module, test_module.get_inputs(), "TOSA-0.80+BI+u55", - {exir_op: 1}, - ) - pipeline.run() - - -@common.parametrize("test_module", test_data_common) -def test_eq_u85_BI(test_module): - pipeline = EthosU85PipelineBI[input_t]( - test_module, - test_module.get_inputs(), - aten_op, - exir_op, - run_on_fvp=False, - use_to_edge_transform_and_lower=True, + {Equal.exir_op: 1}, ) pipeline.run() -@common.parametrize("test_module", test_data_common) -@pytest.mark.skip(reason="The same as test_eq_u55_BI") -def test_eq_u55_BI_on_fvp(test_module): +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone300 +def test_eq_scalar_u55_BI(test_module): # EQUAL is not supported on U55. pipeline = OpNotSupportedPipeline[input_t]( test_module, test_module.get_inputs(), "TOSA-0.80+BI+u55", - {exir_op: 1}, + {Equal.exir_op: 1}, + n_expected_delegates=1, ) pipeline.run() @common.parametrize( "test_module", - test_data_common, - xfails={"eq_rank4_randn": "4D fails because boolean Tensors can't be subtracted"}, + test_data_tensor | test_data_scalar, + xfails={ + "eq_tensor_rank4_randn": "4D fails because boolean Tensors can't be subtracted", + }, ) -@common.SkipIfNoCorstone320 -def test_eq_u85_BI_on_fvp(test_module): +@common.XfailIfNoCorstone320 +def test_eq_u85_BI(test_module): pipeline = EthosU85PipelineBI[input_t]( test_module, test_module.get_inputs(), - aten_op, - exir_op, + Equal.aten_op_BI, + Equal.exir_op, run_on_fvp=True, - use_to_edge_transform_and_lower=True, ) pipeline.run() diff --git a/backends/transforms/replace_scalar_with_tensor.py b/backends/transforms/replace_scalar_with_tensor.py index 1f79a525437..b1bab5b0b66 100644 --- a/backends/transforms/replace_scalar_with_tensor.py +++ b/backends/transforms/replace_scalar_with_tensor.py @@ -26,12 +26,14 @@ class ReplaceScalarWithTensorArgPass(ExportPass): exir_ops.edge.aten.div.Scalar: exir_ops.edge.aten.div.Tensor, exir_ops.edge.aten.__rshift__.Scalar: exir_ops.edge.aten.bitwise_right_shift.Tensor, exir_ops.edge.aten.__lshift__.Scalar: exir_ops.edge.aten.bitwise_left_shift.Tensor, + exir_ops.edge.aten.eq.Scalar: exir_ops.edge.aten.eq.Tensor, torch.ops.aten.add.Scalar: torch.ops.aten.add.Tensor, torch.ops.aten.sub.Scalar: torch.ops.aten.sub.Tensor, torch.ops.aten.mul.Scalar: torch.ops.aten.mul.Tensor, torch.ops.aten.div.Scalar: torch.ops.aten.div.Tensor, torch.ops.aten.__rshift__.Scalar: torch.ops.aten.bitwise_right_shift.Tensor, torch.ops.aten.__lshift__.Scalar: torch.ops.aten.bitwise_left_shift.Tensor, + torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor, } def get_replacement(self, op, args, kwargs, meta):