diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 8f6408492a..cf869562b6 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -12,6 +12,7 @@ partition, get_submod_inputs, ) +from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs from torch_tensorrt.dynamo.backend.conversion import convert_module from torch._dynamo.backends.common import fake_tensor_unsupported @@ -25,22 +26,20 @@ @td.register_backend(name="torch_tensorrt") @fake_tensor_unsupported def torch_tensorrt_backend( - gm: torch.fx.GraphModule, - sample_inputs: Sequence[torch.Tensor], - settings: CompilationSettings = CompilationSettings(), + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs ): DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend - return DEFAULT_BACKEND(gm, sample_inputs, settings=settings) + return DEFAULT_BACKEND(gm, sample_inputs, **kwargs) @td.register_backend(name="aot_torch_tensorrt_aten") @fake_tensor_unsupported def aot_torch_tensorrt_aten_backend( - gm: torch.fx.GraphModule, - sample_inputs: Sequence[torch.Tensor], - settings: CompilationSettings = CompilationSettings(), + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs ): + settings = parse_dynamo_kwargs(kwargs) + custom_backend = partial( _pretraced_backend, settings=settings, diff --git a/py/torch_tensorrt/dynamo/backend/utils.py b/py/torch_tensorrt/dynamo/backend/utils.py index e6e22d5f96..9396373790 100644 --- a/py/torch_tensorrt/dynamo/backend/utils.py +++ b/py/torch_tensorrt/dynamo/backend/utils.py @@ -1,9 +1,15 @@ import torch +import logging +from dataclasses import replace, fields +from torch_tensorrt.dynamo.backend._settings import CompilationSettings from typing import Any, Union, Sequence, Dict from torch_tensorrt import _Input, Device +logger = logging.getLogger(__name__) + + def prepare_inputs( inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict], device: torch.device = torch.device("cuda"), @@ -66,3 +72,36 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device: ) return device + + +def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings: + """Parses the kwargs field of a Dynamo backend + + Args: + kwargs: Keyword arguments dictionary provided to the backend + Returns: + CompilationSettings object with relevant kwargs + """ + + # Initialize an empty CompilationSettings object + settings = CompilationSettings() + + # If the user specifies keyword args, overwrite those fields in settings + # Validate all specified kwargs to ensure they are true fields of the dataclass + # + # Note: kwargs provided by torch.compile are wrapped in the "options" key + if kwargs: + if "options" in kwargs and len(kwargs) == 1: + kwargs = kwargs["options"] + + valid_attrs = {attr.name for attr in fields(settings)} + valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs} + settings = replace(settings, **valid_kwargs) + + # Enable debug/verbose mode if requested + if settings.debug: + logger.setLevel(logging.DEBUG) + + logger.debug(f"Compiling with Settings:\n{settings}") + + return settings