Skip to content

TorchTensorRTModule Serialization Fix #3572

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 24, 2025
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import copy
import logging
import pickle
from typing import Any, List, Optional, Tuple, Union
from typing import Any, List, Optional, Set, Tuple, Union

import torch
from torch_tensorrt._Device import Device
Expand Down Expand Up @@ -227,6 +227,11 @@ def setup_engine(self) -> None:

def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
metadata["settings"].torch_executed_ops = (
TorchTensorRTModule.serialize_aten_ops(
metadata["settings"].torch_executed_ops
)
)
dumped_metadata = pickle.dumps(metadata)
encoded_metadata = base64.b64encode(dumped_metadata).decode("utf-8")
return encoded_metadata
Expand All @@ -235,8 +240,21 @@ def encode_metadata(self, metadata: Any) -> str:
def decode_metadata(encoded_metadata: bytes) -> Any:
dumped_metadata = base64.b64decode(encoded_metadata.encode("utf-8"))
metadata = pickle.loads(dumped_metadata)
metadata["settings"].torch_executed_ops = (
TorchTensorRTModule.deserialize_aten_ops(
metadata["settings"].torch_executed_ops
)
)
return metadata

@staticmethod
def serialize_aten_ops(aten_ops: Set[torch._ops.OpOverload]) -> Set[str]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call this _serialize_torch_ops

return {str(op) for op in aten_ops}

@staticmethod
def deserialize_aten_ops(aten_ops: Set[str]) -> Set[torch._ops.OpOverload]:
return {eval("torch.ops." + str(v)) for v in aten_ops}

def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt:
if self.engine:
return (
Expand Down
Loading