Skip to content

Commit 6e53805

Browse files
committed
fix: Add support for truncate_long_and_double in FX
- Add support and testing for `double` type inputs
1 parent 5382916 commit 6e53805

File tree

6 files changed

+104
-6
lines changed

6 files changed

+104
-6
lines changed

py/torch_tensorrt/fx/fx2trt.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
explicit_batch_dimension: bool = False,
4242
explicit_precision: bool = False,
4343
logger_level=None,
44+
truncate_long_and_double=False,
4445
):
4546
super().__init__(module)
4647

@@ -70,6 +71,7 @@ def __init__(
7071

7172
self.optimization_profiles: Optional[List] = None
7273
self.input_specs = input_specs
74+
self.truncate_long_and_double = truncate_long_and_double
7375
self.input_specs_iter = 0
7476
self.validate_input_specs()
7577
self._cur_node_name: Optional[str] = None
@@ -306,7 +308,9 @@ def placeholder(self, target, args, kwargs):
306308
self.optimization_profiles[i].set_shape(target, *shape_range)
307309

308310
return self.network.add_input(
309-
name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
311+
name=target,
312+
shape=tuple(shape),
313+
dtype=torch_dtype_to_trt(dtype, self.truncate_long_and_double),
310314
)
311315

312316
def call_module(self, target, args, kwargs):

py/torch_tensorrt/fx/lower.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def compile(
4343
use_experimental_fx_rt=False,
4444
correctness_atol=1e-1,
4545
correctness_rtol=1e-1,
46+
truncate_long_and_double=False,
4647
) -> nn.Module:
4748
"""
4849
Takes in original module, input and lowering setting, run lowering workflow to turn module
@@ -62,6 +63,7 @@ def compile(
6263
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
6364
dynamic_batch: batch dimension (dim=0) is dynamic.
6465
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
66+
truncate_long_and_double: Whether to automatically truncate long and double-type tensor inputs to TRT Engines
6567
Returns:
6668
A torch.nn.Module lowered by TensorRT.
6769
"""
@@ -85,6 +87,7 @@ def compile(
8587
use_experimental_rt=use_experimental_fx_rt,
8688
correctness_atol=correctness_atol,
8789
correctness_rtol=correctness_rtol,
90+
truncate_long_and_double=truncate_long_and_double,
8891
)
8992
lowerer = Lowerer.create(lower_setting=lower_setting)
9093
return lowerer(module, input)
@@ -129,6 +132,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
129132
logger_level=trt.Logger.VERBOSE
130133
if self.lower_setting.verbose_log
131134
else trt.Logger.WARNING,
135+
truncate_long_and_double=self.lower_setting.truncate_long_and_double,
132136
)
133137

