Skip to content

Commit 807cfeb

Browse files
committed
Changed the way to deserialize and move it to settings.py
1 parent f9235ed commit 807cfeb

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

py/torch_tensorrt/dynamo/_settings.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from dataclasses import dataclass, field
2-
from typing import Collection, Optional, Set, Tuple, Union
2+
from typing import Any, Collection, Optional, Set, Tuple, Union
33

4+
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt._Device import Device
67
from torch_tensorrt._enums import EngineCapability, dtype
@@ -143,6 +144,16 @@ class CompilationSettings:
143144
use_distributed_mode_trace: bool = USE_DISTRIBUTED_MODE_TRACE
144145
offload_module_to_cpu: bool = OFFLOAD_MODULE_TO_CPU
145146

147+
def __getstate__(self) -> dict[str, Any]:
148+
state = self.__dict__.copy()
149+
state["torch_executed_ops"] = {str(op) for op in state["torch_executed_ops"]}
150+
return state
151+
152+
def __setstate__(self, state: dict[str, Any]) -> None:
153+
self.__dict__.update(state)
154+
ops_str = self.torch_executed_ops
155+
self.torch_executed_ops = {getattr(torch.ops, op) for op in ops_str}
156+
146157

147158
_SETTINGS_TO_BE_ENGINE_INVARIANT = (
148159
"enabled_precisions",

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import logging
66
import pickle
7-
from typing import Any, List, Optional, Set, Tuple, Union
7+
from typing import Any, List, Optional, Tuple, Union
88

99
import torch
1010
from torch_tensorrt._Device import Device
@@ -227,11 +227,6 @@ def setup_engine(self) -> None:
227227

228228
def encode_metadata(self, metadata: Any) -> str:
229229
metadata = copy.deepcopy(metadata)
230-
metadata["settings"].torch_executed_ops = (
231-
TorchTensorRTModule.serialize_aten_ops(
232-
metadata["settings"].torch_executed_ops
233-
)
234-
)
235230
dumped_metadata = pickle.dumps(metadata)
236231
encoded_metadata = base64.b64encode(dumped_metadata).decode("utf-8")
237232
return encoded_metadata
@@ -240,21 +235,8 @@ def encode_metadata(self, metadata: Any) -> str:
240235
def decode_metadata(encoded_metadata: bytes) -> Any:
241236
dumped_metadata = base64.b64decode(encoded_metadata.encode("utf-8"))
242237
metadata = pickle.loads(dumped_metadata)
243-
metadata["settings"].torch_executed_ops = (
244-
TorchTensorRTModule.deserialize_aten_ops(
245-
metadata["settings"].torch_executed_ops
246-
)
247-
)
248238
return metadata
249239

250-
@staticmethod
251-
def serialize_aten_ops(aten_ops: Set[torch._ops.OpOverload]) -> Set[str]:
252-
return {str(op) for op in aten_ops}
253-
254-
@staticmethod
255-
def deserialize_aten_ops(aten_ops: Set[str]) -> Set[torch._ops.OpOverload]:
256-
return {eval("torch.ops." + str(v)) for v in aten_ops}
257-
258240
def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt:
259241
if self.engine:
260242
return (

0 commit comments

Comments
 (0)