Skip to content

Commit 9a9820a

Browse files
committed
fix/feat: Add support for 64bit Tensor inputs FX
- Add `truncate_long_and_double` argument in FX settings to allow 64bit inputs - Utilize existing Dynamo functionality to repair FX aten graphs with 64bit inputs - Refactor imports in Dynamo to avoid circular import issue from new dependency issues - Add test cases to validate new feature
1 parent 921dd2f commit 9a9820a

File tree

10 files changed

+131
-13
lines changed

10 files changed

+131
-13
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from functools import partial
66

77
from typing import Any, Optional, Sequence
8-
from torch_tensorrt import EngineCapability, Device
8+
from torch_tensorrt._Device import Device
9+
from torch_tensorrt._enums import EngineCapability
910
from torch_tensorrt.fx.utils import LowerPrecision
1011

1112
from torch_tensorrt.dynamo.common import CompilationSettings

py/torch_tensorrt/dynamo/backend/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
from torch_tensorrt.dynamo.common import CompilationSettings
66
from typing import Any, Union, Sequence, Dict
7-
from torch_tensorrt import _Input, Device
7+
from torch_tensorrt import _Input
8+
from torch_tensorrt._Device import Device
89

910

1011
logger = logging.getLogger(__name__)

py/torch_tensorrt/dynamo/fx_ts_compat/lower.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
VERSION_COMPATIBLE,
3232
OPTIMIZATION_LEVEL,
3333
USE_EXPERIMENTAL_RT,
34+
TRUNCATE_LONG_AND_DOUBLE,
3435
)
3536

3637
logger = logging.getLogger(__name__)
@@ -51,7 +52,7 @@ def compile(
5152
dla_local_dram_size=1073741824,
5253
dla_global_dram_size=536870912,
5354
calibrator=None,
54-
truncate_long_and_double=False,
55+
truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE,
5556
require_full_compilation=False,
5657
explicit_batch_dimension=False,
5758
debug=DEBUG,
@@ -86,6 +87,7 @@ def compile(
8687
max_aux_streams: max number of aux stream to use
8788
version_compatible: enable version compatible feature
8889
optimization_level: builder optimization level
90+
truncate_long_and_double: Whether to truncate long and double inputs to TRT engines automatically
8991
Returns:
9092
A torch.nn.Module lowered by TensorRT.
9193
"""
@@ -144,6 +146,7 @@ def compile(
144146
max_aux_streams=max_aux_streams,
145147
version_compatible=version_compatible,
146148
optimization_level=optimization_level,
149+
truncate_long_and_double=truncate_long_and_double,
147150
)
148151
lowerer = Lowerer.create(lower_setting=lower_setting)
149152
return lowerer(module, inputs)
@@ -222,6 +225,7 @@ def default_split_function(
222225
splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
223226
splitter_setting.min_block_size = lower_setting.min_block_size
224227
splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt
228+
splitter_setting.truncate_long_and_double = lower_setting.truncate_long_and_double
225229
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
226230
splitter.node_support_preview()
227231
return splitter.generate_split_results()

py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class LowerSetting(LowerSettingBasic):
7373
max_aux_streams: max number of aux stream to use
7474
version_compatible: enable version compatible feature
7575
optimization_level: builder optimization level
76+
truncate_long_and_double: Whether to truncate long and double inputs to TRT engines automatically
7677
"""
7778

7879
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
@@ -102,3 +103,4 @@ class LowerSetting(LowerSettingBasic):
102103
max_aux_streams: Optional[int] = None
103104
version_compatible: bool = False
104105
optimization_level: Optional[int] = None
106+
truncate_long_and_double: bool = False

py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
1111
from torch_tensorrt.fx.utils import LowerPrecision
1212
from torch_tensorrt import _Input
13-
from torch_tensorrt.dynamo.common import InputTensorSpec
13+
from torch_tensorrt.dynamo.common import (
14+
InputTensorSpec,
15+
repair_long_or_double_inputs,
16+
)
1417

1518
from ..lower_setting import LowerSetting
1619
from torch_tensorrt.fx.observer import Observer
@@ -196,6 +199,14 @@ def lower_func(split_result: SplitResult) -> nn.Module:
196199
_LOGGER.info(f"Now lowering submodule {submod_name}")
197200
lowering_start_time = datetime.datetime.now()
198201

202+
if self.lower_setting.truncate_long_and_double:
203+
submod_inputs = repair_long_or_double_inputs(
204+
parent_graph=split_result.split_module,
205+
submodule=submod,
206+
submodule_inputs=submod_inputs,
207+
submodule_name=submod_name,
208+
)
209+
199210
self.lower_setting.input_specs = self._trt_input
200211

201212
lowered_module = self._lower_func(

py/torch_tensorrt/fx/lower.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def compile(
4343
use_experimental_fx_rt=False,
4444
correctness_atol=1e-1,
4545
correctness_rtol=1e-1,
46+
truncate_long_and_double=False,
4647
) -> nn.Module:
4748
"""
4849
Takes in original module, input and lowering setting, run lowering workflow to turn module
@@ -62,6 +63,7 @@ def compile(
6263
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
6364
dynamic_batch: batch dimension (dim=0) is dynamic.
6465
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
66+
truncate_long_and_double: Whether to truncate long and double inputs to TRT engines automatically
6567
Returns:
6668
A torch.nn.Module lowered by TensorRT.
6769
"""
@@ -85,6 +87,7 @@ def compile(
8587
use_experimental_rt=use_experimental_fx_rt,
8688
correctness_atol=correctness_atol,
8789
correctness_rtol=correctness_rtol,
90+
truncate_long_and_double=truncate_long_and_double,
8891
)
8992
lowerer = Lowerer.create(lower_setting=lower_setting)
9093
return lowerer(module, input)
@@ -159,6 +162,7 @@ def default_split_function(
159162
splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension
160163
splitter_setting.min_acc_module_size = lower_setting.min_acc_module_size
161164
splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt
165+
splitter_setting.truncate_long_and_double = lower_setting.truncate_long_and_double
162166
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
163167
splitter.node_support_preview()
164168
return splitter.generate_split_results()

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class LowerSetting(LowerSettingBasic):
7474
correctness_atol: absolute tolerance for correctness check
7575
correctness_rtol: relative tolerance for correctness check
7676
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
77+
truncate_long_and_double: Whether to truncate long and double inputs to TRT engines automatically
7778
"""
7879

7980
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
@@ -101,3 +102,4 @@ class LowerSetting(LowerSettingBasic):
101102
correctness_atol: float = 0.1
102103
correctness_rtol: float = 0.1
103104
use_experimental_rt: bool = False
105+
truncate_long_and_double: bool = False

py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult
1111
from torch_tensorrt.fx.passes.pass_utils import apply_bfloat_float_conversion
1212
from torch_tensorrt.fx.utils import LowerPrecision
13+
from torch_tensorrt.dynamo.common import (
14+
repair_long_or_double_inputs,
15+
)
1316

1417
from ..input_tensor_spec import generate_input_specs
1518

@@ -193,6 +196,14 @@ def lower_func(split_result: SplitResult) -> nn.Module:
193196
_LOGGER.info(f"Now lowering submodule {submod_name}")
194197
lowering_start_time = datetime.datetime.now()
195198

199+
if self.lower_setting.truncate_long_and_double:
200+
submod_inputs = repair_long_or_double_inputs(
201+
parent_graph=split_result.split_module,
202+
submodule=submod,
203+
submodule_inputs=submod_inputs,
204+
submodule_name=submod_name,
205+
)
206+
196207
self.lower_setting.input_specs = generate_input_specs(
197208
submod_inputs,
198209
self.lower_setting,
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import unittest
2+
3+
import torch
4+
5+
from torch_tensorrt.fx.lower import compile
6+
from torch_tensorrt.fx.utils import LowerPrecision
7+
8+
9+
class LongInputTest(unittest.TestCase):
10+
def test_long_input(self):
11+
class Model(torch.nn.Module):
12+
def forward(self, x):
13+
out = x + 1
14+
out = out * 2
15+
out = out - 1
16+
return out
17+
18+
mod = Model().cuda().eval()
19+
20+
inputs = [torch.randint(-40, 40, (3, 4, 7)).cuda().long()]
21+
22+
aten_mod = compile(
23+
mod,
24+
inputs,
25+
min_acc_module_size=3,
26+
explicit_batch_dimension=True,
27+
verbose_log=True,
28+
lower_precision=LowerPrecision.FP16,
29+
truncate_long_and_double=True,
30+
dynamic_batch=False,
31+
is_aten=True,
32+
)
33+
34+
aten_output = aten_mod(*inputs)[0].detach().cpu()
35+
torch_output = mod(*inputs).detach().cpu()
36+
37+
max_diff = float(torch.max(torch.abs(aten_output - torch_output)))
38+
39+
self.assertAlmostEqual(
40+
max_diff, 0, 4, msg="Torch outputs don't match with TRT outputs"
41+
)
42+
43+
44+
class DoubleInputTest(unittest.TestCase):
45+
def test_double_input(self):
46+
class Model(torch.nn.Module):
47+
def forward(self, x):
48+
out = x + 1
49+
out = out * 2
50+
return torch.mean(out, dim=-1)
51+
52+
mod = Model().cuda().eval()
53+
54+
inputs = [torch.rand((3, 4, 1)).cuda().double()]
55+
56+
aten_mod = compile(
57+
mod,
58+
inputs,
59+
min_acc_module_size=3,
60+
explicit_batch_dimension=True,
61+
verbose_log=True,
62+
lower_precision=LowerPrecision.FP32,
63+
truncate_long_and_double=True,
64+
dynamic_batch=False,
65+
is_aten=True,
66+
)
67+
68+
aten_output = aten_mod(*inputs)[0].detach().cpu()
69+
torch_output = mod(*inputs).detach().cpu()
70+
71+
max_diff = float(torch.max(torch.abs(aten_output - torch_output)))
72+
73+
self.assertAlmostEqual(
74+
max_diff, 0, 4, msg="Torch outputs don't match with TRT outputs"
75+
)

py/torch_tensorrt/fx/tools/trt_splitter.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
def create_trt_operator_support(
2020
use_implicit_batch_dim=True,
2121
exclude_support_node_name: set = (),
22+
truncate_long_and_double: bool = False,
2223
) -> ops.OperatorSupportBase:
2324
"""Creates an `OperatorSupportBase` instance used for TRT splitting purpose."""
2425
# Create an `OperatorSupport` that declares a node supported if it
@@ -32,14 +33,17 @@ def create_trt_operator_support(
3233
support_dict[get_acc_ops_name(k)] = None
3334
supported_if_converter_registered = ops.OperatorSupport(support_dict=support_dict)
3435

35-
return ops.chain(
36-
ops.OpSupports.decline_if_node_in_names(exclude_support_node_name),
37-
# 1. Node is not supported if it has args with int64 or float64 dtype:
38-
ops.OpSupports.decline_if_input_dtype(torch.int64),
39-
ops.OpSupports.decline_if_input_dtype(torch.float64),
40-
# 2. Node is supported if it has TRT converter:
41-
supported_if_converter_registered,
42-
)
36+
op_support_checks = [
37+
ops.OpSupports.decline_if_node_in_names(exclude_support_node_name)
38+
]
39+
40+
if not truncate_long_and_double:
41+
op_support_checks.append(ops.OpSupports.decline_if_input_dtype(torch.int64))
42+
op_support_checks.append(ops.OpSupports.decline_if_input_dtype(torch.float64))
43+
44+
op_support_checks.append(supported_if_converter_registered)
45+
46+
return ops.chain(*op_support_checks)
4347

4448

4549
class TRTSplitterSetting(splitter_base._SplitterSettingBase):
@@ -52,6 +56,7 @@ def __init__(self):
5256
self.use_implicit_batch_dim: bool = True
5357
self.exclude_support_node_name: set = set()
5458
self.use_experimental_rt: bool = False
59+
self.truncate_long_and_double: bool = False
5560

5661
if self.use_experimental_rt and self.use_implicit_batch_dim:
5762
raise ValueError(
@@ -71,7 +76,9 @@ def __init__(
7176
settings = TRTSplitterSetting()
7277
if not operator_support:
7378
operator_support = create_trt_operator_support(
74-
settings.use_implicit_batch_dim, settings.exclude_support_node_name
79+
settings.use_implicit_batch_dim,
80+
settings.exclude_support_node_name,
81+
settings.truncate_long_and_double,
7582
)
7683
super().__init__(
7784
module,

0 commit comments

Comments
 (0)