From 6431223a722bd40e00180a3414d687e61e5db241 Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Thu, 12 Jun 2025 15:19:25 -0700 Subject: [PATCH] Use GraphBuilder in test_replace_ops_passes. #3 (#11501) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11501 Reviewed By: zonglinpeng Differential Revision: D76157744 --- .../aot/tests/test_replace_ops_passes.py | 537 +++++++++--------- 1 file changed, 271 insertions(+), 266 deletions(-) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index 81b628ef232..4ff84a296e8 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -4,18 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import operator import unittest from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union import torch -from executorch.backends.cadence.aot import compiler -from executorch.backends.cadence.aot.compiler import ( - export_to_edge, - quantize_and_export_to_edge, -) from executorch.backends.cadence.aot.graph_builder import ( GraphBuilder, single_op_builder, @@ -122,9 +117,9 @@ def test_replace_matmul_with_transposed_matmul( ), ) builder.output([matmul]) - original = builder.get_graph_module() + original_gm = builder.get_graph_module() p = ReplaceMatmulWithTransposedMatmulPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1, @@ -153,9 +148,9 @@ def test_replace_constant_pad_nd_with_slice( args=(x, [0, 0, 0, 0]), ) builder.output([matmul]) - original = builder.get_graph_module() + original_gm = builder.get_graph_module() p = ReplaceConstantPadNdWithSlicePass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.slice.Tensor), 1, @@ -178,13 +173,13 @@ def test_add_replace_scalar_with_tensor_arg( self, _, shape: Tuple[int], other: float ): x = torch.randn(shape) - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.add.Scalar, args=(x, other), ) p = ReplaceScalarWithTensorArgPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), 1, @@ -206,13 +201,13 @@ def test_sub_replace_scalar_with_tensor_arg( self, _, shape: Tuple[int], other: float ): x = torch.randn(shape) - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.sub.Scalar, args=(x, other), ) p = ReplaceScalarWithTensorArgPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.sub.Tensor), 1, @@ -234,13 +229,13 @@ def test_mul_replace_scalar_with_tensor_arg( self, _, shape: Tuple[int], other: float ): x = torch.randn(shape) - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.mul.Scalar, args=(x, other), ) p = ReplaceScalarWithTensorArgPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), 1, @@ -265,13 +260,13 @@ def test_div_replace_scalar_with_tensor_arg( other: float, ): x = torch.randn(*shape) - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.div.Scalar, args=(x, other), ) p = ReplaceScalarWithTensorArgPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.div.Tensor), 1, @@ -294,13 +289,13 @@ def test_replace_functionally_equivalent_op_targets_relu( self, _, shape: Tuple[int] ): x = torch.randn(shape) - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.relu_.default, args=(x,), ) p = ReplaceFunctionallyEquivalentOpTargets() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.relu.default), @@ -322,13 +317,13 @@ def test_replace_functionally_equivalent_op_targets_unsafe_split( self, _, shape: Tuple[int], split_size: int, dim: int ): x = torch.randn(shape) - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.unsafe_split.Tensor, args=(x, split_size, dim), ) p = ReplaceFunctionallyEquivalentOpTargets() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.split_copy.Tensor), 1, @@ -402,12 +397,12 @@ def test_replace_transposed_conv_with_linear( args=(convolution, [0, 2, 1]), ) builder.output([convolution]) - original = builder.get_graph_module() + original_gm = builder.get_graph_module() p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() p2 = ReplaceTransposedConvWithLinearPass() graph_after_passes = cast( - PassResult, p2(cast(PassResult, p1(original)).graph_module) + PassResult, p2(cast(PassResult, p1(original_gm)).graph_module) ).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.linear.default), @@ -488,9 +483,9 @@ def test_replace_convolution_optional_args_with_concrete_args( args=(convolution, [0, 2, 1]), ) builder.output([convolution]) - original = builder.get_graph_module() + original_gm = builder.get_graph_module() p = ReplaceConvolutionOptionalArgsWithConcreteArgsPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, @@ -512,13 +507,13 @@ def test_replace_convolution_optional_args_with_concrete_args( @torch.no_grad() def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]): x = torch.randn(shape) - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.constant_pad_nd.default, args=(x, padding), ) p = ReplacePadWithCatPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1, @@ -531,13 +526,13 @@ def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]): @torch.no_grad() def test_replace_repeat_with_cat(self): x = torch.randn([3, 5]) - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.repeat.default, args=(x, [1, 2]), ) p = ReplaceRepeatWithCatPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1, @@ -587,9 +582,9 @@ def test_replace_masked_scalar_tensor_with_full( args=(mask, scalar_tensor, x), ) builder.output([aten_where_self]) - original = builder.get_graph_module() + original_gm = builder.get_graph_module() p = ReplaceScalarTensorWithFullPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, @@ -607,13 +602,13 @@ def test_replace_masked_scalar_tensor_with_full( def test_replace_scalar_tensor_with_full( self, ): - original = single_op_builder( + original_gm = single_op_builder( placeholders=(), op=exir_ops.edge.aten.scalar_tensor.default, args=(0.123,), ) p = ReplaceScalarTensorWithFullPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, @@ -640,8 +635,10 @@ def test_replace_linear_with_fully_connected(self): args=(x, permute_copy), ) builder.output([mm]) - original = builder.get_graph_module() - gm = cast(PassResult, ReplacePermuteWithTransposePass()(original)).graph_module + original_gm = builder.get_graph_module() + gm = cast( + PassResult, ReplacePermuteWithTransposePass()(original_gm) + ).graph_module gm = cast(PassResult, ReplaceMMWithAddMMPass()(gm)).graph_module gm = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)).graph_module graph_after_passes = cast( @@ -688,8 +685,10 @@ def test_replace_addmm_with_linear( kwargs={"beta": beta, "alpha": alpha}, ) builder.output([addmm]) - original = builder.get_graph_module() - gm = cast(PassResult, ReplacePermuteWithTransposePass()(original)).graph_module + original_gm = builder.get_graph_module() + gm = cast( + PassResult, ReplacePermuteWithTransposePass()(original_gm) + ).graph_module graph_after_passes = cast( PassResult, ReplaceAddMMWithLinearPass()(gm) ).graph_module @@ -708,13 +707,13 @@ def test_replace_mm_with_addmm(self): M, K, N = 14, 48, 24 x = torch.randn([M, K]) y = torch.randn([K, N]) - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x, y), op=exir_ops.edge.aten.mm.default, args=(x, y), ) p = ReplaceMMWithAddMMPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.addmm.default), @@ -729,7 +728,7 @@ def test_replace_mm_with_addmm(self): [ # shape [(5, 1, 6, 7)], - [(1)], + [1], [(4, 3, 2)], # shape, dim to squeeze [(2, 1), 0], @@ -741,19 +740,19 @@ def test_replace_mm_with_addmm(self): def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): x = torch.randn(shape) if dim: - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.squeeze_copy.dim, args=(x, dim), ) else: - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.squeeze_copy.default, args=(x,), ) p = ReplaceSqueezeAndUnsqueezeWithViewPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), @@ -782,13 +781,13 @@ def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): @torch.no_grad() def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int): x = torch.randn(shape) - original = single_op_builder( + original_gm = single_op_builder( placeholders=(x,), op=exir_ops.edge.aten.unsqueeze_copy.default, args=(x, dim), ) p = ReplaceSqueezeAndUnsqueezeWithViewPass() - graph_after_passes = cast(PassResult, p(original)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertIsNotNone(graph_after_passes) self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), @@ -805,33 +804,56 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar( in_features: int = 16, out_features: int = 16, ): - # Tensors - these will be inputs to graph. - x = torch.randn([1, in_features]) - - inputs = (x,) - model = torch.nn.Linear(in_features=in_features, out_features=out_features) - - exported_program = quantize_and_export_to_edge(model, inputs).exported_program() - - # By default, the quantized linear op should have constant scalar attributes. - self.assertTargetCountsEqual( - exported_program.graph_module, - [ - # One quantized linear op. - (exir_ops.edge.cadence.quantized_linear.default, 1), - # No per tensor quantized linear ops. - (exir_ops.edge.cadence.quantized_linear.per_tensor, 0), - # Three aten.full ops. - (exir_ops.edge.aten.full.default, 3), - ], + src_zero_point = 0 + out_zero_point = 0 + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn([1, in_features])) + weights = builder.placeholder( + "weights", torch.randn([in_features, out_features], dtype=torch.float32) ) - - # Apply replacement pass. + bias = builder.placeholder( + "bias", torch.randn([out_features], dtype=torch.float32) + ) + quantized_input = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(x, 0.01431146077811718, 57, -128, 127, torch.int8), + ) + weight_zero_point = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 0), + ) + out_multiplier = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 0), + ) + out_shift = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 0), + ) + output = builder.call_operator( + op=exir_ops.edge.cadence.quantized_linear.default, + args=( + quantized_input, + weights, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + None, + ), + ) + dequantized_output = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(output, 0.010696045123040676, -31, -128, 127, torch.int8), + ) + builder.output([dequantized_output]) + original_gm = builder.get_graph_module() p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass() - graph_after_passes = p(exported_program.graph_module) + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertIsNotNone(graph_after_passes) - gm = dead_code_elimination_pass(graph_after_passes.graph_module).graph_module - + gm = dead_code_elimination_pass(graph_after_passes).graph_module # By default, the quantized linear op should have constant scalar attributes. self.assertTargetCountsEqual( gm, @@ -851,37 +873,63 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_ in_features: int = 16, out_features: int = 16, ): - # Tensors - these will be inputs to graph. - x = torch.randn([1, in_features]) - - inputs = (x,) - model = torch.nn.Linear(in_features=in_features, out_features=out_features) - - exported_program = quantize_and_export_to_edge(model, inputs).exported_program() - - # By default, the quantized linear op should have constant scalar attributes. - self.assertTargetCountsEqual( - exported_program.graph_module, - [ - # One quantized linear op. - (exir_ops.edge.cadence.quantized_linear.default, 1), - # No per tensor quantized linear ops. - (exir_ops.edge.cadence.quantized_linear.per_tensor, 0), - # Three aten.full ops. - (exir_ops.edge.aten.full.default, 3), - ], + src_zero_point = 0 + out_zero_point = 0 + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn([1, in_features])) + weights = builder.placeholder( + "weights", torch.randn([in_features, out_features], dtype=torch.float32) + ) + bias = builder.placeholder( + "bias", torch.randn([out_features], dtype=torch.float32) ) + quantized_input = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + args=(x, 0.01431146077811718, 57, -128, 127, torch.int8), + ) + weight_zero_point = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 0), + ) + out_multiplier = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 0), + ) + out_shift = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1], 0), + ) + output = builder.call_operator( + op=exir_ops.edge.cadence.quantized_linear.default, + args=( + quantized_input, + weights, + bias, + src_zero_point, + weight_zero_point, + out_multiplier, + out_shift, + out_zero_point, + None, + ), + ) + dequantized_output = builder.call_operator( + op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + args=(output, 0.010696045123040676, -31, -128, 127, torch.int8), + ) + builder.output([dequantized_output]) + original_gm = builder.get_graph_module() - for node in exported_program.graph_module.graph.nodes: + for node in original_gm.graph.nodes: # Replace the `shape` argument for aten.full op with a tuple. if node.target == exir_ops.edge.aten.full.default: node.args = (tuple(node.args[0]), node.args[1]) # Apply replacement pass. p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass() - graph_after_passes = p(exported_program.graph_module) + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertIsNotNone(graph_after_passes) - gm = dead_code_elimination_pass(graph_after_passes.graph_module).graph_module + gm = dead_code_elimination_pass(graph_after_passes).graph_module # By default, the quantized linear op should have constant scalar attributes. self.assertTargetCountsEqual( @@ -898,23 +946,17 @@ def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_ @torch.no_grad() def test_replace_conv1d_with_linear(self): - class Conv(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, kernel_size: int): - super().__init__() - self.conv1d = torch.nn.Conv1d(in_features, out_features, kernel_size) - - def forward(self, x): - return self.conv1d(x) - - model_conv1d = Conv(96, 192, 7) x = torch.randn(1, 96, 7) - graph_module = ( - export_to_edge(model_conv1d, (x,)).exported_program().graph_module + weights = torch.randn(192, 96, 7) + bias = torch.randn(192) + original_gm = single_op_builder( + placeholders=(x, weights, bias), + op=exir_ops.edge.cadence.convolution.default, + args=(x, weights, bias, [1], [0], [1], 1, False), ) - # First, replace the aten convolution with a cadence.convolution op p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() - temp_graph = p1(graph_module).graph_module + temp_graph = p1(original_gm).graph_module self.assertIsNotNone(temp_graph) p2 = ReplaceTrivialConvWithLinear() @@ -937,23 +979,17 @@ def forward(self, x): @torch.no_grad() def test_replace_conv2d_with_linear(self): - class Conv(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, kernel_size: int): - super().__init__() - self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) - - def forward(self, x): - return self.conv2d(x) - - model_conv2d = Conv(96, 192, 7) x = torch.randn(1, 96, 7, 7) - graph_module = ( - export_to_edge(model_conv2d, (x,)).exported_program().graph_module + weights = torch.randn(192, 96, 7, 7) + bias = torch.randn(192) + original_gm = single_op_builder( + placeholders=(x, weights, bias), + op=exir_ops.edge.cadence.convolution.default, + args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1, False), ) - # First, replace the aten convolution with a cadence.convolution op p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() - temp_graph = p1(graph_module).graph_module + temp_graph = p1(original_gm).graph_module self.assertIsNotNone(temp_graph) p2 = ReplaceTrivialConvWithLinear() @@ -976,24 +1012,16 @@ def forward(self, x): @torch.no_grad() def test_replace_conv2d_with_im2row_and_linear(self): - class Conv(torch.nn.Module): - def __init__(self, in_features: int, out_features: int, kernel_size: int): - super().__init__() - self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) - - def forward(self, x): - return self.conv2d(x) - - model_conv2d = Conv(96, 192, 7) x = torch.randn(1, 96, 47, 37) - graph_module = ( - compiler.export_to_cadence(model_conv2d, (x,)) - .exported_program() - .graph_module + weights = torch.randn(192, 96, 7, 7) + bias = torch.randn(192) + original_gm = single_op_builder( + placeholders=(x, weights, bias), + op=exir_ops.edge.cadence.convolution.default, + args=(x, weights, bias, [1, 1], [0, 0], [1, 1], 1, False), ) - p = ReplaceConvWithIm2RowAndLinear() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module # Assert that the convolution is converted to im2row + linear self.assertEqual( @@ -1014,17 +1042,14 @@ def forward(self, x): ) @torch.no_grad() def test_replace_select_with_view(self, shape: Tuple[int], dim: int, index: int): - class Select(torch.nn.Module): - def forward(self, x): - return x.select(dim, index) - x = torch.randn(shape) - graph_module = export_to_edge(Select(), (x,)).exported_program().graph_module - + original_gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.select_copy.int, + args=(x, dim, index), + ) p = ReplaceSelectWithViewOpPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - + graph_after_passes = cast(PassResult, p(original_gm)).graph_module # Assert that select op was replaced with view op self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 @@ -1048,17 +1073,17 @@ def test_replace_nop_transpose_with_view( dim1: int, dtype: torch.dtype = torch.float32, ): - class Transpose(torch.nn.Module): - def forward(self, x): - return x.transpose(dim0, dim1) - - _max_value = 127 - x = (torch.rand(shape) * _max_value).to(dtype=dtype) - graph_module = export_to_edge(Transpose(), (x,)).exported_program().graph_module - + if dtype == torch.float32: + x = torch.randn(shape) + else: + x = torch.randint(low=0, high=100, size=shape, dtype=torch.int64) + original_gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.transpose_copy.int, + args=(x, dim0, dim1), + ) p = ReplaceNopTransposeOrPermuteWithViewPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module # Assert that transpose op was removed, and a view op was placed instead self.assertEqual( @@ -1071,21 +1096,20 @@ def forward(self, x): @parameterized.expand( [ # permutations that can be replaced by view - [(3, 1, 3, 1, 4), (0, 2, 4, 1, 3), torch.float32], - [(1, 3, 4), (1, 2, 0), torch.float32], + [(3, 1, 3, 1, 4), (0, 2, 4, 1, 3)], + [(1, 3, 4), (1, 2, 0)], ] ) @torch.no_grad() - def test_replace_nop_permute_with_view(self, input_shape, dims, dtype): - class Permute(torch.nn.Module): - def forward(self, x): - return torch.permute(x, dims) - - x = torch.randn(input_shape).to(dtype=dtype) - graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module - + def test_replace_nop_permute_with_view(self, shape, dims): + x = torch.randn(shape) + original_gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.permute_copy.default, + args=(x, dims), + ) p = ReplaceNopTransposeOrPermuteWithViewPass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module # Assert that permute op was removed, and a view op was placed instead self.assertEqual( @@ -1098,21 +1122,20 @@ def forward(self, x): @parameterized.expand( [ # permutations replaced by transpose - [(3, 4), [1, 0], torch.float32], - [(3, 4, 6), (0, 2, 1), torch.float32], + [(3, 4), [1, 0]], + [(3, 4, 6), (0, 2, 1)], ] ) @torch.no_grad() - def test_replace_permute_with_transpose(self, input_shape, dims, dtype): - class Permute(torch.nn.Module): - def forward(self, x): - return torch.permute(x, dims) - - x = torch.randn(input_shape).to(dtype=dtype) - graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module - + def test_replace_permute_with_transpose(self, shape, dims): + x = torch.randn(shape) + original_gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.permute_copy.default, + args=(x, dims), + ) p = ReplacePermuteWithTransposePass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module # Assert that permute op was replaced by a transpose op self.assertEqual( @@ -1122,23 +1145,18 @@ def forward(self, x): count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1 ) - @parameterized.expand( - [ - # permutations replaced by transpose - [(3, 4), [0, 1], torch.float32], - ] - ) @torch.no_grad() - def test_replace_permute_with_transpose_nop(self, input_shape, dims, dtype): - class Permute(torch.nn.Module): - def forward(self, x): - return torch.permute(x, dims) - - x = torch.randn(input_shape).to(dtype=dtype) - graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module - + def test_replace_permute_with_transpose_nop( + self, + ): + x = torch.randn(3, 4) + original_gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 1]), + ) p = ReplacePermuteWithTransposePass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module # Assert that permute op was replaced by a transpose op self.assertEqual( @@ -1148,25 +1166,29 @@ def forward(self, x): count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 ) - def test_replace_aten_where_with_cadence_where_Scalar(self): - class WhereScalarModel(torch.nn.Module): - def forward(self, cond: torch.Tensor): - a = torch.ops.aten.full.default(a_shape, val1) - b = torch.ops.aten.full.default(b_shape, val2) - return torch.where(cond > 0, a, b) - - cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (4, 8), (4, 8), 0.0, 1.0] - cond = torch.randn(cond_shape) - - graph_module = ( - export_to_edge(WhereScalarModel(), (cond,)).exported_program().graph_module + def test_replace_aten_where_with_cadence(self): + builder = GraphBuilder() + cond = builder.placeholder("cond", torch.randn(4, 8)) + aten_gt_scalar = builder.call_operator( + op=exir_ops.edge.aten.gt.Scalar, + args=(cond, 0), ) - + aten_full_default = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([4, 8], 0.0), + ) + aten_full_default_1 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([4, 8], 1.0), + ) + aten_where_self = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(aten_gt_scalar, aten_full_default, aten_full_default_1), + ) + builder.output([aten_where_self]) + original_gm = builder.get_graph_module() p = ReplaceWhereWithFullArgsWithWhereScalar() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - - # Assert that aten.where op was replaced by a - # cadence.where_Scalar op + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node( graph_after_passes, @@ -1179,50 +1201,39 @@ def forward(self, cond: torch.Tensor): 1, ) - class WhereBroadcastModel(torch.nn.Module): - def forward(self, cond: torch.Tensor): - a = torch.ops.aten.full.default(a_shape, val1) - b = torch.ops.aten.full.default(b_shape, val2) - return torch.where(cond > 0, a, b) - - # a tensor bigger than cond and b - cond_shape, a_shape, b_shape, val1, val2 = [(8,), (4, 8), (8,), 0.0, 1.0] - cond = torch.randn(cond_shape) - - graph_module = ( - export_to_edge(WhereBroadcastModel(), (cond,)) - .exported_program() - .graph_module + @parameterized.expand( + [ + [(4, 8), (4, 8), (4, 8), 0.0, 1.0], + [(8,), (4, 8), (8,), 0.0, 1.0], + [(4, 8), (8,), (8,), 0.0, 1.0], + ] + ) + def test_replace_aten_where_with_cadence_broadcast( + self, cond_shape, a_shape, b_shape, val1, val2 + ): + # cond_shape, a_shape, b_shape, val1, val2 = + builder = GraphBuilder() + cond = builder.placeholder("cond", torch.randn(cond_shape)) + aten_gt_scalar = builder.call_operator( + op=exir_ops.edge.aten.gt.Scalar, + args=(cond, 0), ) - - p = ReplaceWhereWithFullArgsWithWhereScalar() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - - # Assert that aten.where op is still in the graph since where_Scalar does not - # support broadcast - self.assertEqual( - count_node( - graph_after_passes, - exir_ops.edge.aten.where.self, - ), - 1, + aten_full_default = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(a_shape, val1), ) - - # cond tensor bigger than a and b - cond_shape, a_shape, b_shape, val1, val2 = [(4, 8), (8,), (8,), 0.0, 1.0] - cond = torch.randn(cond_shape) - - graph_module = ( - export_to_edge(WhereBroadcastModel(), (cond,)) - .exported_program() - .graph_module + aten_full_default_1 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=(b_shape, val2), ) - + aten_where_self = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(aten_gt_scalar, aten_full_default, aten_full_default_1), + ) + builder.output([aten_where_self]) + original_gm = builder.get_graph_module() p = ReplaceWhereWithFullArgsWithWhereScalar() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - - # Assert that aten.where op is still in the graph since where_Scalar does not - # support broadcast + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node( graph_after_passes, @@ -1281,17 +1292,14 @@ def test_replace_split_with_sizes_with_slice(self): @parameterized.expand([[2], [3], [4]]) def test_replace_pow_with_mul(self, exponent: int): - class Pow(torch.nn.Module): - def forward(self, input): - return torch.ops.aten.pow.Tensor_Scalar(input, exponent) - - input = torch.randn(2, 1, 64) - - graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module - + x = torch.randn(2, 1, 64) + original_gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.pow.Tensor_Scalar, + args=(x, exponent), + ) p = ReplacePowWithMulPass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module - + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node( graph_after_passes, @@ -1299,7 +1307,6 @@ def forward(self, input): ), 0, ) - self.assertEqual( count_node( graph_after_passes, @@ -1315,16 +1322,14 @@ def forward(self, input): ] ) def test_replace_pow_with_mul_not_applied(self, exponent): - class Pow(torch.nn.Module): - def forward(self, input): - return torch.ops.aten.pow.Tensor_Scalar(input, exponent) - - input = torch.randn(2, 1, 64) - - graph_module = export_to_edge(Pow(), (input,)).exported_program().graph_module - + x = torch.randn(2, 1, 64) + original_gm = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.pow.Tensor_Scalar, + args=(x, exponent), + ) p = ReplacePowWithMulPass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original_gm)).graph_module self.assertEqual( count_node(