Skip to content

Commit 6bf3127

Browse files
authored
_TorchTensorRTModule Serialization Fix
1 parent 60863a3 commit 6bf3127

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import copy
55
import logging
66
import pickle
7-
from typing import Any, List, Optional, Tuple, Union
7+
from typing import Any, List, Optional, Set, Tuple, Union
88

99
import torch
1010
from torch_tensorrt._Device import Device
@@ -227,6 +227,7 @@ def setup_engine(self) -> None:
227227

228228
def encode_metadata(self, metadata: Any) -> str:
229229
metadata = copy.deepcopy(metadata)
230+
metadata["settings"].torch_executed_ops = TorchTensorRTModule.serialize_aten_ops(metadata["settings"].torch_executed_ops)
230231
dumped_metadata = pickle.dumps(metadata)
231232
encoded_metadata = base64.b64encode(dumped_metadata).decode("utf-8")
232233
return encoded_metadata
@@ -235,8 +236,17 @@ def encode_metadata(self, metadata: Any) -> str:
235236
def decode_metadata(encoded_metadata: bytes) -> Any:
236237
dumped_metadata = base64.b64decode(encoded_metadata.encode("utf-8"))
237238
metadata = pickle.loads(dumped_metadata)
239+
metadata["settings"].torch_executed_ops = TorchTensorRTModule.deserialize_aten_ops(metadata["settings"].torch_executed_ops)
238240
return metadata
239241

242+
@staticmethod
243+
def serialize_aten_ops(aten_ops: Set[torch._ops.OpOverload]) -> Set[str]:
244+
return {str(op) for op in aten_ops}
245+
246+
@staticmethod
247+
def deserialize_aten_ops(aten_ops: Set[str]) -> Set[torch._ops.OpOverload]:
248+
return {eval("torch.ops." + str(v)) for v in aten_ops}
249+
240250
def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt:
241251
if self.engine:
242252
return (

0 commit comments

Comments
 (0)