Skip to content

🐛 [Bug] Using torch_executed_ops causes the metadata packing to fail #3566

Open
@narendasan

Description

@narendasan

Bug Description

If you use torch_executed_ops to run an op in PyTorch, it causes the runtime to fail to setup the engine.

To Reproduce

Steps to reproduce the behavior:

import os
import torch
import torch_tensorrt
import torchvision.models as models

# Initialize model with half precision and sample inputs
model = models.resnet18(pretrained=True).half().eval().to("cuda")
compile_spec = {
    "inputs": [
        torch_tensorrt.Input(
            min_shape=(1, 3, 224, 224),
            opt_shape=(8, 3, 224, 224),
            max_shape=(16, 3, 224, 224),
            dtype=torch.half,
        )
    ],
    "ir": "dynamo",
    "enabled_precisions": {torch.float32, torch.float16, torch.bfloat16},
    "min_block_size": 1,
    "debug": True,
    "output_format": "exported_program",
    "cache_built_engines": True,
    "reuse_cached_engines": True,
    "torch_executed_ops": {torch.ops.aten.matmul}

}
trt_model = torch_tensorrt.compile(model, **compile_spec)

Creates the following failure:

Traceback (most recent call last):
  File "/home/naren/pytorch_org/tensorrt/examples/dynamo/aoti_resnet.py", line 56, in <module>
    trt_model = torch_tensorrt.compile(model, **compile_spec)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/_compile.py", line 289, in compile
    trt_graph_module = dynamo_compile(
                       ^^^^^^^^^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/_compiler.py", line 712, in compile
    trt_gm = compile_module(
             ^^^^^^^^^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/_compiler.py", line 918, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 108, in convert_module
    return rt_cls(
           ^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/_features.py", line 68, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 148, in __init__
    self.setup_engine()
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/_features.py", line 68, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 226, in setup_engine
    self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())
                                                ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/_features.py", line 68, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 182, in _pack_engine_info
    engine_info[SERIALIZED_METADATA_IDX] = self.encode_metadata(metadata)
                                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/_features.py", line 68, in wrapper
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/naren/pytorch_org/tensorrt/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py", line 230, in encode_metadata
    dumped_metadata = pickle.dumps(metadata)
                      ^^^^^^^^^^^^^^^^^^^^^^
TypeError: cannot pickle 'PyCapsule' object
DEBUG Command exited with code: 1

Expected behavior

The module should be able to set up properly

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.8.0.dev0+f09be7245
  • PyTorch Version (e.g. 1.0): 2.8.0.dev20250606+cu128
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): PyTorch index
  • Build command you used (if compiling from source): uv run
  • Are you using local sources or building from archives:
  • Python version: 3.11
  • CUDA version: 12.8
  • GPU models and configuration: 3080Ti
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions