Skip to content

Commit 38e80b1

Browse files
committed
fix: Add support for fake tensors
- Refactor `to_numpy` function to handle non-tensor inputs, avoiding fake tensor issue during compilation of constants - Add regression test case to elicit behavior
1 parent 15b7765 commit 38e80b1

File tree

3 files changed

+130
-20
lines changed

3 files changed

+130
-20
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,13 @@
1414
)
1515
from torch_tensorrt.dynamo.backend.conversion import convert_module
1616

17-
from torch._dynamo.backends.common import fake_tensor_unsupported
18-
1917
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2018

2119

2220
logger = logging.getLogger(__name__)
2321

2422

2523
@td.register_backend(name="torch_tensorrt")
26-
@fake_tensor_unsupported
2724
def torch_tensorrt_backend(
2825
gm: torch.fx.GraphModule,
2926
sample_inputs: Sequence[torch.Tensor],
@@ -35,7 +32,6 @@ def torch_tensorrt_backend(
3532

3633

3734
@td.register_backend(name="aot_torch_tensorrt_aten")
38-
@fake_tensor_unsupported
3935
def aot_torch_tensorrt_aten_backend(
4036
gm: torch.fx.GraphModule,
4137
sample_inputs: Sequence[torch.Tensor],
@@ -55,7 +51,6 @@ def aot_torch_tensorrt_aten_backend(
5551
)
5652

5753

58-
@fake_tensor_unsupported
5954
def _pretraced_backend(
6055
gm: torch.fx.GraphModule,
6156
sample_inputs: Sequence[torch.Tensor],
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from utils import lower_graph_testing
2+
from torch.testing._internal.common_utils import run_tests, TestCase
3+
import torch
4+
from torch_tensorrt.dynamo import compile
5+
6+
7+
class TestFakeTensors(TestCase):
8+
def test_lowering_mul_int(self):
9+
class MulInt(torch.nn.Module):
10+
def forward(self, x):
11+
return x * 7
12+
13+
# Operations expected to be included in the traced graph after decompositions
14+
expected_ops = {
15+
torch.ops.aten.mul.Tensor,
16+
}
17+
18+
inputs = [
19+
torch.rand(
20+
3,
21+
5,
22+
7,
23+
).cuda(),
24+
]
25+
26+
fx_graph = torch.fx.symbolic_trace(MulInt())
27+
_, expected_ops_unseen = lower_graph_testing(
28+
fx_graph,
29+
inputs,
30+
expected_ops=expected_ops,
31+
min_block_size=1,
32+
)
33+
34+
self.assertEquals(
35+
len(expected_ops_unseen),
36+
0,
37+
f"The following expected ops were not encountered: {expected_ops_unseen}",
38+
)
39+
40+
torch._dynamo.reset()
41+
42+
# Validate that the results between Torch and Torch-TRT are similar
43+
optimized_model = compile(
44+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
45+
)
46+
optimized_model_results = optimized_model(*inputs).detach().cpu()
47+
torch_model_results = fx_graph(*inputs).detach().cpu()
48+
49+
max_diff = float(
50+
torch.max(torch.abs(optimized_model_results - torch_model_results))
51+
)
52+
self.assertAlmostEqual(
53+
max_diff,
54+
0,
55+
msg=f"MulInt TRT outputs don't match with the original model.",
56+
)
57+
58+
def test_lowering_add_float(self):
59+
class AddFloat(torch.nn.Module):
60+
def forward(self, x):
61+
return x + 84.0
62+
63+
# Operations expected to be included in the traced graph after decompositions
64+
expected_ops = {
65+
torch.ops.aten.add.Tensor,
66+
}
67+
68+
inputs = [
69+
torch.rand(
70+
1,
71+
5,
72+
7,
73+
9,
74+
).cuda(),
75+
]
76+
77+
fx_graph = torch.fx.symbolic_trace(AddFloat())
78+
_, expected_ops_unseen = lower_graph_testing(
79+
fx_graph,
80+
inputs,
81+
expected_ops=expected_ops,
82+
min_block_size=1,
83+
)
84+
85+
self.assertEquals(
86+
len(expected_ops_unseen),
87+
0,
88+
f"The following expected ops were not encountered: {expected_ops_unseen}",
89+
)
90+
91+
torch._dynamo.reset()
92+
93+
# Validate that the results between Torch and Torch-TRT are similar
94+
optimized_model = compile(
95+
fx_graph, inputs, min_block_size=1, pass_through_build_failures=True
96+
)
97+
optimized_model_results = optimized_model(*inputs).detach().cpu()
98+
torch_model_results = fx_graph(*inputs).detach().cpu()
99+
100+
max_diff = float(
101+
torch.max(torch.abs(optimized_model_results - torch_model_results))
102+
)
103+
self.assertAlmostEqual(
104+
max_diff,
105+
0,
106+
msg=f"AddFloat TRT outputs don't match with the original model.",
107+
)
108+
109+
110+
if __name__ == "__main__":
111+
run_tests()

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -151,28 +151,37 @@ def extend_mod_attr_to_tuple(mod: torch.nn.Module, name: str, size: int):
151151
return extend_attr_to_tuple(val, size)
152152

153153

154-
def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]:
154+
def to_numpy(value: Optional[Union[torch.Tensor, int, float]]) -> Optional[np.ndarray]:
155155
"""
156156
Convert a PyTorch Tensor to a Numpy Array. If the tensor is
157157
quantized it will be dequantized first.
158158
159159
Args:
160-
tensor (Optional[torch.Tensor]): A PyTorch tensor or None.
160+
value (Optional[Union[torch.Tensor, int, float]]): A PyTorch tensor, int, or float
161161
162162
Returns:
163163
A Numpy array.
164164
"""
165165

166-
if tensor is None:
167-
return tensor
166+
if value is None:
167+
return value
168168

169-
assert isinstance(
170-
tensor, torch.Tensor
171-
), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}"
172-
if tensor.is_quantized:
173-
tensor = tensor.dequantize()
169+
elif isinstance(value, torch.Tensor):
170+
if value.is_quantized:
171+
value = value.dequantize()
174172

175-
return tensor.cpu().detach().contiguous().numpy()
173+
return value.cpu().detach().contiguous().numpy()
174+
175+
elif isinstance(value, int):
176+
return np.array([value], dtype=np.int32)
177+
178+
elif isinstance(value, float):
179+
return np.array([value], dtype=np.float32)
180+
181+
else:
182+
raise AssertionError(
183+
f"to_numpy can only be called on None, int, float, or torch.Tensor, got: {value}"
184+
)
176185

177186

178187
def has_dynamic_shape(shape: Shape) -> bool:
@@ -244,11 +253,6 @@ def create_constant(
244253
Returns:
245254
A TensorRT ITensor that represents the given value.
246255
"""
247-
if isinstance(value, int):
248-
value = torch.IntTensor([value])
249-
250-
if isinstance(value, float):
251-
value = torch.Tensor([value])
252256

253257
if dtype:
254258
value = value.to(dtype)

0 commit comments

Comments
 (0)