134138
interp_result: TRTInterpreterResult = interpreter.run(

py/torch_tensorrt/fx/lower_setting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,4 @@ class LowerSetting(LowerSettingBasic):
101101
correctness_atol: float = 0.1
102102
correctness_rtol: float = 0.1
103103
use_experimental_rt: bool = False
104+
truncate_long_and_double: bool = False

py/torch_tensorrt/fx/test/core/test_trt_module.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def forward(self, x):
7272
interp = TRTInterpreter(
7373
mod,
7474
input_specs=InputTensorSpec.from_tensors(inputs),
75+
truncate_long_and_double=True,
7576
)
7677
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
7778
torch.save(trt_mod, "trt.pt")
@@ -99,6 +100,66 @@ def forward(self, x):
99100
interp = TRTInterpreter(
100101
mod,
101102
input_specs=InputTensorSpec.from_tensors(inputs),
103+
truncate_long_and_double=True,
104+
)
105+
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
106+
st = trt_mod.state_dict()
107+
108+
new_trt_mod = TRTModule()
109+
new_trt_mod.load_state_dict(st)
110+
111+
torch.testing.assert_close(
112+
new_trt_mod(inputs[0].cuda()).cpu(),
113+
ref_output,
114+
rtol=1e-04,
115+
atol=1e-04,
116+
check_dtype=False,
117+
)
118+
119+
120+
class TestTRTModuleFloat64Input(TestCase):
121+
def test_save_and_load_trt_module(self):
122+
class TestModule(torch.nn.Module):
123+
def forward(self, x):
124+
return x + x
125+
126+
inputs = [torch.randn(5, 5).double()]
127+
mod = TestModule().eval()
128+
ref_output = mod(*inputs)
129+
130+
mod = acc_tracer.trace(mod, inputs)
131+
interp = TRTInterpreter(
132+
mod,
133+
input_specs=InputTensorSpec.from_tensors(inputs),
134+
truncate_long_and_double=True,
135+
)
136+
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
137+
torch.save(trt_mod, "trt.pt")
138+
reload_trt_mod = torch.load("trt.pt")
139+
140+
torch.testing.assert_close(
141+
reload_trt_mod(inputs[0].cuda()).cpu(),
142+
ref_output,
143+
rtol=1e-04,
144+
atol=1e-04,
145+
check_dtype=False,
146+
)
147+
os.remove(f"{os.getcwd()}/trt.pt")
148+
149+
def test_save_and_load_state_dict(self):
150+
class TestModule(torch.nn.Module):
151+
def forward(self, x):
152+
return x + x
153+
154+
inputs = [torch.randn(5, 5).double()]
155+
mod = TestModule().eval()
156+
ref_output = mod(*inputs)
157+
158+
mod = acc_tracer.trace(mod, inputs)
159+
interp = TRTInterpreter(
160+
mod,
161+
input_specs=InputTensorSpec.from_tensors(inputs),
162+
truncate_long_and_double=True,
102163
)
103164
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
104165
st = trt_mod.state_dict()

py/torch_tensorrt/fx/trt_module.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,15 @@ def forward(self, *inputs):
156156
inputs = (
157157
inputs[:i] + (inputs[i].to(torch.int32),) + inputs[i + 1 :]
158158
)
159+
elif (
160+
inputs[i].dtype == torch.float64
161+
and self.input_dtypes[i] == torch.float32
162+
):
163+
inputs = (
164+
inputs[:i]
165+
+ (inputs[i].to(torch.float32),)
166+
+ inputs[i + 1 :]
167+
)
159168

160169
assert (
161170
inputs[i].dtype == self.input_dtypes[i]

py/torch_tensorrt/fx/utils.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def from_str(label: str) -> Optional["LowerPrecision"]:
3939
return None
4040

4141

42-
def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType:
42+
def torch_dtype_to_trt(
43+
dtype: torch.dtype, truncate_long_and_double: bool = False
44+
) -> TRTDataType:
4345
"""
4446
Convert PyTorch data types to TensorRT data types.
4547
@@ -56,14 +58,31 @@ def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType:
5658
elif dtype == torch.int32:
5759
return trt.int32
5860
elif dtype == torch.int64:
59-
_LOGGER.warn(
60-
"Detected Int64 Input, Casting to Int32 for TRT Engine Compatibility"
61-
)
62-
return trt.int32
61+
if truncate_long_and_double:
62+
_LOGGER.warn(
63+
"Detected Int64 Input, Casting to Int32 for TRT Engine Compatibility"
64+
)
65+
return trt.int32
66+
else:
67+
raise TypeError(
68+
"Detected Int64 Input which is not supported by tensorrt, enable compilation"
69+
+ "option truncate_long_and_double=True to cast input to Int32 for TRT Engine"
70+
)
6371
elif dtype == torch.float16:
6472
return trt.float16
6573
elif dtype == torch.float32:
6674
return trt.float32
75+
elif dtype == torch.float64:
76+
if truncate_long_and_double:
77+
_LOGGER.warn(
78+
"Detected Float64 Input, Casting to Float32 for TRT Engine Compatibility"
79+
)
80+
return trt.float32
81+
else:
82+
raise TypeError(
83+
"Detected Float64 Input which is not supported by tensorrt, enable compilation"
84+
+ "option truncate_long_and_double=True to cast input to Float32 for TRT Engine"
85+
)
6786
else:
6887
raise TypeError("%s is not supported by tensorrt" % dtype)
6988

0 commit comments

Comments
 (0)