4
4
import copy
5
5
import logging
6
6
import pickle
7
- from typing import Any , List , Optional , Set , Tuple , Union
7
+ from typing import Any , List , Optional , Tuple , Union
8
8
9
9
import torch
10
10
from torch_tensorrt ._Device import Device
@@ -227,11 +227,6 @@ def setup_engine(self) -> None:
227
227
228
228
def encode_metadata (self , metadata : Any ) -> str :
229
229
metadata = copy .deepcopy (metadata )
230
- metadata ["settings" ].torch_executed_ops = (
231
- TorchTensorRTModule .serialize_aten_ops (
232
- metadata ["settings" ].torch_executed_ops
233
- )
234
- )
235
230
dumped_metadata = pickle .dumps (metadata )
236
231
encoded_metadata = base64 .b64encode (dumped_metadata ).decode ("utf-8" )
237
232
return encoded_metadata
@@ -240,21 +235,8 @@ def encode_metadata(self, metadata: Any) -> str:
240
235
def decode_metadata (encoded_metadata : bytes ) -> Any :
241
236
dumped_metadata = base64 .b64decode (encoded_metadata .encode ("utf-8" ))
242
237
metadata = pickle .loads (dumped_metadata )
243
- metadata ["settings" ].torch_executed_ops = (
244
- TorchTensorRTModule .deserialize_aten_ops (
245
- metadata ["settings" ].torch_executed_ops
246
- )
247
- )
248
238
return metadata
249
239
250
- @staticmethod
251
- def serialize_aten_ops (aten_ops : Set [torch ._ops .OpOverload ]) -> Set [str ]:
252
- return {str (op ) for op in aten_ops }
253
-
254
- @staticmethod
255
- def deserialize_aten_ops (aten_ops : Set [str ]) -> Set [torch ._ops .OpOverload ]:
256
- return {eval ("torch.ops." + str (v )) for v in aten_ops }
257
-
258
240
def get_extra_state (self ) -> SerializedTorchTensorRTModuleFmt :
259
241
if self .engine :
260
242
return (
0 commit comments