Skip to content

Commit 5fe7c23

Browse files
committed
fix: Add support for torch.int64 inputs in FXTRT
- Add utility capabilities for accepting `int64` inputs to TRTModules to support multiple use cases - Support cases include situations where internal tensors in split modules are `int64` (generally used for indexing torch Tensors) - This also supports cases where the user wants to input `long` tensors as `forward` inputs - Add test cases to verify functionality and accuracy - Enable tests for `TRTModuleNext`, which are now fully supported on `main`
1 parent b3f433a commit 5fe7c23

File tree

3 files changed

+163
-88
lines changed

3 files changed

+163
-88
lines changed

py/torch_tensorrt/fx/test/core/test_trt_module.py

Lines changed: 141 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from torch.testing._internal.common_utils import run_tests, TestCase
1111
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule
1212

13-
# from torch_tensorrt import TRTModuleNext
14-
# from torch_tensorrt import Device
13+
from torch_tensorrt import TRTModuleNext
14+
from torch_tensorrt import Device
1515
from torch_tensorrt.fx.utils import LowerPrecision
1616

1717

@@ -58,89 +58,145 @@ def forward(self, x):
5858
)
5959

6060

61-
# TODO add unittest.skip later
62-
# class TestTRTModuleNext(TestCase):
63-
# def test_save_and_load_trt_module(self):
64-
# class TestModule(torch.nn.Module):
65-
# def forward(self, x):
66-
# return x + x
67-
68-
# inputs = [torch.randn(1, 1)]
69-
# mod = TestModule().eval()
70-
# ref_output = mod(*inputs)
71-
72-
# mod = acc_tracer.trace(mod, inputs)
73-
74-
# interp = TRTInterpreter(
75-
# mod,
76-
# input_specs=InputTensorSpec.from_tensors(inputs),
77-
# explicit_batch_dimension=True,
78-
# )
79-
# interp_res = interp.run(lower_precision=LowerPrecision.FP32)
80-
81-
# with io.BytesIO() as engine_bytes:
82-
# engine_bytes.write(interp_res.engine.serialize())
83-
# engine_str = engine_bytes.getvalue()
84-
85-
# trt_mod = TRTModuleNext(
86-
# name="TestModule",
87-
# serialized_engine=engine_str,
88-
# input_binding_names=interp_res.input_names,
89-
# output_binding_names=interp_res.output_names,
90-
# target_device=Device(f"cuda:{torch.cuda.current_device()}"),
91-
# )
92-
93-
# torch.save(trt_mod, "trt.pt")
94-
# reload_trt_mod = torch.load("trt.pt")
95-
96-
# torch.testing.assert_allclose(
97-
# reload_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output),
98-
# ref_output,
99-
# rtol=1e-04,
100-
# atol=1e-04,
101-
# )
102-
# os.remove(f"{os.getcwd()}/trt.pt")
103-
104-
# def test_save_and_load_state_dict(self):
105-
# class TestModule(torch.nn.Module):
106-
# def forward(self, x):
107-
# return x + x
108-
109-
# inputs = [torch.randn(1, 1)]
110-
# mod = TestModule().eval()
111-
# ref_output = mod(*inputs)
112-
113-
# mod = acc_tracer.trace(mod, inputs)
114-
# interp = TRTInterpreter(
115-
# mod,
116-
# input_specs=InputTensorSpec.from_tensors(inputs),
117-
# explicit_batch_dimension=True,
118-
# )
119-
# interp_res = interp.run(lower_precision=LowerPrecision.FP32)
120-
121-
# with io.BytesIO() as engine_bytes:
122-
# engine_bytes.write(interp_res.engine.serialize())
123-
# engine_str = engine_bytes.getvalue()
124-
125-
# trt_mod = TRTModuleNext(
126-
# name="TestModule",
127-
# serialized_engine=engine_str,
128-
# input_binding_names=interp_res.input_names,
129-
# output_binding_names=interp_res.output_names,
130-
# target_device=Device(f"cuda:{torch.cuda.current_device()}"),
131-
# )
132-
133-
# st = trt_mod.state_dict()
134-
135-
# new_trt_mod = TRTModuleNext()
136-
# new_trt_mod.load_state_dict(st)
137-
138-
# torch.testing.assert_allclose(
139-
# new_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output),
140-
# ref_output,
141-
# rtol=1e-04,
142-
# atol=1e-04,
143-
# )
61+
class TestTRTModuleInt64Input(TestCase):
62+
def test_save_and_load_trt_module(self):
63+
class TestModule(torch.nn.Module):
64+
def forward(self, x):
65+
return x + x
66+
67+
inputs = [torch.randn(5, 5).long()]
68+
mod = TestModule().eval()
69+
ref_output = mod(*inputs)
70+
71+
mod = acc_tracer.trace(mod, inputs)
72+
interp = TRTInterpreter(
73+
mod,
74+
input_specs=InputTensorSpec.from_tensors(inputs),
75+
)
76+
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
77+
torch.save(trt_mod, "trt.pt")
78+
reload_trt_mod = torch.load("trt.pt")
79+
80+
torch.testing.assert_close(
81+
reload_trt_mod(inputs[0].cuda()).cpu(),
82+
ref_output,
83+
rtol=1e-04,
84+
atol=1e-04,
85+
check_dtype=False,
86+
)
87+
os.remove(f"{os.getcwd()}/trt.pt")
88+
89+
def test_save_and_load_state_dict(self):
90+
class TestModule(torch.nn.Module):
91+
def forward(self, x):
92+
return x + x
93+
94+
inputs = [torch.randn(5, 5).long()]
95+
mod = TestModule().eval()
96+
ref_output = mod(*inputs)
97+
98+
mod = acc_tracer.trace(mod, inputs)
99+
interp = TRTInterpreter(
100+
mod,
101+
input_specs=InputTensorSpec.from_tensors(inputs),
102+
)
103+
trt_mod = TRTModule(*interp.run(lower_precision=LowerPrecision.FP32))
104+
st = trt_mod.state_dict()
105+
106+
new_trt_mod = TRTModule()
107+
new_trt_mod.load_state_dict(st)
108+
109+
torch.testing.assert_close(
110+
new_trt_mod(inputs[0].cuda()).cpu(),
111+
ref_output,
112+
rtol=1e-04,
113+
atol=1e-04,
114+
check_dtype=False,
115+
)
116+
117+
118+
class TestTRTModuleNext(TestCase):
119+
def test_save_and_load_trt_module(self):
120+
class TestModule(torch.nn.Module):
121+
def forward(self, x):
122+
return x + x
123+
124+
inputs = [torch.randn(1, 1)]
125+
mod = TestModule().eval()
126+
ref_output = mod(*inputs)
127+
128+
mod = acc_tracer.trace(mod, inputs)
129+
130+
interp = TRTInterpreter(
131+
mod,
132+
input_specs=InputTensorSpec.from_tensors(inputs),
133+
explicit_batch_dimension=True,
134+
)
135+
interp_res = interp.run(lower_precision=LowerPrecision.FP32)
136+
137+
with io.BytesIO() as engine_bytes:
138+
engine_bytes.write(interp_res.engine.serialize())
139+
engine_str = engine_bytes.getvalue()
140+
141+
trt_mod = TRTModuleNext(
142+
name="TestModule",
143+
serialized_engine=engine_str,
144+
input_binding_names=interp_res.input_names,
145+
output_binding_names=interp_res.output_names,
146+
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
147+
)
148+
149+
torch.save(trt_mod, "trt.pt")
150+
reload_trt_mod = torch.load("trt.pt")
151+
152+
torch.testing.assert_allclose(
153+
reload_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output),
154+
ref_output,
155+
rtol=1e-04,
156+
atol=1e-04,
157+
)
158+
os.remove(f"{os.getcwd()}/trt.pt")
159+
160+
def test_save_and_load_state_dict(self):
161+
class TestModule(torch.nn.Module):
162+
def forward(self, x):
163+
return x + x
164+
165+
inputs = [torch.randn(1, 1)]
166+
mod = TestModule().eval()
167+
ref_output = mod(*inputs)
168+
169+
mod = acc_tracer.trace(mod, inputs)
170+
interp = TRTInterpreter(
171+
mod,
172+
input_specs=InputTensorSpec.from_tensors(inputs),
173+
explicit_batch_dimension=True,
174+
)
175+
interp_res = interp.run(lower_precision=LowerPrecision.FP32)
176+
177+
with io.BytesIO() as engine_bytes:
178+
engine_bytes.write(interp_res.engine.serialize())
179+
engine_str = engine_bytes.getvalue()
180+
181+
trt_mod = TRTModuleNext(
182+
name="TestModule",
183+
serialized_engine=engine_str,
184+
input_binding_names=interp_res.input_names,
185+
output_binding_names=interp_res.output_names,
186+
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
187+
)
188+
189+
st = trt_mod.state_dict()
190+
191+
new_trt_mod = TRTModuleNext()
192+
new_trt_mod.load_state_dict(st)
193+
194+
torch.testing.assert_allclose(
195+
new_trt_mod(inputs[0].cuda()).cpu().reshape_as(ref_output),
196+
ref_output,
197+
rtol=1e-04,
198+
atol=1e-04,
199+
)
144200

145201

146202
if __name__ == "__main__":

py/torch_tensorrt/fx/trt_module.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,6 @@ def forward(self, *inputs):
137137

138138
# This is only used when the trt engine is using implicit batch dim.
139139
batch_size = inputs[0].shape[0]
140-
contiguous_inputs: List[torch.Tensor] = [i.contiguous() for i in inputs]
141140
bindings: List[Any] = [None] * (
142141
len(self.input_names)
143142
+ len(self.output_names)
@@ -148,16 +147,27 @@ def forward(self, *inputs):
148147
assert inputs[
149148
i
150149
].is_cuda, f"{i}th input({input_name}) is not on cuda device."
150+
151+
# Intercept int64 inputs to TRT Engines and cast them to int32
152+
if (
153+
inputs[i].dtype == torch.int64
154+
and self.input_dtypes[i] == torch.int32
155+
):
156+
inputs = (
157+
inputs[:i] + (inputs[i].to(torch.int32),) + inputs[i + 1 :]
158+
)
159+
151160
assert (
152161
inputs[i].dtype == self.input_dtypes[i]
153162
), f"Dtype mismatch for {i}th input({input_name}). Expect {self.input_dtypes[i]}, got {inputs[i].dtype}."
154163

164+
contiguous_input = inputs[i].contiguous()
155165
idx = self.input_binding_indices_in_order[i]
156-
bindings[idx] = contiguous_inputs[i].data_ptr()
166+
bindings[idx] = contiguous_input.data_ptr()
157167

158168
if not self.engine.has_implicit_batch_dimension:
159169
self.context.set_binding_shape(
160-
idx, tuple(contiguous_inputs[i].shape)
170+
idx, tuple(contiguous_input.shape)
161171
)
162172
else:
163173
assert inputs[i].size()[1:] == self.input_shapes[i], (

py/torch_tensorrt/fx/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# @manual=//deeplearning/trt/python:py_tensorrt
66
import tensorrt as trt
77
import torch
8+
import logging
89
from functorch import make_fx
910
from functorch.experimental import functionalize
1011
from torch_tensorrt.fx.passes.lower_basic_pass import (
@@ -15,6 +16,9 @@
1516
from .types import Shape, TRTDataType
1617

1718

19+
_LOGGER: logging.Logger = logging.getLogger(__name__)
20+
21+
1822
class LowerPrecision(Enum):
1923
FP32 = "fp32"
2024
FP16 = "fp16"
@@ -37,6 +41,11 @@ def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType:
3741
return trt.int8
3842
elif dtype == torch.int32:
3943
return trt.int32
44+
elif dtype == torch.int64:
45+
_LOGGER.warn(
46+
"Detected Int64 Input, Casting to Int32 for TRT Engine Compatibility"
47+
)
48+
return trt.int32
4049
elif dtype == torch.float16:
4150
return trt.float16
4251
elif dtype == torch.float32:

0 commit comments

Comments
 (0)