Skip to content

Commit 8d32f9c

Browse files
committed
feat: Add preliminary support for freezing tensors in Dynamo
1 parent 65277c5 commit 8d32f9c

File tree

4 files changed

+49
-6
lines changed

4 files changed

+49
-6
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch
44
from functools import partial
55
import torch._dynamo as td
6+
from torch._guards import TracingContext
67

78
from torch_tensorrt.dynamo import CompilationSettings
89
from torch_tensorrt.dynamo.lowering._decompositions import (
@@ -15,10 +16,12 @@
1516
partition,
1617
get_submod_inputs,
1718
)
19+
from torch_tensorrt.dynamo.lowering._freeze_aot_graph import freeze_autograd_gm
1820
from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs
1921
from torch_tensorrt.dynamo.conversion import convert_module
2022

21-
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
23+
from torch._functorch.aot_autograd import make_boxed_compiler
24+
from .aot_module import aot_module
2225

2326

2427
logger = logging.getLogger(__name__)
@@ -30,6 +33,8 @@ def torch_tensorrt_backend(
3033
):
3134
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3235

36+
TracingContext.get().fake_mode.allow_non_fake_inputs = True
37+
3338
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
3439

3540

@@ -48,7 +53,7 @@ def aot_torch_tensorrt_aten_backend(
4853
gm = pre_aot_substitutions(gm)
4954

5055
# Invoke AOTAutograd to translate operators to aten
51-
return aot_module_simplified(
56+
return aot_module(
5257
gm,
5358
sample_inputs,
5459
fw_compiler=make_boxed_compiler(custom_backend),
@@ -73,9 +78,16 @@ def _pretraced_backend(
7378
try:
7479
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
7580

81+
frozen_gm, unfrozen_indices = freeze_autograd_gm(gm, sample_inputs)
82+
nonfrozen_inputs = [sample_inputs[idx] for idx in unfrozen_indices]
83+
84+
frozen_gm.graph.eliminate_dead_code()
85+
frozen_gm.graph.lint()
86+
frozen_gm.recompile()
87+
7688
trt_compiled = _compile_module(
77-
gm,
78-
sample_inputs,
89+
frozen_gm,
90+
nonfrozen_inputs,
7991
settings=settings,
8092
)
8193
return trt_compiled

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
unified_dtype_converter,
2323
Frameworks,
2424
)
25+
from torch.utils._python_dispatch import _disable_current_modes
26+
2527

2628
_LOGGER: logging.Logger = logging.getLogger(__name__)
2729

@@ -296,6 +298,21 @@ def call_function(self, target, args, kwargs):
296298
assert self._cur_node_name is not None
297299
return converter(self.network, target, args, kwargs, self._cur_node_name)
298300

301+
def get_attr(self, target, args, kwargs):
302+
with _disable_current_modes():
303+
from torch_tensorrt.fx.converters import to_numpy
304+
305+
frozen_attr = self.fetch_attr(target)
306+
307+
if isinstance(frozen_attr, torch.nn.Parameter):
308+
constant_tensor = frozen_attr.data
309+
else:
310+
constant_tensor = frozen_attr
311+
312+
network_constant = to_numpy(constant_tensor)
313+
314+
return network_constant
315+
299316
def call_method(self, target, args, kwargs):
300317
assert isinstance(target, str)
301318
converter = CONVERTERS.get(target)
@@ -317,6 +334,17 @@ def output(self, target, args, kwargs):
317334
else:
318335
outputs = (args[0],)
319336

337+
for output_idx in range(len(outputs)):
338+
from torch_tensorrt.fx.converters import get_trt_tensor
339+
340+
output = outputs[output_idx]
341+
342+
if not isinstance(output, trt.tensorrt.ITensor):
343+
new_output = get_trt_tensor(self.network, output, target)
344+
outputs = (
345+
outputs[:output_idx] + (new_output,) + outputs[output_idx + 1 :]
346+
)
347+
320348
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
321349
raise RuntimeError("TensorRT requires all outputs to be Tensor!")
322350

@@ -356,3 +384,5 @@ def output(self, target, args, kwargs):
356384
elif self.output_fp16 and output.dtype == trt.float32:
357385
output.dtype = trt.float16
358386
self._output_names.append(name)
387+
388+
return list(outputs)

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@
88
from ._partition import partition, get_submod_inputs, DEFAULT_SINGLE_NODE_PARTITIONS
99
from .substitutions import *
1010
from ._fusers import *
11+
from ._freeze_aot_graph import *

py/torch_tensorrt/dynamo/lowering/_partition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def is_node_supported(
125125

126126
if (
127127
node.target in CONVERTERS.keys()
128-
and node_name not in self.torch_executed_ops
129-
):
128+
or (node.op == "get_attr" and "frozen" in node_name)
129+
) and node_name not in self.torch_executed_ops:
130130
# If node is a proper, supported computational node, store the operator
131131
if not node.is_impure():
132132
self.supported_operators.add(node_name)

0 commit comments

Comments
 (0)