From 0adf1fac1b0f15aba5638c1afbdcdcd9ad750a97 Mon Sep 17 00:00:00 2001 From: Guang Yang Date: Thu, 3 Aug 2023 18:33:54 -0700 Subject: [PATCH] Fix export config in examples (#34) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/34 ## Context It doesn't look correct that want to export models with one config but test the export with a different one. ## This DIff - Ensure canonical config is used for both export and tests - Ensure the test is loading the lowered model through `_load_for_executorch_from_buffer` for output consistency checking - Remove confusion about certain export flags by adding more inline comments to those flags, e.g. why is `enable_dynamic_shape` there not enabled by default, what does `enable_aot` do with dynamic shapes, etc. Reviewed By: JacobSzwejbka Differential Revision: D48018569 fbshipit-source-id: e9dd3747a332c08d4a85ce365c1aa3317fe60fb4 --- examples/export/test/TARGETS | 1 + examples/export/test/test_export.py | 15 +++++++++++---- examples/export/utils.py | 2 +- exir/capture/_config.py | 4 ++-- 4 files changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/export/test/TARGETS b/examples/export/test/TARGETS index c440f9a2037..1acac1cb034 100644 --- a/examples/export/test/TARGETS +++ b/examples/export/test/TARGETS @@ -10,5 +10,6 @@ python_unittest( "//executorch/examples/export:utils", "//executorch/examples/models:models", "//executorch/exir:lib", + "//executorch/extension/pybindings:portable", # @manual ], ) diff --git a/examples/export/test/test_export.py b/examples/export/test/test_export.py index 64bb69f57b7..35655b9dfd2 100644 --- a/examples/export/test/test_export.py +++ b/examples/export/test/test_export.py @@ -8,9 +8,14 @@ import torch -from executorch.examples.export.utils import _EDGE_COMPILE_CONFIG +from executorch.examples.export.utils import _CAPTURE_CONFIG, _EDGE_COMPILE_CONFIG from executorch.examples.models import MODEL_NAME_TO_MODEL +# pyre-ignore[21]: Could not find module `executorch.extension.pybindings.portable`. +from executorch.extension.pybindings.portable import ( # @manual + _load_for_executorch_from_buffer, +) + class ExportTest(unittest.TestCase): def _assert_eager_lowered_same_result( @@ -18,16 +23,18 @@ def _assert_eager_lowered_same_result( ): import executorch.exir as exir - capture_config = exir.CaptureConfig(enable_dynamic_shape=False) - edge_model = exir.capture(eager_model, example_inputs, capture_config).to_edge( + edge_model = exir.capture(eager_model, example_inputs, _CAPTURE_CONFIG).to_edge( _EDGE_COMPILE_CONFIG ) executorch_model = edge_model.to_executorch() + # pyre-ignore + pte_model = _load_for_executorch_from_buffer(executorch_model.buffer) + with torch.no_grad(): eager_output = eager_model(*example_inputs) with torch.no_grad(): - executorch_output = executorch_model.graph_module(*example_inputs) + executorch_output = pte_model.forward(example_inputs) self.assertTrue( torch.allclose(eager_output, executorch_output[0], rtol=1e-5, atol=1e-5) ) diff --git a/examples/export/utils.py b/examples/export/utils.py index 1a38acf984a..0255f64509f 100644 --- a/examples/export/utils.py +++ b/examples/export/utils.py @@ -11,7 +11,7 @@ # Reason is that there memory allocation ops with symbolic shape nodes. # and when evaulating shape, it doesnt seem that we presenting them with shape env # that contain those variables. -_CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True, _unlift=False) +_CAPTURE_CONFIG = exir.CaptureConfig(enable_aot=True) _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig( _check_ir_validity=False, ) diff --git a/exir/capture/_config.py b/exir/capture/_config.py index 7e008268b06..2b23e1d9f81 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -19,8 +19,8 @@ class CaptureConfig: pt2_mode: bool = True enable_functionalization: bool = True - enable_dynamic_shape: bool = False - enable_aot: bool = False + enable_dynamic_shape: bool = False # This flag does nothing if enable_aot is True + enable_aot: bool = False # When it's true it implies automatic dynamic shapes via default dynamo config _dynamo_config: "ExirDynamoConfig" = field(default_factory=ExirDynamoConfig) _unlift: bool = False _use_old_decomp_table: bool = False