Skip to content

Commit 0e9d884

Browse files
authored
TorchTensorRTModule Serialization Fix (#3572)
1 parent 29f59f9 commit 0e9d884

File tree

3 files changed

+85
-1
lines changed

3 files changed

+85
-1
lines changed

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from dataclasses import dataclass, field
2-
from typing import Collection, Optional, Set, Tuple, Union
2+
from typing import Any, Collection, Optional, Set, Tuple, Union
33

44
from torch.fx.node import Target
55
from torch_tensorrt._Device import Device
@@ -141,6 +141,21 @@ class CompilationSettings:
141141
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
142142
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
143143

144+
def __getstate__(self) -> dict[str, Any]:
145+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
146+
ConverterRegistry,
147+
)
148+
149+
state = self.__dict__.copy()
150+
state["torch_executed_ops"] = {
151+
op if isinstance(op, str) else ConverterRegistry.qualified_name_or_str(op)
152+
for op in state["torch_executed_ops"]
153+
}
154+
return state
155+
156+
def __setstate__(self, state: dict[str, Any]) -> None:
157+
self.__dict__.update(state)
158+
144159

145160
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
146161
"enabled_precisions",

tests/py/dynamo/models/test_export_serde.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,38 @@ def test_resnet18_dynamic(ir):
386386
)
387387

388388

389+
@unittest.skipIf(
390+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
391+
)
392+
def test_resnet18_torch_exec_ops_serde(ir):
393+
"""
394+
This tests export save and load functionality on Resnet18 model
395+
"""
396+
model = models.resnet18().eval().cuda()
397+
input = torch.randn((1, 3, 224, 224)).to("cuda")
398+
399+
compile_spec = {
400+
"inputs": [input],
401+
"ir": ir,
402+
"min_block_size": 1,
403+
"cache_built_engines": False,
404+
"reuse_cached_engines": False,
405+
"torch_executed_ops": {torch.ops.aten.addmm, "torch.ops.aten.add"},
406+
}
407+
408+
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
409+
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
410+
torchtrt.save(trt_module, trt_ep_path)
411+
deser_trt_module = torchtrt.load(trt_ep_path).module()
412+
outputs_pyt = deser_trt_module(input)
413+
outputs_trt = trt_module(input)
414+
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
415+
assertions.assertTrue(
416+
cos_sim > COSINE_THRESHOLD,
417+
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
418+
)
419+
420+
389421
@pytest.mark.unit
390422
def test_hybrid_conv_fallback(ir):
391423
"""

tests/py/dynamo/models/test_models.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,43 @@ def test_resnet18_cpu_offload(ir):
9494
torch._dynamo.reset()
9595

9696

97+
@unittest.skipIf(
98+
not importlib.util.find_spec("torchvision"), "torchvision not installed"
99+
)
100+
def test_resnet18_torch_exec_ops(ir):
101+
model = models.resnet18(pretrained=True).eval().to("cuda")
102+
input = torch.randn((1, 3, 224, 224)).to("cuda")
103+
104+
compile_spec = {
105+
"inputs": [
106+
torchtrt.Input(
107+
min_shape=(1, 3, 224, 224),
108+
opt_shape=(8, 3, 224, 224),
109+
max_shape=(16, 3, 224, 224),
110+
dtype=torch.float32,
111+
)
112+
],
113+
"ir": ir,
114+
"enabled_precisions": {torch.float32, torch.float16, torch.bfloat16},
115+
"min_block_size": 1,
116+
"debug": True,
117+
"output_format": "exported_program",
118+
"cache_built_engines": True,
119+
"reuse_cached_engines": True,
120+
"torch_executed_ops": {torch.ops.aten.matmul, "torch.ops.aten.add"},
121+
}
122+
123+
trt_mod = torchtrt.compile(model, **compile_spec)
124+
cos_sim = cosine_similarity(model(input), trt_mod(input))
125+
assertions.assertTrue(
126+
cos_sim > COSINE_THRESHOLD,
127+
msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
128+
)
129+
130+
# Clean up model env
131+
torch._dynamo.reset()
132+
133+
97134
@pytest.mark.unit
98135
@unittest.skipIf(
99136
not importlib.util.find_spec("torchvision"),

0 commit comments

Comments
 (0)