Skip to content

Commit 2d7c87f

Browse files
committed
Added the string case
1 parent 807cfeb commit 2d7c87f

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from dataclasses import dataclass, field
22
from typing import Any, Collection, Optional, Set, Tuple, Union
33

4-
import torch
54
from torch.fx.node import Target
65
from torch_tensorrt._Device import Device
76
from torch_tensorrt._enums import EngineCapability, dtype
@@ -145,14 +144,19 @@ class CompilationSettings:
145144
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
146145

147146
def __getstate__(self) -> dict[str, Any]:
147+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
148+
ConverterRegistry,
149+
)
150+
148151
state = self.__dict__.copy()
149-
state["torch_executed_ops"] = {str(op) for op in state["torch_executed_ops"]}
152+
state["torch_executed_ops"] = {
153+
op if isinstance(op, str) else ConverterRegistry.qualified_name_or_str(op)
154+
for op in state["torch_executed_ops"]
155+
}
150156
return state
151157

152158
def __setstate__(self, state: dict[str, Any]) -> None:
153159
self.__dict__.update(state)
154-
ops_str = self.torch_executed_ops
155-
self.torch_executed_ops = {getattr(torch.ops, op) for op in ops_str}
156160

157161

158162
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

tests/py/dynamo/models/test_models.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,41 @@ def test_resnet18_cpu_offload(ir):
8484
torch._dynamo.reset()
8585

8686

87+
@pytest.mark.unit
88+
def test_resnet18_torch_exec_ops(ir):
89+
model = models.resnet18(pretrained=True).eval().to("cuda")
90+
input = torch.randn((1, 3, 224, 224)).to("cuda")
91+
92+
compile_spec = {
93+
"inputs": [
94+
torchtrt.Input(
95+
min_shape=(1, 3, 224, 224),
96+
opt_shape=(8, 3, 224, 224),
97+
max_shape=(16, 3, 224, 224),
98+
dtype=torch.float32,
99+
)
100+
],
101+
"ir": ir,
102+
"enabled_precisions": {torch.float32, torch.float16, torch.bfloat16},
103+
"min_block_size": 1,
104+
"debug": True,
105+
"output_format": "exported_program",
106+
"cache_built_engines": True,
107+
"reuse_cached_engines": True,
108+
"torch_executed_ops": {torch.ops.aten.matmul, "torch.ops.aten.add"},
109+
}
110+
111+
trt_mod = torchtrt.compile(model, **compile_spec)
112+
cos_sim = cosine_similarity(model(input), trt_mod(input))
113+
assertions.assertTrue(
114+
cos_sim > COSINE_THRESHOLD,
115+
msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
116+
)
117+
118+
# Clean up model env
119+
torch._dynamo.reset()
120+
121+
87122
@pytest.mark.unit
88123
def test_mobilenet_v2(ir):
89124
model = models.mobilenet_v2(pretrained=True).eval().to("cuda")

0 commit comments

Comments
 (0)