Skip to content

Commit 7d658d6

Browse files
committed
fix: Reorganize Dynamo backends
- Rename key backends to establish default backend and optional alternatives - Update function headers and docstrings, as well as key imports - Rename `torch_compile` folder to `backends` in accordance with `torch.compile` terminology - Update references throughout codebase - Specify certain functions as private/helper via underscore
1 parent d4e5ed0 commit 7d658d6

File tree

17 files changed

+77
-39
lines changed

17 files changed

+77
-39
lines changed

py/torch_tensorrt/_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def compile(
157157
**kwargs,
158158
)
159159
elif target_ir == _IRType.torch_compile:
160-
return torch_tensorrt.dynamo.torch_compile(
160+
return torch_tensorrt.dynamo.compile(
161161
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
162162
)
163163
elif target_ir == _IRType.fx_ts_compat:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
from torch_tensorrt.dynamo import fx_ts_compat
2-
from .torch_compile import compile as torch_compile
2+
from .backend import compile

py/torch_tensorrt/dynamo/torch_compile/__init__.py renamed to py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
from torch_tensorrt import EngineCapability, Device
99
from torch_tensorrt.fx.utils import LowerPrecision
1010

11-
from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
12-
from torch_tensorrt.dynamo.torch_compile.utils import prepare_inputs, prepare_device
13-
from torch_tensorrt.dynamo.torch_compile.backends import tensorrt_backend
14-
from torch_tensorrt.dynamo.torch_compile._defaults import (
11+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
12+
from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device
13+
from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend
14+
from torch_tensorrt.dynamo.backend._defaults import (
1515
PRECISION,
1616
DEBUG,
1717
MAX_WORKSPACE_SIZE,
@@ -121,6 +121,6 @@ def create_backend(
121121
)
122122

123123
return partial(
124-
tensorrt_backend,
124+
torch_tensorrt_backend,
125125
settings=settings,
126126
)

py/torch_tensorrt/dynamo/torch_compile/_settings.py renamed to py/torch_tensorrt/dynamo/backend/_settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from dataclasses import dataclass
22

33
from torch_tensorrt.fx.utils import LowerPrecision
4-
from torch_tensorrt.dynamo.torch_compile._defaults import (
4+
from torch_tensorrt.dynamo.backend._defaults import (
55
PRECISION,
66
DEBUG,
77
MAX_WORKSPACE_SIZE,

py/torch_tensorrt/dynamo/torch_compile/backends.py renamed to py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,42 @@
44
from functools import partial
55
import torch._dynamo as td
66

7-
from torch_tensorrt.dynamo.torch_compile._settings import CompilationSettings
8-
from torch_tensorrt.dynamo.torch_compile.lowering._decompositions import (
7+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
8+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
99
get_decompositions,
1010
)
11-
from torch_tensorrt.dynamo.torch_compile.lowering._partition import (
11+
from torch_tensorrt.dynamo.backend.lowering._partition import (
1212
partition,
1313
get_submod_inputs,
1414
)
15-
from torch_tensorrt.dynamo.torch_compile.conversion import convert_module
15+
from torch_tensorrt.dynamo.backend.conversion import convert_module
1616

1717
from torch._dynamo.backends.common import fake_tensor_unsupported
1818

1919
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2020

2121

22-
@td.register_backend(name="tensorrt")
22+
@td.register_backend(name="torch_tensorrt")
2323
@fake_tensor_unsupported
24-
def tensorrt_backend(
25-
gm: torch.nn.Module,
24+
def torch_tensorrt_backend(
25+
gm: torch.fx.GraphModule,
26+
sample_inputs: Sequence[torch.Tensor],
27+
settings: CompilationSettings = CompilationSettings(),
28+
):
29+
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
30+
31+
return DEFAULT_BACKEND(gm=gm, sample_inputs=sample_inputs, settings=settings)
32+
33+
34+
@td.register_backend(name="aot_torch_tensorrt_aten")
35+
@fake_tensor_unsupported
36+
def aot_torch_tensorrt_aten_backend(
37+
gm: torch.fx.GraphModule,
2638
sample_inputs: Sequence[torch.Tensor],
2739
settings: CompilationSettings = CompilationSettings(),
2840
):
2941
custom_backend = partial(
30-
fx_dynamo_backend,
42+
_pretraced_backend,
3143
settings=settings,
3244
)
3345

@@ -40,14 +52,12 @@ def tensorrt_backend(
4052
)
4153

4254

43-
@td.register_backend(name="fx_tensorrt")
44-
@fake_tensor_unsupported
45-
def fx_dynamo_backend(
55+
def _pretraced_backend(
4656
gm: torch.fx.GraphModule,
47-
example_inputs: Sequence[torch.Tensor],
57+
sample_inputs: Sequence[torch.Tensor],
4858
settings: CompilationSettings = CompilationSettings(),
4959
):
50-
"""Helper function to manage translation of FX module to TRT engines
60+
"""Helper function to manage translation of traced FX module to TRT engines
5161
5262
Args:
5363
module: FX GraphModule to convert
@@ -57,9 +67,9 @@ def fx_dynamo_backend(
5767
Compiled FX GraphModule
5868
"""
5969
try:
60-
trt_compiled = compile_module(
70+
trt_compiled = _compile_module(
6171
gm,
62-
example_inputs,
72+
sample_inputs,
6373
settings=settings,
6474
)
6575
return trt_compiled
@@ -72,12 +82,12 @@ def fx_dynamo_backend(
7282
return gm.forward
7383

7484

75-
def compile_module(
85+
def _compile_module(
7686
gm: torch.fx.GraphModule,
77-
example_inputs: Sequence[torch.Tensor],
87+
sample_inputs: Sequence[torch.Tensor],
7888
settings: CompilationSettings = CompilationSettings(),
7989
) -> torch.fx.GraphModule:
80-
"""Compile an FX module
90+
"""Compile a traced FX module
8191
8292
Includes: Partitioning + Conversion Phases
8393
@@ -100,7 +110,7 @@ def compile_module(
100110

101111
# Get submodule inputs
102112
submodule_inputs = get_submod_inputs(
103-
partitioned_module, submodule, example_inputs
113+
partitioned_module, submodule, sample_inputs
104114
)
105115

106116
# Create TRT Module from submodule
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from torch_tensorrt.dynamo.backend.lowering._decompositions import (
2+
get_decompositions,
3+
)
4+
from torch_tensorrt.dynamo.backend.lowering._partition import (
5+
partition,
6+
get_submod_inputs,
7+
)

py/torch_tensorrt/dynamo/torch_compile/lowering/_partition.py renamed to py/torch_tensorrt/dynamo/backend/lowering/_partition.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from torch_tensorrt.dynamo.torch_compile._defaults import MAX_NUM_TRT_ENGINES
5+
from torch_tensorrt.dynamo.backend._defaults import MAX_NUM_TRT_ENGINES
66
from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
77
from torch.fx.passes.operator_support import OperatorSupport
88

0 commit comments

Comments
 (0)