Skip to content

Commit 21847e3

Browse files
committed
fix: Add truncate_long_and_double to Dynamo
- Add default, setting, and function arguments for `truncate_long_and_double` in Dynamo - Add utilities for repairing long/double inputs to TRT engines, including support for autocasting back to long/double after the computation completes - Add multiple helper functions to enable easy testing and diagnosis of long/double IO to TRT engines - Add necessary compiler code to enable usage of the `truncate_long_and_double` argument as a switch for the feature - Add Dynamo compile support for `truncate_long_and_double` compilation argument by intercepting long/double type inputs and casting them to their 32-bit counterparts prior to usage in TRT-accelerated subgraphs, then casting back if necessary - Add robust logic to handle 64-bit inputs and outputs - Add test cases for long and double scenarios - Centralize truncation utility for later use in Dynamo export path
1 parent db3523a commit 21847e3

File tree

7 files changed

+339
-2
lines changed

7 files changed

+339
-2
lines changed

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
VERSION_COMPATIBLE = False
1111
OPTIMIZATION_LEVEL = None
1212
USE_PYTHON_RUNTIME = None
13+
TRUNCATE_LONG_AND_DOUBLE = False

py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
VERSION_COMPATIBLE,
2121
OPTIMIZATION_LEVEL,
2222
USE_PYTHON_RUNTIME,
23+
TRUNCATE_LONG_AND_DOUBLE,
2324
)
2425

2526

