Open
Description
Bug Description
If you use torch_executed_ops
to run an op in PyTorch, it causes the runtime to fail to setup the engine.
To Reproduce
Steps to reproduce the behavior:
import os
import torch
import torch_tensorrt
import torchvision.models as models
# Initialize model with half precision and sample inputs
model = models.resnet18(pretrained=True).half().eval().to("cuda")
compile_spec = {
"inputs": [
torch_tensorrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(8, 3, 224, 224),
max_shape=(16, 3, 224, 224),
dtype=torch.half,
)
],
"ir": "dynamo",
"enabled_precisions": {torch.float32, torch.float16, torch.bfloat16},
"min_block_size": 1,
"debug": True,
"output_format": "exported_program",
"cache_built_engines": True,
"reuse_cached_engines": True,
"torch_executed_ops": {torch.ops.aten.matmul}
}
trt_model = torch_tensorrt.compile(model, **compile_spec)
Creates the following failure:
Traceback (most recent call last):
File "/home/naren/pytorch_org/tensorrt/examples/dynamo/aoti_resnet.py", line 56, in <module>
trt_model = torch_tensorrt.compile(model, **compile_spec)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/_compile.py", line 289, in compile
trt_graph_module = dynamo_compile(
^^^^^^^^^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/_compiler.py", line 712, in compile
trt_gm = compile_module(
^^^^^^^^^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/_compiler.py", line 918, in compile_module
trt_module = convert_module(
^^^^^^^^^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 108, in convert_module
return rt_cls(
^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/_features.py", line 68, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 148, in __init__
self.setup_engine()
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/_features.py", line 68, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 226, in setup_engine
self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/_features.py", line 68, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 182, in _pack_engine_info
engine_info[SERIALIZED_METADATA_IDX] = self.encode_metadata(metadata)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/_features.py", line 68, in wrapper
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 230, in encode_metadata
dumped_metadata = pickle.dumps(metadata)
^^^^^^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'PyCapsule' object
DEBUG Command exited with code: 1
Expected behavior
The module should be able to set up properly
Environment
Build information about Torch-TensorRT can be found by turning on debug messages
- Torch-TensorRT Version (e.g. 1.0.0): 2.8.0.dev0+f09be7245
- PyTorch Version (e.g. 1.0): 2.8.0.dev20250606+cu128
- CPU Architecture: x86_64
- OS (e.g., Linux): Linux
- How you installed PyTorch (
conda
,pip
,libtorch
, source): PyTorch index - Build command you used (if compiling from source): uv run
- Are you using local sources or building from archives:
- Python version: 3.11
- CUDA version: 12.8
- GPU models and configuration: 3080Ti
- Any other relevant information: