From 0f0463a7f38e194137feda1fb3bf2ba41c9934b8 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Wed, 6 Mar 2024 07:53:43 -0800 Subject: [PATCH] Fix tests. Summary: as title. Differential Revision: D54588001 --- exir/backend/test/test_backends_lifted.py | 54 ++++++++++++++--------- exir/program/test/test_program.py | 12 ++++- 2 files changed, 43 insertions(+), 23 deletions(-) diff --git a/exir/backend/test/test_backends_lifted.py b/exir/backend/test/test_backends_lifted.py index c712ab19c3a..e219198027c 100644 --- a/exir/backend/test/test_backends_lifted.py +++ b/exir/backend/test/test_backends_lifted.py @@ -1012,17 +1012,19 @@ def false_fn(x, y): x = x - y return x - def f(x, y): - x = x + y - x = control_flow.cond(x[0][0] == 1, true_fn, false_fn, [x, y]) - x = x - y - return x + class Module(torch.nn.Module): + def forward(self, x, y): + x = x + y + x = control_flow.cond(x[0][0] == 1, true_fn, false_fn, [x, y]) + x = x - y + return x + f = Module() inputs = (torch.ones(2, 2), torch.ones(2, 2)) orig_res = f(*inputs) orig = to_edge( export( - torch.export.WrapperModule(f), + f, inputs, ) ) @@ -1066,15 +1068,17 @@ def map_fn(x, y): x = x + y return x - def f(xs, y): - y = torch.mm(y, y) - return control_flow.map(map_fn, xs, y) + class Module(torch.nn.Module): + def forward(self, xs, y): + y = torch.mm(y, y) + return control_flow.map(map_fn, xs, y) + f = Module() inputs = (torch.ones(2, 2), torch.ones(2, 2)) orig_res = f(*inputs) orig = to_edge( export( - torch.export.WrapperModule(f), + f, inputs, ) ) @@ -1132,9 +1136,10 @@ def map_fn(x, pred1, pred2, y): x = x + y return x.sin() - def f(xs, pred1, pred2, y): - y = torch.mm(y, y) - return control_flow.map(map_fn, xs, pred1, pred2, y) + class Module(torch.nn.Module): + def forward(self, xs, pred1, pred2, y): + y = torch.mm(y, y) + return control_flow.map(map_fn, xs, pred1, pred2, y) inputs = ( torch.ones(2, 2), @@ -1143,10 +1148,11 @@ def f(xs, pred1, pred2, y): torch.ones(2, 2), ) + f = Module() orig_res = f(*inputs) orig = to_edge( export( - torch.export.WrapperModule(f), + f, inputs, ) ) @@ -1205,12 +1211,14 @@ def f(xs, pred1, pred2, y): ) def test_list_input(self): - def f(x: List[torch.Tensor]): - y = x[0] + x[1] - return y + class Module(torch.nn.Module): + def forward(self, x: List[torch.Tensor]): + y = x[0] + x[1] + return y + f = Module() inputs = ([torch.randn(2, 2), torch.randn(2, 2)],) - edge_prog = to_edge(export(torch.export.WrapperModule(f), inputs)) + edge_prog = to_edge(export(f, inputs)) lowered_gm = to_backend( BackendWithCompilerDemo.__name__, edge_prog.exported_program(), [] ) @@ -1227,12 +1235,14 @@ def forward(self, x: List[torch.Tensor]): gm.exported_program().module()(*inputs) def test_dict_input(self): - def f(x: Dict[str, torch.Tensor]): - y = x["a"] + x["b"] - return y + class Module(torch.nn.Module): + def forward(self, x: Dict[str, torch.Tensor]): + y = x["a"] + x["b"] + return y + f = Module() inputs = ({"a": torch.randn(2, 2), "b": torch.randn(2, 2)},) - edge_prog = to_edge(export(torch.export.WrapperModule(f), inputs)) + edge_prog = to_edge(export(f, inputs)) lowered_gm = to_backend( BackendWithCompilerDemo.__name__, edge_prog.exported_program(), [] ) diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 8c2ddddb7c2..01de1f3befd 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -30,6 +30,16 @@ from torch.library import impl, Library + +class WrapperModule(torch.nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + lib = Library("test_op", "DEF") # Fake a operator for testing. @@ -374,7 +384,7 @@ def _test_edge_dialect_verifier(self, callable, validate_ir=True): two, ) if not isinstance(callable, torch.nn.Module): - callable = torch.export.WrapperModule(callable) + callable = WrapperModule(callable) exported_foo = export(callable, inputs) _ = to_edge(exported_foo, compile_config=edge_compile_config)