@@ -43,7 +44,7 @@ def compile(
4344
dla_local_dram_size=1073741824,
4445
dla_global_dram_size=536870912,
4546
calibrator=None,
46-
truncate_long_and_double=False,
47+
truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE,
4748
require_full_compilation=False,
4849
min_block_size=MIN_BLOCK_SIZE,
4950
torch_executed_ops=[],
@@ -62,7 +63,7 @@ def compile(
6263
"The Dynamo backend is an experimental feature, for which only the "
6364
+ "following arguments are supported: "
6465
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
65-
+ "torch_executed_ops, pass_through_build_failures}"
66+
+ "truncate_long_and_double, torch_executed_ops, pass_through_build_failures}"
6667
)
6768

6869
if not isinstance(inputs, collections.abc.Sequence):
@@ -103,6 +104,7 @@ def compile(
103104
version_compatible=version_compatible,
104105
optimization_level=optimization_level,
105106
use_python_runtime=use_python_runtime,
107+
truncate_long_and_double=truncate_long_and_double,
106108
**kwargs,
107109
)
108110

@@ -130,6 +132,7 @@ def create_backend(
130132
version_compatible: bool = VERSION_COMPATIBLE,
131133
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
132134
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
135+
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
133136
**kwargs,
134137
):
135138
"""Create torch.compile backend given specified arguments
@@ -163,5 +166,6 @@ def create_backend(
163166
version_compatible=version_compatible,
164167
optimization_level=optimization_level,
165168
use_python_runtime=use_python_runtime,
169+
truncate_long_and_double=truncate_long_and_double,
166170
**kwargs,
167171
)

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_submod_inputs,
1717
)
1818
from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs
19+
from torch_tensorrt.dynamo.common import repair_long_or_double_inputs
1920
from torch_tensorrt.dynamo.backend.conversion import convert_module
2021

2122
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
@@ -134,6 +135,16 @@ def _compile_module(
134135
partitioned_module, submodule, sample_inputs
135136
)
136137

138+
# Ensure all submodule inputs do not require a gradient
139+
for param in submodule_inputs:
140+
param.requires_grad = False
141+
142+
# Handle long/double inputs if requested by the user
143+
if settings.truncate_long_and_double:
144+
submodule_inputs = repair_long_or_double_inputs(
145+
partitioned_module, submodule, submodule_inputs, name
146+
)
147+
137148
# Create TRT Module from submodule
138149
trt_mod = convert_module(
139150
submodule,

py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,5 +169,116 @@ def forward(self, x, y):
169169
)
170170

171171

172+
class Test64BitInput(TestCase):
173+
def test_float64_input_full_support(self):
174+
class FullySupportedMultiOp(torch.nn.Module):
175+
def forward(self, x, y):
176+
return torch.ops.aten.mean.dim(
177+
torch.ops.aten.mul.Tensor(torch.ops.aten.add.Tensor(x, y), 2), [0]
178+
)
179+
180+
fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp())
181+
partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3)
182+
183+
self.assertEquals(
184+
len(list(partitioned_graph.named_children())),
185+
1,
186+
"All operators are supported, there should be one segment",
187+
)
188+
189+
inputs = [
190+
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
191+
torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(),
192+
]
193+
194+
torch._dynamo.reset()
195+
196+
# Validate that the results between Torch and Torch-TRT are similar
197+
optimized_model = compile(
198+
fx_graph,
199+
inputs,
200+
min_block_size=1,
201+
pass_through_build_failures=True,
202+
truncate_long_and_double=True,
203+
debug=True,
204+
)
205+
optimized_model_results = optimized_model(*inputs).detach().cpu()
206+
torch_model_results = fx_graph(*inputs).detach().cpu()
207+
208+
max_diff = float(
209+
torch.max(torch.abs(optimized_model_results - torch_model_results))
210+
)
211+
self.assertAlmostEqual(
212+
max_diff,
213+
0,
214+
DECIMALS_OF_AGREEMENT,
215+
f"TRT outputs don't match with the original model.",
216+
)
217+
218+
def test_int64_input_partial_support(self):
219+
class PartiallySupportedMultiOp(torch.nn.Module):
220+
def forward(self, x, y):
221+
return torch.ops.aten.div.Tensor_mode(
222+
x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor"
223+
)
224+
225+
fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp())
226+
unexpected_ops = {torch.ops.aten.add.Tensor}
227+
228+
inputs = [
229+
torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(),
230+
torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(),
231+
]
232+
233+
(unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing(
234+
fx_graph,
235+
inputs,
236+
unexpected_ops=unexpected_ops,
237+
min_block_size=1,
238+
torch_executed_ops={"torch.ops.aten.add.Tensor"},
239+
testing_partitioning=True,
240+
)
241+
242+
self.assertEquals(
243+
len(unexpected_ops_seen),
244+
0,
245+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
246+
)
247+
self.assertEquals(
248+
len(partitioned_graphs),
249+
1,
250+
"Without control flow breaks, there should only be a single graph",
251+
)
252+
self.assertEquals(
253+
len(list(partitioned_graphs[0].named_children())),
254+
1,
255+
"Certain operators are set to run in Torch, expected 1 segment",
256+
)
257+
258+
torch._dynamo.reset()
259+
260+
# Validate that the results between Torch and Torch-TRT are similar
261+
optimized_model = compile(
262+
fx_graph,
263+
inputs,
264+
min_block_size=1,
265+
pass_through_build_failures=True,
266+
truncate_long_and_double=True,
267+
debug=True,
268+
)
269+
optimized_model_results = optimized_model(*inputs).detach().cpu()
270+
torch_model_results = fx_graph(*inputs).detach().cpu()
271+
272+
max_diff = float(
273+
torch.max(torch.abs(optimized_model_results - torch_model_results))
274+
)
275+
self.assertAlmostEqual(
276+
max_diff,
277+
0,
278+
DECIMALS_OF_AGREEMENT,
279+
f"TRT outputs don't match with the original model.",
280+
)
281+
282+
172283
if __name__ == "__main__":
173284
run_tests()

py/torch_tensorrt/dynamo/common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ._settings import CompilationSettings
55
from .input_tensor_spec import InputTensorSpec
66
from .fx2trt import TRTInterpreter, TRTInterpreterResult
7+
from .truncate_long_and_double import repair_long_or_double_inputs
78

89

910
logger = logging.getLogger(__name__)

py/torch_tensorrt/dynamo/common/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
VERSION_COMPATIBLE,
1313
OPTIMIZATION_LEVEL,
1414
USE_PYTHON_RUNTIME,
15+
TRUNCATE_LONG_AND_DOUBLE,
1516
)
1617

1718

@@ -27,3 +28,4 @@ class CompilationSettings:
2728
version_compatible: bool = VERSION_COMPATIBLE
2829
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
2930
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
31+
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE

0 commit comments

Comments
 (0)