Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 17 additions & 15 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import logging
from enum import Enum
from typing import Any, Callable, List, Optional, Sequence, Set

import torch
import torch.fx
import torch_tensorrt.ts
from torch_tensorrt import logging
from torch_tensorrt._enums import dtype
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo.compile import compile as dynamo_compile
Expand All @@ -16,6 +16,13 @@
from torch_tensorrt.ts._compiler import compile as torchscript_compile
from typing_extensions import TypeGuard

logger = logging.getLogger(__name__)

__all__ = [
"compile",
"convert_method_to_trt_engine",
]


def _non_fx_input_interface(
inputs: Sequence[Input | torch.Tensor | InputTensorSpec],
Expand All @@ -30,7 +37,7 @@ def _fx_input_interface(


class _IRType(Enum):
"""Enum to set the minimum required logging level to print a message to stdout"""
"""Enum to determine the type of IR selected for model compilation"""

ts = 0
fx = 1
Expand All @@ -39,7 +46,7 @@ class _IRType(Enum):


class _ModuleType(Enum):
"""Enum to set the minimum required logging level to print a message to stdout"""
"""Enum to determine the type of model provided as input"""

nn = 0
ts = 1
Expand Down Expand Up @@ -81,14 +88,11 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
if ir == "default":
# Options are listed in order of preference
if module_is_fxable:
logging.log(
logging.Level.Info, "ir was set to default, using dynamo as ir"
)
logger.info("ir was set to default, using dynamo as ir")
return _IRType.dynamo
elif module_is_tsable:
logging.log(
logging.Level.Warning,
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript",
logger.warning(
"Input graph is a Torchscript module but the ir provided is default (dynamo). Please set ir=torchscript to suppress the warning. Compiling the module with ir=torchscript"
)
return _IRType.ts
else:
Expand Down Expand Up @@ -151,9 +155,8 @@ def compile(
if target_ir == _IRType.ts:
ts_mod = module
if module_type == _ModuleType.nn:
logging.log(
logging.Level.Info,
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
logger.info(
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
)
ts_mod = torch.jit.script(module)
assert _non_fx_input_interface(input_list)
Expand Down Expand Up @@ -274,9 +277,8 @@ def convert_method_to_trt_engine(
if target_ir == _IRType.ts:
ts_mod = module
if module_type == _ModuleType.nn:
logging.log(
logging.Level.Info,
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
logger.info(
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript"
)
ts_mod = torch.jit.script(module)
return torch_tensorrt.ts.convert_method_to_trt_engine( # type: ignore[no-any-return]
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import logging

from torch_tensorrt._utils import sanitized_torch_version

from packaging import version

logger = logging.getLogger(__name__)

if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
from ._settings import * # noqa: F403
from ._SourceIR import SourceIR # noqa: F403
Expand Down
11 changes: 7 additions & 4 deletions py/torch_tensorrt/dynamo/aten_tracer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import logging
import sys
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
Expand All @@ -26,6 +27,8 @@

Value: TypeAlias = Union[Tuple["Value", ...], List["Value"], Dict[str, "Value"]]

logger = logging.getLogger(__name__)


class DynamoConfig:
"""
Expand Down Expand Up @@ -85,7 +88,7 @@ def setting_python_recursive_limit(limit: int = 10000) -> Generator[None, None,
sys.setrecursionlimit(default)


@req_torch_version("2.dev")
@req_torch_version
def dynamo_trace(
f: Callable[..., Value],
# pyre-ignore
Expand Down Expand Up @@ -121,7 +124,7 @@ def dynamo_trace(
) from exc


@req_torch_version("2.dev")
@req_torch_version
def trace(
model: torch.nn.Module | torch.fx.GraphModule,
inputs: Tuple[Any, ...],
Expand All @@ -145,13 +148,13 @@ def trace(
]

fx_module, __package__ = dynamo_trace(model, inputs, True, "symbolic")
print(fx_module.graph)

for passes in passes_list:
pr: PassResult = passes(fx_module)
fx_module = pr.graph_module

fx_module(*inputs)

fx_module = run_const_fold(fx_module)
print(fx_module.graph)
logger.info("Post export graph : %s\n", fx_module.graph)
return fx_module
69 changes: 2 additions & 67 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@
import torch._dynamo as td
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo.conversion import (
convert_module,
repair_long_or_double_inputs,
)
from torch_tensorrt.dynamo.compile import compile_module
from torch_tensorrt.dynamo.lowering._decompositions import get_decompositions
from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition
from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs

Expand Down Expand Up @@ -70,7 +66,7 @@ def _pretraced_backend(
try:
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

trt_compiled = _compile_module(
trt_compiled = compile_module(
gm,
sample_inputs,
settings=settings,
Expand All @@ -93,64 +89,3 @@ def _pretraced_backend(
+ "specify pass_through_build_failures=False."
)
raise


def _compile_module(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule:
"""Compile a traced FX module

Includes: Partitioning + Conversion Phases

Args:
module: FX GraphModule to convert
inputs: Inputs to the module
settings: Compilation settings
Returns:
Compiled FX GraphModule
"""
# Partition module into components that can be TRT-accelerated
partitioned_module = partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)

# Store TRT replicas of Torch subgraphs
trt_modules = {}

# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)

# Get submodule inputs
submodule_inputs = get_submod_inputs(
partitioned_module, submodule, sample_inputs
)

assert submodule_inputs is not None
# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
partitioned_module, submodule, submodule_inputs, name
)

# Create TRT Module from submodule
trt_mod = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
)

trt_modules[name] = trt_mod

# Replace all FX Modules with TRT Modules
for name, trt_mod in trt_modules.items():
setattr(partitioned_module, name, trt_mod)

return partitioned_module
93 changes: 66 additions & 27 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import collections.abc
import logging
from typing import Any, List, Optional, Set, Tuple
from typing import Any, List, Optional, Sequence, Set, Tuple

import torch
import torch_tensorrt
from torch.fx.passes.pass_manager import PassManager
from torch.fx.passes.splitter_base import SplitResult
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import ( # TODO: Should probabably be the TRT EngineCapability Enum
Expand All @@ -25,12 +24,11 @@
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
)
from torch_tensorrt.dynamo.backend.backends import _compile_module
from torch_tensorrt.dynamo.conversion import convert_module
from torch_tensorrt.dynamo.lowering._fusers import (
fuse_permute_linear,
fuse_permute_matmul,
from torch_tensorrt.dynamo.conversion import (
convert_module,
repair_long_or_double_inputs,
)
from torch_tensorrt.dynamo.lowering._partition import get_submod_inputs, partition
from torch_tensorrt.dynamo.utils import prepare_device, prepare_inputs
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting

Expand Down Expand Up @@ -67,17 +65,11 @@ def compile(
**kwargs: Any,
) -> torch.fx.GraphModule:
if debug:
logger.setLevel(logging.DEBUG)
if logger.parent:
logger.parent.setLevel(logging.DEBUG)

enabled_precisions = set(enabled_precisions)

logger.warning(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
+ "torch_executed_ops, pass_through_build_failures}"
)

if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]

Expand Down Expand Up @@ -118,9 +110,10 @@ def compile(
}

settings = CompilationSettings(**compilation_options)
logger.debug("Compilation Settings: %s\n", settings)

if kwargs.get("use_capability_partitioner", None):
model = lower_model(gm, torch_inputs)
return _compile_module(model, torch_inputs, settings)
return compile_module(gm, torch_inputs, settings)
else:
split_result = lower_model_using_trt_splitter(gm, torch_inputs)
trt_module = _compile_graph(split_result, torch_inputs, settings)
Expand Down Expand Up @@ -153,8 +146,6 @@ def _compile_graph(
def lower_model_using_trt_splitter(
model: torch.nn.Module, inputs: Any, **kwargs: Any
) -> SplitResult:
# Perform basic lowering
model = lower_model(model, inputs)
splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = False
splitter_setting.min_acc_module_size = 1
Expand All @@ -166,14 +157,62 @@ def lower_model_using_trt_splitter(
return split_result


def lower_model(
model: torch.nn.Module, inputs: Any, **kwargs: Any
def compile_module(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule:
graph_optimization_pm = PassManager.build_from_passlist(
[fuse_permute_matmul, fuse_permute_linear]
"""Compile a traced FX module

Includes: Partitioning + Conversion Phases

Args:
module: FX GraphModule to convert
inputs: Inputs to the module
settings: Compilation settings
Returns:
Compiled FX GraphModule
"""
# Partition module into components that can be TRT-accelerated
partitioned_module = partition(
gm,
verbose=settings.debug,
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
)
lowered_model: torch.fx.GraphModule = graph_optimization_pm(model)
# if isinstance(lowered_model, torch.fx.GraphModule):
# ShapeProp(lowered_model).propagate(*inputs)

return lowered_model
# Store TRT replicas of Torch subgraphs
trt_modules = {}

# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
for name, _ in partitioned_module.named_children():
submodule = getattr(partitioned_module, name)

# Get submodule inputs
submodule_inputs = get_submod_inputs(
partitioned_module, submodule, sample_inputs
)

assert submodule_inputs is not None
# Handle long/double inputs if requested by the user
if settings.truncate_long_and_double:
submodule_inputs = repair_long_or_double_inputs(
partitioned_module, submodule, submodule_inputs, name
)

# Create TRT Module from submodule
trt_mod = convert_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
)

trt_modules[name] = trt_mod

# Replace all FX Modules with TRT Modules
for name, trt_mod in trt_modules.items():
setattr(partitioned_module, name, trt_mod)

return partitioned_module
Loading