diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 97c02f34fb..6a55e6455d 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Collection, Optional, Set, Tuple, Union +from typing import Any, Collection, Optional, Set, Tuple, Union from torch.fx.node import Target from torch_tensorrt._Device import Device @@ -143,6 +143,21 @@ class CompilationSettings: use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU + def __getstate__(self) -> dict[str, Any]: + from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + ConverterRegistry, + ) + + state = self.__dict__.copy() + state["torch_executed_ops"] = { + op if isinstance(op, str) else ConverterRegistry.qualified_name_or_str(op) + for op in state["torch_executed_ops"] + } + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + self.__dict__.update(state) + _SETTINGS_TO_BE_ENGINE_INVARIANT = ( "enabled_precisions", diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 52e5eefb63..91ed2a2b99 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -1,3 +1,4 @@ +import importlib import os import tempfile import unittest @@ -372,6 +373,38 @@ def test_resnet18_dynamic(ir): ) +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), "torchvision not installed" +) +def test_resnet18_torch_exec_ops_serde(ir): + """ + This tests export save and load functionality on Resnet18 model + """ + model = models.resnet18().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [input], + "ir": ir, + "min_block_size": 1, + "cache_built_engines": False, + "reuse_cached_engines": False, + "torch_executed_ops": {torch.ops.aten.addmm, "torch.ops.aten.add"}, + } + + exp_program = torchtrt.dynamo.trace(model, **compile_spec) + trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) + torchtrt.save(trt_module, trt_ep_path) + deser_trt_module = torchtrt.load(trt_ep_path).module() + outputs_pyt = deser_trt_module(input) + outputs_trt = trt_module(input) + cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + @pytest.mark.unit def test_hybrid_conv_fallback(ir): """ diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index b0ebbf5fa4..8a9a0c3775 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -84,6 +84,43 @@ def test_resnet18_cpu_offload(ir): torch._dynamo.reset() +@unittest.skipIf( + not importlib.util.find_spec("torchvision"), "torchvision not installed" +) +def test_resnet18_torch_exec_ops(ir): + model = models.resnet18(pretrained=True).eval().to("cuda") + input = torch.randn((1, 3, 224, 224)).to("cuda") + + compile_spec = { + "inputs": [ + torchtrt.Input( + min_shape=(1, 3, 224, 224), + opt_shape=(8, 3, 224, 224), + max_shape=(16, 3, 224, 224), + dtype=torch.float32, + ) + ], + "ir": ir, + "enabled_precisions": {torch.float32, torch.float16, torch.bfloat16}, + "min_block_size": 1, + "debug": True, + "output_format": "exported_program", + "cache_built_engines": True, + "reuse_cached_engines": True, + "torch_executed_ops": {torch.ops.aten.matmul, "torch.ops.aten.add"}, + } + + trt_mod = torchtrt.compile(model, **compile_spec) + cos_sim = cosine_similarity(model(input), trt_mod(input)) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + # Clean up model env + torch._dynamo.reset() + + @pytest.mark.unit def test_mobilenet_v2(ir): model = models.mobilenet_v2(pretrained=True).eval().to("cuda")