diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index 56fcda99a14..b7616b047d3 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -6,10 +6,9 @@ # pyre-unsafe -import logging import math import unittest -from typing import cast +from typing import cast, Optional import executorch.backends.cadence.aot.ops_registrations # noqa import torch @@ -24,7 +23,6 @@ get_default_memory_config, MemoryConfig, ) -from executorch.exir import memory from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.memory_planning import collect_specs_from_nodes from executorch.exir.passes.spec_prop_pass import SpecPropPass @@ -225,14 +223,23 @@ def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None: # Initializes the nodes metadata and runs the GenerateMemoryViewConstraints, # GenerateSliceAndSelectNopConstraints, and GenerateCatNopConstraints passes. def run_memory_planning( - self, original, opt_level=2, alloc_graph_input=True + self, + original, + opt_level=2, + mem_algo=1, # greedy_by_size_for_offset_calculation_with_hierarchy + alloc_graph_input=True, + alloc_graph_output=True, + memory_config: Optional[MemoryConfig] = None, ) -> GraphModule: + if memory_config is None: + memory_config = get_default_memory_config() graph_module = SpecPropPass().call(original).graph_module return CadenceMemoryPlanning( - get_default_memory_config(), + memory_config, opt_level=opt_level, - mem_algo=1, # greedy_by_size_for_offset_calculation_with_hierarchy + mem_algo=mem_algo, alloc_graph_input=alloc_graph_input, + alloc_graph_output=alloc_graph_output, )(graph_module).graph_module @parameterized.expand( @@ -241,10 +248,19 @@ def run_memory_planning( [3, 6], # x_shape [2, 6], # y_shape 0, # concat dim + False, # alloc_graph_input + ], + [ + [3, 6], # x_shape + [2, 6], # y_shape + 0, # concat dim + True, # alloc_graph_input ], ] ) - def test_optimize_cat_on_placeholders(self, x_shape, y_shape, concat_dim) -> None: + def test_optimize_cat_on_placeholders( + self, x_shape, y_shape, concat_dim, alloc_graph_input + ) -> None: concat_shape = [x_shape[concat_dim] + y_shape[concat_dim], x_shape[1]] builder = GraphBuilder() x = builder.placeholder("x", torch.ones(*x_shape)) @@ -262,12 +278,16 @@ def test_optimize_cat_on_placeholders(self, x_shape, y_shape, concat_dim) -> Non builder.output([graph_output]) original = builder.get_graph_module() - graph_module = self.run_memory_planning(original) + graph_module = self.run_memory_planning( + original, alloc_graph_input=alloc_graph_input + ) graph_module.graph.eliminate_dead_code() - # Assert that cat op is optimized away - self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) - # Assert that cat op is replaced by its nop version post optimization - self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) + if alloc_graph_input: + self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) + else: + self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0) self.verify_nop_memory_alloc(graph_module) # Returns a GraphModule with the following structure: @@ -473,7 +493,13 @@ def test_optimize_cat_with_slice(self) -> None: self.assertEqual(count_node(graph_module, exir_ops.edge.aten.slice.Tensor), 1) self.verify_nop_memory_alloc(graph_module) - def test_optimize_cat_with_slice_infeasible(self) -> None: + @parameterized.expand( + [ + (True,), # alloc_graph_input + (False,), # alloc_graph_input + ], + ) + def test_optimize_cat_with_slice_infeasible(self, alloc_graph_input) -> None: x_shape = [5, 6] y_shape = [3, 6] concated_shape = [8, 6] @@ -527,14 +553,20 @@ def test_optimize_cat_with_slice_infeasible(self) -> None: ) builder.output([cat]) original = builder.get_graph_module() - graph_module = self.run_memory_planning(original, alloc_graph_input=False) - graph_module.graph.eliminate_dead_code() - # # Assert that slice op is optimized away. - self.assertEqual( - count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 1 + graph_module = self.run_memory_planning( + original, opt_level=3, alloc_graph_input=alloc_graph_input ) - # # Assert that cat op is not optimized away - self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) + graph_module.graph.eliminate_dead_code() + if alloc_graph_input: + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0) + self.assertEqual( + count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 1 + ) + else: + self.assertEqual( + count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 1 + ) + self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) self.verify_nop_memory_alloc(graph_module) def test_optimize_slice_outermost(self) -> None: @@ -773,250 +805,123 @@ def test_optimize_select_depending_on_opt_level(self) -> None: ) self.verify_nop_memory_alloc(graph_module) - # TODO: Test fails due to memory planning - @unittest.expectedFailure - def test_optimize_cat_with_param(self) -> None: - class CatWithPadding(torch.nn.Module): - def __init__(self, padding_shape): - super().__init__() - zeros = torch.zeros(padding_shape) - self.register_buffer("padding", zeros) - - def forward(self, x, y): - x1 = torch.add(x, 2.4, 3.1) - y1 = torch.add(y, 1, 2) - # Cat along the outermost dimension cannot be optimized away - # because padding is a param - return torch.ops.aten.cat((x1, y1, self.padding)) - - x = torch.ones(3, 5) - y = torch.ones(2, 5) - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - CatWithPadding((1, 5)), (x, y), opt_level=2 - ) - .exported_program() - .graph_module - ) - graph_module.graph.eliminate_dead_code() - # Assert that cat op is not optimized away - self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1) - self.verify_nop_memory_alloc(graph_module) - def test_optimize_cat_then_slice_on_mutable_buffer(self) -> None: - class CatWithPadding(torch.nn.Module): - def __init__(self, padding_shape): - super().__init__() - zeros = torch.zeros(padding_shape) - self.register_buffer("padding", zeros) - - def forward(self, x, y): - x = x.view(3, 5) - cat = torch.ops.aten.cat((x, self.padding.clone())) - slice_copy = torch.ops.aten.slice(cat, dim=0, start=x.shape[0]) - self.padding.copy_(slice_copy) - return cat.view(-1) + y - - x = torch.ones(15) - y = torch.ones(1) - et_prog_manager = compiler.export_to_executorch_gen_etrecord( - CatWithPadding((1, 5)), (x, y), opt_level=3 + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(3, 6, dtype=torch.float32)) + y = builder.placeholder("y", torch.ones(1, 6, dtype=torch.float32)) + pre_created_output = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([4, 6], 0.0), + kwargs={"dtype": torch.float32}, ) - graph_module = et_prog_manager.exported_program().graph_module - logging.info(f"graph_module: {graph_module.print_readable(print_output=False)}") - self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) - self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) - self.verify_nop_memory_alloc(graph_module) - - def test_optimize_cat_with_view(self) -> None: - class CatViewFeasible(torch.nn.Module): - def forward(self, x, y): - x1 = torch.add(x, 2.4, 3.1) - x2 = x1.view((5, 3)) - y1 = torch.add(y, 2.4, 3.1) - y2 = y1.view((2, 3)) - # Cat can be optimized away since x2 and y2 are not mem-equivalent - return torch.ops.aten.cat((y2, x2)) - - x = torch.ones(3, 5) - y = torch.ones(3, 2) - # Optimizing cat ops is only at opt_level 2+, and requires the memory planning - # pass to run: - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - CatViewFeasible(), (x, y), opt_level=2, mem_algo=1 - ) - .exported_program() - .graph_module + cat = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([x, y],), + kwargs={"dim": 0, "out": pre_created_output}, + ) + slice_out = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([1, 6], 0.0), + kwargs={"dtype": torch.float32}, ) + slice_result = builder.call_operator( + op=torch.ops.aten.slice_copy.Tensor_out, + args=( + cat, + 0, # dim + 3, # start + 4, # end + 1, # step + ), + kwargs={"out": slice_out}, + ) + builder.output([slice_result]) + original = builder.get_graph_module() + graph_module = self.run_memory_planning(original, opt_level=3) graph_module.graph.eliminate_dead_code() - # Assert that cat op is optimized away - self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) self.verify_nop_memory_alloc(graph_module) - def test_no_optimize_cat_with_repeated_args(self) -> None: - class CatViewInfeasible(torch.nn.Module): - def forward(self, x): - x1 = torch.add(x, 2.4, 3.1) - # Repeat will be decomposed into a cat. The cat cannot be optimized - # away since all its args are mem-equivalent - return torch.ops.aten.repeat(x1, [1, 2]) - - x = torch.ones(3, 5) - # Optimizing cat ops is only at opt_level 2+, and requires the memory planning - # pass to run: - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - CatViewInfeasible(), (x,), opt_level=2, mem_algo=1 - ) - .exported_program() - .graph_module + def test_cat_then_cat(self) -> None: + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(16, 16, dtype=torch.float32)) + to_add_to_x = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([16, 16], 1.0), + kwargs={"dtype": torch.float32}, ) - graph_module.graph.eliminate_dead_code() - # Assert that cat op is not optimized away - self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) - self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0) - self.verify_nop_memory_alloc(graph_module) - - def test_no_optimize_cat_with_placeholder(self) -> None: - class CatViewInfeasible(torch.nn.Module): - def forward(self, x, y): - # Repeat will be decomposed into a cat. The cat cannot be optimized - # away since all its args are mem-equivalent - return torch.cat((x, y), dim=0) - - x = torch.ones(3, 5) - y = torch.ones(2, 5) - # Optimizing cat ops is only at opt_level 2+, and requires the memory planning - # pass to run: - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - CatViewInfeasible(), - (x, y), - opt_level=2, - mem_algo=1, - alloc_graph_input=False, - ) - .exported_program() - .graph_module + x1 = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x, to_add_to_x), ) - graph_module.graph.eliminate_dead_code() - # Assert that cat op is not optimized away - self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1) - self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0) - self.verify_nop_memory_alloc(graph_module) - - def test_no_optimize_cat(self) -> None: - class Model(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x) -> torch.Tensor: - x0 = torch.slice_copy(x, dim=0, start=0, end=4) - x0 = x0.view(-1) - x1 = torch.slice_copy(x, dim=0, start=4, end=8) - x1 = x1.view(-1) - return torch.cat((x0, x1), dim=0) - - model = Model() - inputs = (torch.randn(16, 16),) - - # Check that both view ops and slice copy are optimized. - # We can't optimize cat op in this case. - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - model, inputs, opt_level=3, alloc_graph_input=True - ) - .exported_program() - .graph_module + x2 = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x1, to_add_to_x), ) - self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0) - self.assertEqual( - count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 2 + x3 = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x2, to_add_to_x), ) - self.assertEqual(count_node(graph_module, memory.view), 2) - self.verify_nop_memory_alloc(graph_module) - - def test_optimize_slice_copy(self) -> None: - class Model(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x) -> torch.Tensor: - x0 = torch.slice_copy(x, dim=0, start=0, end=4) - x0 = x0.view(-1) - x1 = torch.slice_copy(x, dim=0, start=4, end=8) - x1 = x1.view(-1) - return torch.cat((x0, x1), dim=0) - - model = Model() - inputs = (torch.randn(16, 16),) - - # Check that view ops and cat are optimized. - # We can't optimize slice_copy op in this case. - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - model, inputs, opt_level=3, alloc_graph_input=False - ) - .exported_program() - .graph_module + pre_created_output1 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([32, 16], 0.0), + kwargs={"dtype": torch.float32}, ) - self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1) - self.assertEqual( - count_node(graph_module, torch.ops.aten._slice_copy_nop.Tensor_out), 0 + cat1 = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([x1, x2],), + kwargs={"dim": 0, "out": pre_created_output1}, ) - self.assertEqual(count_node(graph_module, memory.view), 2) - self.verify_nop_memory_alloc(graph_module) - - def test_cat_then_cat(self) -> None: - class Model(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x) -> torch.Tensor: - x1 = x + 1 - x2 = x1 + 1 - x3 = x2 + 1 - return torch.cat((torch.cat((x1, x2), dim=0), x3), dim=0) - - model = Model() - inputs = (torch.randn(16, 16),) - - # Check that both the cat ops can be optimized. - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - model, inputs, opt_level=3, alloc_graph_input=False - ) - .exported_program() - .graph_module + pre_created_output2 = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([32, 16], 0.0), + kwargs={"dtype": torch.float32}, + ) + cat2 = builder.call_operator( + op=torch.ops.aten.cat.out, + args=([cat1, x3],), + kwargs={"dim": 0, "out": pre_created_output2}, ) + builder.output([cat2]) + original = builder.get_graph_module() + graph_module = self.run_memory_planning( + original, opt_level=3, alloc_graph_input=False + ) + self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 2) self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0) self.verify_nop_memory_alloc(graph_module) def test_view_for_unallocated_output(self) -> None: - class Model(torch.nn.Module): - def __init__(self, padding_shape): - super().__init__() - - def forward(self, x, y): - x = x + 1 - # x_view will be a memory.view. - x_view = torch.ops.aten.view_copy(x, [15]) - return x, x_view + y - - x = torch.ones(3, 5) - y = torch.ones(15) - # Check that memory planning passes for unallocated output `x`. - graph_module = ( - compiler.export_to_executorch_gen_etrecord( - Model((1, 5)), (x, y), opt_level=2, alloc_graph_output=False - ) - .exported_program() - .graph_module + builder = GraphBuilder() + x = builder.placeholder("x", torch.ones(3, 5, dtype=torch.float32)) + y = builder.placeholder("y", torch.ones(15, dtype=torch.float32)) + to_add_to_x = builder.call_operator( + op=exir_ops.edge.aten.full.default, + args=([3, 5], 1.0), + kwargs={"dtype": torch.float32}, + ) + add_x = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x, to_add_to_x), + ) + add_x_view = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(add_x, [15]), + ) + add_x_y = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(add_x_view, y), + ) + builder.output([add_x, add_x_y]) + original = builder.get_graph_module() + graph_module = self.run_memory_planning( + original, opt_level=2, alloc_graph_output=False + ) + self.assertEqual( + count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1 ) - self.assertEqual(count_node(graph_module, memory.view), 1) self.verify_nop_memory_alloc(graph_module) def test_start_alignment_constraints(self) -> None: