From 3d4779472d4b23274d1ad84b544706d996605d34 Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Wed, 4 Jun 2025 14:35:26 -0700 Subject: [PATCH] Use GraphBuilder in test_replace_ops_passes. #1 (#11344) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11344 Reviewed By: hsharma35 Differential Revision: D75911655 --- .../aot/tests/test_replace_ops_passes.py | 265 ++++++++---------- 1 file changed, 118 insertions(+), 147 deletions(-) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index e8215c378f9..79649063119 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -57,7 +57,6 @@ from executorch.exir.passes import dead_code_elimination_pass from parameterized.parameterized import parameterized -from torch._ops import OpOverload from torch.fx.passes.infra.pass_base import PassResult @@ -87,36 +86,46 @@ def assertTargetCountsEqual( @parameterized.expand( [ - # Regular MM - [(64, 33), (33, 128)], - # Batched MM - [(2, 48, 48), (2, 48, 48)], - ] + ( + "regular", + (64, 33), # x_shape + (33, 128), # y_shape + ), + ( + "batched", + (2, 48, 48), # x_shape + (2, 48, 48), # y_shape + ), + ], ) @torch.no_grad() def test_replace_matmul_with_transposed_matmul( self, + _, x_shape: Tuple[int], y_shape: Tuple[int], ) -> None: - class MatMul(torch.nn.Module): - def __init__(self) -> None: - super(MatMul, self).__init__() - - def forward(self, x, y): - return torch.matmul(x, y) - - model = MatMul() - X = torch.randn(x_shape) - Y = torch.randn(y_shape) - p = ReplaceMatmulWithTransposedMatmulPass() - inputs = (X, Y) - graph_module = ( - quantize_and_export_to_edge(model, inputs).exported_program().graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*x_shape, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn(*y_shape, dtype=torch.float32)) + matmul = builder.call_operator( + op=exir_ops.edge.cadence.quantized_matmul.default, + args=( + x, + 0, # X_zero_point + y, + 0, # Y_zero_point, + None, # bias + 1, # out_multiplier + 0, # out_shift + 0, # out_zero_point + False, # transposed=False + ), ) - # pyre-fixme[16]: Optional type has no attribute `graph_module` - graph_after_passes = p(graph_module).graph_module - + builder.output([matmul]) + original = builder.get_graph_module() + p = ReplaceMatmulWithTransposedMatmulPass() + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1, @@ -130,33 +139,24 @@ def forward(self, x, y): @parameterized.expand( [ - [(3, 5), (0, 0)], - [ - (20, 1, 80), - (0, 0), - ], - ] + ("2d", (3, 5), [0, 0]), # shape # padding + ("3d", (20, 1, 80), [0, 0, 0]), # shape # padding + ], ) @torch.no_grad() def test_replace_constant_pad_nd_with_slice( - self, shape: Tuple[int], padding: Tuple[int] + self, _, shape: Tuple[int], padding: Tuple[int] ): - # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition. - class Padding(torch.nn.Module): - def __init__(self): - super().__init__() - self.padding = padding - - def forward(self, x: torch.Tensor): - return F.pad(x, self.padding) - - model = Padding() - x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + matmul = builder.call_operator( + op=exir_ops.edge.aten.constant_pad_nd.default, + args=(x, [0, 0, 0, 0]), + ) + builder.output([matmul]) + original = builder.get_graph_module() p = ReplaceConstantPadNdWithSlicePass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.slice.Tensor), 1, @@ -169,28 +169,27 @@ def forward(self, x: torch.Tensor): @parameterized.expand( [ - [(7, 5, 6), 1.23], - [(7, 5), 2], + ["3d", (7, 5, 6), 1.23], + ["2d", (7, 5), 2], + ["1d", (10,), 42949], ] ) @torch.no_grad() - def test_add_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float): - class Add(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.add.Scalar(x, other) - - model = Add() + def test_add_replace_scalar_with_tensor_arg( + self, _, shape: Tuple[int], other: float + ): x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.add.Scalar, + args=(x, other), + ) p = ReplaceScalarWithTensorArgPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), 1, ) - self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.add.Scalar), 0, @@ -198,29 +197,27 @@ def forward(self, x): @parameterized.expand( [ - [(7, 5, 6), 1.23], - [(7, 5), 2], - [(10), 42949], + ["3d", (7, 5, 6), 1.23], + ["2d", (7, 5), 2], + ["1d", (10,), 42949], ] ) @torch.no_grad() - def test_sub_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float): - class Sub(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.sub.Scalar(x, other) - - model = Sub() + def test_sub_replace_scalar_with_tensor_arg( + self, _, shape: Tuple[int], other: float + ): x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.sub.Scalar, + args=(x, other), + ) p = ReplaceScalarWithTensorArgPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.sub.Tensor), 1, ) - self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.sub.Scalar), 0, @@ -228,29 +225,27 @@ def forward(self, x): @parameterized.expand( [ - [(7, 5, 6), 1.23], - [(7, 5), 2], - [(513), 3], + ["3d", (7, 5, 6), 1.23], + ["2d", (7, 5), 2], + ["1d", (10,), 42949], ] ) @torch.no_grad() - def test_mul_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float): - class Mul(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.mul.Scalar(x, other) - - model = Mul() + def test_mul_replace_scalar_with_tensor_arg( + self, _, shape: Tuple[int], other: float + ): x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.mul.Scalar, + args=(x, other), + ) p = ReplaceScalarWithTensorArgPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), 1, ) - self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.mul.Scalar), 0, @@ -258,32 +253,30 @@ def forward(self, x): @parameterized.expand( [ - [(7, 5, 6), 1.23], - [(7, 5), 2], + ["3d", (7, 5, 6), 1.23], + ["2d", (7, 5), 2], + ["1d", (10,), 42949], ] ) @torch.no_grad() def test_div_replace_scalar_with_tensor_arg( self, + _, shape: Tuple[int], other: float, ): - class Div(torch.nn.Module): - def forward(self, x): - return torch.ops.aten.div.Scalar(x, other) - - model = Div() - x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - + x = torch.randn(*shape) + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.div.Scalar, + args=(x, other), + ) p = ReplaceScalarWithTensorArgPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.div.Tensor), 1, ) - self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.div.Scalar), 0, @@ -291,20 +284,25 @@ def forward(self, x): @parameterized.expand( [ - [(2, 3, 5, 6)], - [(7, 6, 5)], - [(4, 4)], - [(316)], + ["4d", (2, 3, 5, 6)], + ["3d", (7, 6, 5)], + ["2d", (4, 4)], + ["1d", (316)], ] ) @torch.no_grad() - def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int]): - model = torch.nn.ReLU() + def test_replace_functionally_equivalent_op_targets_relu( + self, _, shape: Tuple[int] + ): x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.relu_.default, + args=(x,), + ) p = ReplaceFunctionallyEquivalentOpTargets() + graph_after_passes = cast(PassResult, p(original)).graph_module - graph_after_passes = cast(PassResult, p(graph_module)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.relu.default), 1, @@ -315,56 +313,29 @@ def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int] ) @parameterized.expand( - [ - # split the only dimension - [(50,), i, 0] - for i in range(2, 7) - ] - + [ - # split the leading dim - [(10, 2, 3), i, 0] - for i in range(2, 7) - ] - + [ - # split the trailing dim - [(3, 3, 6), i, 2] - for i in range(2, 6) - ] - + [ - # split the dim in the middle - [(3, 5, 14, 2, 3), i, 2] - for i in range(2, 7) - ] + [["split_linear_tensor", (50,), i, 0] for i in range(2, 7)] + + [["split_leading_dim", (10, 2, 3), i, 0] for i in range(2, 7)] + + [["split_trailing_dim", (3, 3, 6), i, 2] for i in range(2, 6)] + + [["split_middle_dim", (3, 5, 14, 2, 3), i, 2] for i in range(2, 7)] ) @torch.no_grad() def test_replace_functionally_equivalent_op_targets_unsafe_split( - self, shape: Tuple[int], split_size: int, dim: int + self, _, shape: Tuple[int], split_size: int, dim: int ): - class TensorSplitWithSizes(torch.nn.Module): - def __init__(self, split_size: int, dim: int, op: OpOverload): - super().__init__() - self.split_size = split_size - self.dim = dim - self.op = op - - def forward(self, x: torch.Tensor): - return self.op(x, self.split_size, self.dim) - x = torch.randn(shape) - model = TensorSplitWithSizes(split_size, dim, torch.unsafe_split) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.unsafe_split.Tensor, + args=(x, split_size, dim), + ) p = ReplaceFunctionallyEquivalentOpTargets() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( - count_node( - graph_after_passes, exir_ops.edge.aten.split_with_sizes_copy.default - ), + count_node(graph_after_passes, exir_ops.edge.aten.split_copy.Tensor), 1, ) self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), - 0, + count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 0, x ) @parameterized.expand(