From 8c195ffb036f46a123c567b1a853aa628dbb657d Mon Sep 17 00:00:00 2001 From: Ekaterina Ignasheva Date: Fri, 6 Jun 2025 14:09:35 -0700 Subject: [PATCH] Use GraphBuilder in test_replace_ops_passes. #2 (#11456) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/11456 Reviewed By: hsharma35 Differential Revision: D75982351 --- .../aot/tests/test_replace_ops_passes.py | 459 +++++++++--------- 1 file changed, 216 insertions(+), 243 deletions(-) diff --git a/backends/cadence/aot/tests/test_replace_ops_passes.py b/backends/cadence/aot/tests/test_replace_ops_passes.py index c9762946cb5..81b628ef232 100644 --- a/backends/cadence/aot/tests/test_replace_ops_passes.py +++ b/backends/cadence/aot/tests/test_replace_ops_passes.py @@ -11,7 +11,6 @@ from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union import torch -import torch.nn.functional as F from executorch.backends.cadence.aot import compiler from executorch.backends.cadence.aot.compiler import ( export_to_edge, @@ -343,7 +342,7 @@ def test_replace_functionally_equivalent_op_targets_unsafe_split( [(1, 8, 33), 8, 16, 3], [(1, 8, 33), 8, 16, 5, 2], [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False], - # channel last + # # channel last [(1, 33, 8), 8, 16, 3, 1, 0, 1, False, False, True], [(1, 33, 8), 8, 16, 5, 2, 0, 1, False, True, True], ] @@ -359,38 +358,56 @@ def test_replace_transposed_conv_with_linear( padding: int = 0, dilation: int = 1, depthwise: bool = False, - bias: bool = True, + bias_enabled: bool = True, channel_last: bool = False, ): - class TConv(torch.nn.Module): - def __init__(self): - super().__init__() - self.tconv1d = torch.nn.ConvTranspose1d( - in_channels, - out_channels, - kernel, - stride=stride, - padding=padding, - dilation=dilation, - groups=in_channels if depthwise else 1, - bias=bias, - ) - - def forward(self, x: torch.Tensor): - if channel_last: - x = x.permute([0, 2, 1]) - x = self.tconv1d(x) - if channel_last: - x = x.permute([0, 2, 1]) - return x + transposed = True + output_padding = [0] + groups = in_channels if depthwise else 1 + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + weights = builder.placeholder( + "weights", + torch.randn([in_channels, out_channels, kernel], dtype=torch.float32), + ) + bias = ( + builder.placeholder( + "bias", torch.randn([out_channels], dtype=torch.float32) + ) + if bias_enabled + else None + ) + if channel_last: + x = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 2, 1]), + ) + convolution = builder.call_operator( + op=exir_ops.edge.aten.convolution.default, + args=( + x, + weights, + bias, + [stride], + [padding], + [dilation], + transposed, + output_padding, + groups, + ), + ) + if channel_last: + convolution = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(convolution, [0, 2, 1]), + ) + builder.output([convolution]) + original = builder.get_graph_module() - x = torch.randn(shape) - model = TConv() - graph_module = export_to_edge(model, (x,)).exported_program().graph_module p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() p2 = ReplaceTransposedConvWithLinearPass() graph_after_passes = cast( - PassResult, p2(cast(PassResult, p1(graph_module)).graph_module) + PassResult, p2(cast(PassResult, p1(original)).graph_module) ).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.linear.default), @@ -427,39 +444,53 @@ def test_replace_convolution_optional_args_with_concrete_args( padding: int = 0, dilation: int = 1, depthwise: bool = False, - bias: bool = True, + bias_enabled: bool = True, channel_last: bool = False, ): - class Conv(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv1d = torch.nn.Conv1d( - in_channels, - out_channels, - kernel, - stride=stride, - padding=padding, - dilation=dilation, - groups=in_channels if depthwise else 1, - bias=bias, - ) - - def forward(self, x: torch.Tensor): - if channel_last: - x = x.permute([0, 2, 1]) - x = self.conv1d(x) - if channel_last: - x = x.permute([0, 2, 1]) - return x - - x = torch.randn(shape) - model = Conv() - - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - + transposed = True + output_padding = [0] + groups = in_channels if depthwise else 1 + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + weights = builder.placeholder( + "weights", + torch.randn([in_channels, out_channels, kernel], dtype=torch.float32), + ) + bias = ( + builder.placeholder( + "bias", torch.randn([out_channels], dtype=torch.float32) + ) + if bias_enabled + else None + ) + if channel_last: + x = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 2, 1]), + ) + convolution = builder.call_operator( + op=exir_ops.edge.aten.convolution.default, + args=( + x, + weights, + bias, + [stride], + [padding], + [dilation], + transposed, + output_padding, + groups, + ), + ) + if channel_last: + convolution = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(convolution, [0, 2, 1]), + ) + builder.output([convolution]) + original = builder.get_graph_module() p = ReplaceConvolutionOptionalArgsWithConcreteArgsPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, @@ -471,58 +502,46 @@ def forward(self, x: torch.Tensor): @parameterized.expand( [ - [(1, 2, 3), (1, 1)], + [(1, 2, 3), [1, 1]], [ (20, 1, 80), - (1, 4), + [1, 4], ], ] ) @torch.no_grad() def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]): - # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition. - class Padding(torch.nn.Module): - def __init__(self): - super().__init__() - self.padding = padding - - def forward(self, x: torch.Tensor): - return F.pad(x, self.padding) - - model = Padding() x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.constant_pad_nd.default, + args=(x, padding), + ) p = ReplacePadWithCatPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1, ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.pad.default), + count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default), 0, ) @torch.no_grad() def test_replace_repeat_with_cat(self): - class Repeat(torch.nn.Module): - def forward(self, x): - x1 = torch.add(x, 2.4, 3.1) - return torch.ops.aten.repeat(x1, [1, 2]) - - x = torch.ones(3, 5) - graph_module = export_to_edge(Repeat(), (x,)).exported_program().graph_module - + x = torch.randn([3, 5]) + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.repeat.default, + args=(x, [1, 2]), + ) p = ReplaceRepeatWithCatPass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 1, ) - self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.repeat.default), 0, @@ -535,7 +554,7 @@ def forward(self, x): [(3, 4)], [(7, 8, 3)], [(3, 3, 2, 4)], - [(36, 1, 2, 80), (1)], + [(36, 1, 2, 80), (1,)], # tests where mask will be broadcasted [(36, 1, 2, 80), (1, 1, 2, 1)], [(36, 2, 8, 4), (36, 1, 1, 4)], @@ -548,67 +567,57 @@ def test_replace_masked_scalar_tensor_with_full( shape: Tuple[int], mask_shape: Union[Tuple[int, ...], None] = None, ): - class MaskedFill(torch.nn.Module): - def __init__(self, value: float): - super().__init__() - self.value = value - - def forward(self, x: torch.Tensor, mask: torch.Tensor): - return torch.masked_fill(x, mask, self.value) - - x = torch.randn(shape) - mask = torch.randn(mask_shape if mask_shape else shape) > 0 - value = 0.5 * torch.mean(x).item() - model = MaskedFill(value) - graph_module = export_to_edge(model, (x, mask)).exported_program().graph_module - + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + mask = builder.placeholder( + "mask", + torch.randint(0, 2, mask_shape if mask_shape else shape, dtype=torch.bool), + ) + scalar_tensor = builder.call_operator( + op=exir_ops.edge.aten.scalar_tensor.default, + args=(0.123,), + kwargs={ + "dtype": torch.float32, + "layout": torch.strided, + "device": torch.device("cpu"), + }, + ) + aten_where_self = builder.call_operator( + op=exir_ops.edge.aten.where.self, + args=(mask, scalar_tensor, x), + ) + builder.output([aten_where_self]) + original = builder.get_graph_module() p = ReplaceScalarTensorWithFullPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, ) - self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.where.self), 1, ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.masked_fill), + count_node(graph_after_passes, exir_ops.edge.aten.scalar_tensor.default), 0, ) - @parameterized.expand( - [ - [(1), 1.5], - [(1), 0.0], - ] - ) @torch.no_grad() - def test_replace_scalar_tensor_with_full(self, shape: Tuple[int], value: float): - class ScalarTensor(torch.nn.Module): - def __init__(self, shape: Tuple[int], value: float): - super().__init__() - self.shape = shape - self.value = value - - def forward(self, x: torch.Tensor): - return torch.scalar_tensor(value) - - model = ScalarTensor(shape, value) - x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - + def test_replace_scalar_tensor_with_full( + self, + ): + original = single_op_builder( + placeholders=(), + op=exir_ops.edge.aten.scalar_tensor.default, + args=(0.123,), + ) p = ReplaceScalarTensorWithFullPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, ) - self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.scalar_tensor.default), 0, @@ -616,36 +625,39 @@ def forward(self, x: torch.Tensor): @torch.no_grad() def test_replace_linear_with_fully_connected(self): - shape, in_features, out_features, bias = (1, 14), 14, 128, False - model = torch.nn.Linear(in_features, out_features, bias=bias) - x = torch.randn(shape) - - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - permute_to_trans_pass = ReplacePermuteWithTransposePass() - mm_to_addmm_pass = ReplaceMMWithAddMMPass() - add_to_linear_pass = ReplaceAddMMWithLinearPass() - linear_to_fullyconnected_pass = ReplaceLinearWithFullyConnectedOpPass() - graph_after_passes = linear_to_fullyconnected_pass( - add_to_linear_pass( - mm_to_addmm_pass( - permute_to_trans_pass(graph_module).graph_module - ).graph_module - ).graph_module + shape, in_channels, out_channels = (1, 14), 14, 128 + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32)) + weights = builder.placeholder( + "weights", torch.randn([out_channels, in_channels], dtype=torch.float32) + ) + permute_copy = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(weights, [1, 0]), + ) + mm = builder.call_operator( + op=exir_ops.edge.aten.mm.default, + args=(x, permute_copy), + ) + builder.output([mm]) + original = builder.get_graph_module() + gm = cast(PassResult, ReplacePermuteWithTransposePass()(original)).graph_module + gm = cast(PassResult, ReplaceMMWithAddMMPass()(gm)).graph_module + gm = cast(PassResult, ReplaceAddMMWithLinearPass()(gm)).graph_module + graph_after_passes = cast( + PassResult, ReplaceLinearWithFullyConnectedOpPass()(gm) ).graph_module self.assertIsNotNone(graph_after_passes) - self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.full.default), 1, ) - self.assertEqual( count_node( graph_after_passes, exir_ops.edge.cadence.fully_connected.default ), 1, ) - self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.linear), 0, @@ -661,41 +673,27 @@ def test_replace_linear_with_fully_connected(self): def test_replace_addmm_with_linear( self, shape: Tuple[int], in_features: int, out_features: int, bias: bool ): - class AddMM(torch.nn.Module): - def __init__(self, alpha: float = 1, beta: float = 1): - super().__init__() - self.alpha = alpha - self.beta = beta - - def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): - return torch.addmm( - x, y, z.transpose(1, 0), alpha=self.alpha, beta=self.beta - ) - - # alpha, beta must be 1 to be 1 to enable ReplaceAddMMWithLinearPass - # get_attr will always turn into placeholders and mutable outputs in PT2 M, K, N, alpha, beta = 14, 48, 24, 1.0, 1.0 - x = torch.randn(N) - y = torch.randn(M, K) - z = torch.randn(N, K) - - # test addmm - model = AddMM(alpha=alpha, beta=beta) - graph_module = export_to_edge(model, (x, y, z)).exported_program().graph_module - - tp = ReplacePermuteWithTransposePass() - ap = ReplaceAddMMWithLinearPass() + builder = GraphBuilder() + x = builder.placeholder("x", torch.randn(N, dtype=torch.float32)) + y = builder.placeholder("y", torch.randn([M, K], dtype=torch.float32)) + z = builder.placeholder("z", torch.randn([N, K], dtype=torch.float32)) + permute_copy = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(z, [1, 0]), + ) + addmm = builder.call_operator( + op=exir_ops.edge.aten.addmm.default, + args=(x, y, permute_copy), + kwargs={"beta": beta, "alpha": alpha}, + ) + builder.output([addmm]) + original = builder.get_graph_module() + gm = cast(PassResult, ReplacePermuteWithTransposePass()(original)).graph_module graph_after_passes = cast( - PassResult, ap(cast(PassResult, tp(graph_module)).graph_module) + PassResult, ReplaceAddMMWithLinearPass()(gm) ).graph_module self.assertIsNotNone(graph_after_passes) - - self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.addmm.default), - 1, - ) - - # Assert that all the aten.addmm nodes are removed. self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1, @@ -707,37 +705,23 @@ def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): @torch.no_grad() def test_replace_mm_with_addmm(self): - # The mm ops will be convereted to addmm ops by Jarvis - class MM(torch.nn.Module): - def __init__(self, K, N): - super().__init__() - self.K = K - self.N = N - - def forward(self, y: torch.Tensor, z: torch.Tensor): - return torch.ops.aten.mm(y, z) - M, K, N = 14, 48, 24 - y = torch.randn(M, K) - z = torch.randn(K, N) - - # test addmm - model = MM(K, N) - graph_module = export_to_edge(model, (y, z)).exported_program().graph_module - - # First, replace the aten.mm with an aten.addmm op + x = torch.randn([M, K]) + y = torch.randn([K, N]) + original = single_op_builder( + placeholders=(x, y), + op=exir_ops.edge.aten.mm.default, + args=(x, y), + ) p = ReplaceMMWithAddMMPass() - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertIsNotNone(graph_after_passes) - - # Assert that all the aten.mm nodes are removed. self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.addmm.default), 1, ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.edge.aten.mm), + count_node(graph_after_passes, exir_ops.edge.aten.mm.default), 0, ) @@ -755,36 +739,36 @@ def forward(self, y: torch.Tensor, z: torch.Tensor): ) @torch.no_grad() def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): - # The squeeze ops will be convereted to view ops by Jarvis - class Squeeze(torch.nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x: torch.Tensor): - if self.dim is None: - return torch.squeeze(x) - return torch.squeeze(x, self.dim) - - model = Squeeze(dim) x = torch.randn(shape) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - - # First, replace the aten.squeeze_copy with an aten.view_copy op + if dim: + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.squeeze_copy.dim, + args=(x, dim), + ) + else: + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.squeeze_copy.default, + args=(x,), + ) p = ReplaceSqueezeAndUnsqueezeWithViewPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertIsNotNone(graph_after_passes) - - # Assert that all the aten.squeeze_copy nodes are removed. self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1, ) - self.assertEqual( - count_node(graph_after_passes, exir_ops.aten.squeeze_copy), - 0, - ) + if dim: + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.squeeze_copy.dim), + 0, + ) + else: + self.assertEqual( + count_node(graph_after_passes, exir_ops.edge.aten.squeeze_copy.default), + 0, + ) @parameterized.expand( [ @@ -797,32 +781,21 @@ def forward(self, x: torch.Tensor): ) @torch.no_grad() def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int): - class Unsqueeze(torch.nn.Module): - def __init__(self, dim): - super().__init__() - self.dim = dim - - def forward(self, x: torch.Tensor): - return torch.unsqueeze(x, self.dim) - - # Test that the pass works for all dims. - model = Unsqueeze(dim) - x = torch.randn(5, 6, 7) - graph_module = export_to_edge(model, (x,)).exported_program().graph_module - - # First, replace the aten.unsqueeze_copy with an aten.view_copy op + x = torch.randn(shape) + original = single_op_builder( + placeholders=(x,), + op=exir_ops.edge.aten.unsqueeze_copy.default, + args=(x, dim), + ) p = ReplaceSqueezeAndUnsqueezeWithViewPass() - - graph_after_passes = cast(PassResult, p(graph_module)).graph_module + graph_after_passes = cast(PassResult, p(original)).graph_module self.assertIsNotNone(graph_after_passes) - - # Assert that all the aten.unsqueeze_copy nodes are removed. self.assertEqual( count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1, ) self.assertEqual( - count_node(graph_after_passes, exir_ops.aten.unsqueeze_copy), + count_node(graph_after_passes, exir_ops.edge.aten.unsqueeze_copy.default), 0, )