Skip to content

Commit 510552b

Browse files
committed
fix: Refactor data type handling in FX
1 parent e1555bc commit 510552b

File tree

8 files changed

+172
-112
lines changed

8 files changed

+172
-112
lines changed

py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def forward(self, x):
5454
0,
5555
msg=f"MulInt TRT outputs don't match with the original model.",
5656
)
57+
torch._dynamo.reset()
5758

5859
def test_lowering_add_float(self):
5960
class AddFloat(torch.nn.Module):
@@ -106,6 +107,8 @@ def forward(self, x):
106107
msg=f"AddFloat TRT outputs don't match with the original model.",
107108
)
108109

110+
torch._dynamo.reset()
111+
109112

110113
if __name__ == "__main__":
111114
run_tests()

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,12 @@
1616
from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS
1717
from .input_tensor_spec import InputTensorSpec
1818
from torch_tensorrt.fx.observer import Observer
19-
from torch_tensorrt.fx.utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt
19+
from torch_tensorrt.fx.utils import (
20+
get_dynamic_dims,
21+
LowerPrecision,
22+
unified_dtype_converter,
23+
Frameworks,
24+
)
2025

2126
_LOGGER: logging.Logger = logging.getLogger(__name__)
2227

@@ -305,7 +310,9 @@ def placeholder(self, target, args, kwargs):
305310
self.optimization_profiles[i].set_shape(target, *shape_range)
306311

307312
return self.network.add_input(
308-
name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype)
313+
name=target,
314+
shape=tuple(shape),
315+
dtype=unified_dtype_converter(dtype, Frameworks.TRT),
309316
)
310317

311318
def call_module(self, target, args, kwargs):

py/torch_tensorrt/fx/converters/acc_ops_converters.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from torch.fx.immutable_collections import immutable_list
1919
from torch.fx.node import Argument, Target
2020

21-
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
21+
from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks
2222

2323
from .converter_utils import * # noqa: F403
2424
from torch_tensorrt.fx.passes.lower_basic_pass import (
@@ -400,7 +400,7 @@ def acc_ops_pad_with_slice_layer(
400400
)
401401

402402
# cast value to TRTensor
403-
dt = torch_dtype_from_trt(input_val.dtype)
403+
dt = unified_dtype_converter(input_val.dtype, Frameworks.TORCH)
404404
value = 0 if value == None else value
405405
value_const = get_trt_tensor(
406406
network, torch.tensor([value], dtype=dt), f"{name}_value"
@@ -1561,7 +1561,7 @@ def acc_ops_to_dtype(
15611561
input_t = get_trt_tensor(network, input_val, f"{name}_input_t")
15621562
if input_dtype:
15631563
if isinstance(input_dtype, torch.dtype):
1564-
input_dtype = torch_dtype_to_trt(input_dtype)
1564+
input_dtype = unified_dtype_converter(input_dtype, Frameworks.TRT)
15651565
input_t = type_cast(network, target, f"{name}_input", input_t, input_dtype)
15661566
return input_t
15671567

@@ -1822,7 +1822,7 @@ def acc_ops_logical_xor(
18221822
# f"isinf received input {input_t} that is not part "
18231823
# "of the TensorRT region!"
18241824
# )
1825-
# tdtype = torch_dtype_from_trt(input_t.dtype)
1825+
# tdtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
18261826

18271827
# inf_t = torch.ones(tuple(input_t.shape))
18281828
# inf_t = inf_t * float("inf")
@@ -1860,7 +1860,7 @@ def acc_ops_any(
18601860

18611861
if input_t.dtype in (trt.float32, trt.float16, trt.int32):
18621862
comp_t = torch.zeros(tuple([*input_t.shape])).to(
1863-
torch_dtype_from_trt(input_t.dtype)
1863+
unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
18641864
)
18651865
comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t")
18661866
kwargs_new = {"input": input_t, "other": comp_t}
@@ -2749,7 +2749,7 @@ def acc_ops_masked_fill_tensor(
27492749
if type(value_t) is torch.Tensor:
27502750
value_t = value_t.cpu().numpy()
27512751
# cast to input type
2752-
input_dtype = torch_dtype_from_trt(input_t.dtype)
2752+
input_dtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH)
27532753
value_t = (torch.ones(shape) * value_t).to(input_dtype)
27542754
input_val = get_trt_tensor(network, input_t, f"{name}_input")
27552755
value_val = get_trt_tensor(network, value_t, f"{name}_input")
@@ -2883,7 +2883,11 @@ def add_clamp(network, input, val, op, name):
28832883
# clamping scalar
28842884
acc_ops_clamp_trt = get_trt_tensor(
28852885
network,
2886-
squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))),
2886+
squeeze_left(
2887+
torch.tensor(
2888+
[val], dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH)
2889+
)
2890+
),
28872891
f"{name}_clamp_{val}",
28882892
)
28892893
else:
@@ -2892,7 +2896,8 @@ def add_clamp(network, input, val, op, name):
28922896
(
28932897
val
28942898
* torch.ones(
2895-
acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)
2899+
acc_ops_clamp_shape,
2900+
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
28962901
)
28972902
)
28982903
.cpu()
@@ -3538,7 +3543,9 @@ def acc_ops_cumsum(
35383543
iterator = loop.add_iterator(input_val, dim, False)
35393544
data = iterator.get_output(0)
35403545
new_dims = tuple(data.shape)
3541-
zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype))
3546+
zero_tensor = torch.zeros(
3547+
new_dims, dtype=unified_dtype_converter(input_val.dtype, Frameworks.TORCH)
3548+
)
35423549
zero_tensor = network.add_constant(
35433550
zero_tensor.shape, to_numpy(zero_tensor)
35443551
).get_output(0)
@@ -3689,7 +3696,7 @@ def acc_ops_new_ones(
36893696
dtype_val = kwargs.get("dtype")
36903697
if dtype_val is None:
36913698
dtype_val = input_val.dtype
3692-
dtype_val = torch_dtype_from_trt(dtype_val)
3699+
dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH)
36933700

36943701
device_val = kwargs.get("device")
36953702
assert (
@@ -3713,7 +3720,7 @@ def acc_ops_new_empty(
37133720
dtype_val = kwargs.get("dtype")
37143721
if dtype_val is None:
37153722
dtype_val = input_val.dtype
3716-
dtype_val = torch_dtype_from_trt(dtype_val)
3723+
dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH)
37173724

37183725
device_val = kwargs.get("device")
37193726
assert (

py/torch_tensorrt/fx/converters/aten_ops_converters.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from torch.fx.immutable_collections import immutable_list
1919
from torch.fx.node import Argument, Target
2020

21-
from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt
22-
2321
from .converter_utils import * # noqa: F403
2422
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
2523
from torch_tensorrt.fx.converters.impl import activation

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 67 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
TRTPluginFieldCollection,
2121
TRTTensor,
2222
)
23-
from ..utils import torch_dtype_from_trt
23+
from ..utils import unified_dtype_converter, Frameworks
2424

2525

2626
class SourceIR(Enum):
@@ -151,38 +151,49 @@ def extend_mod_attr_to_tuple(mod: torch.nn.Module, name: str, size: int):
151151
return extend_attr_to_tuple(val, size)
152152

153153

154-
def to_numpy(value: Optional[Union[torch.Tensor, int, float]]) -> Optional[np.ndarray]:
154+
def to_numpy(
155+
value: Optional[Union[torch.Tensor, np.ndarray, int, float]],
156+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None,
157+
) -> Optional[np.ndarray]:
155158
"""
156159
Convert a PyTorch Tensor to a Numpy Array. If the tensor is
157160
quantized it will be dequantized first.
158161
159162
Args:
160-
value (Optional[Union[torch.Tensor, int, float]]): A PyTorch tensor, int, or float
163+
value (Optional[Union[torch.Tensor, np.ndarray, int, float]]):
164+
A PyTorch tensor, Numpy array, int, or float
161165
162166
Returns:
163167
A Numpy array.
164168
"""
169+
output = None
165170

166-
if value is None:
167-
return value
171+
if value is None or isinstance(value, np.ndarray):
172+
output = value
168173

169174
elif isinstance(value, torch.Tensor):
170175
if value.is_quantized:
171176
value = value.dequantize()
172177

173-
return value.cpu().detach().contiguous().numpy()
178+
output = value.cpu().detach().contiguous().numpy()
174179

175180
elif isinstance(value, int):
176-
return np.array([value], dtype=np.int32)
181+
output = np.array([value], dtype=np.int32)
177182

178183
elif isinstance(value, float):
179-
return np.array([value], dtype=np.float32)
184+
output = np.array([value], dtype=np.float32)
180185

181186
else:
182187
raise AssertionError(
183-
f"to_numpy can only be called on None, int, float, or torch.Tensor, got: {value}"
188+
f"to_numpy can only be called on None, int, float, np.ndarray, or torch.Tensor, got: {value}"
184189
)
185190

191+
return (
192+
output
193+
if dtype is None
194+
else output.astype(unified_dtype_converter(dtype, Frameworks.NUMPY))
195+
)
196+
186197

187198
def has_dynamic_shape(shape: Shape) -> bool:
188199
"""
@@ -234,35 +245,35 @@ def get_axes_for_reduce_op(
234245

235246
def create_constant(
236247
network: TRTNetwork,
237-
value: Union[int, float, torch.Tensor],
248+
value: Union[int, float, np.ndarray, torch.Tensor],
238249
name: str,
239-
dtype: Optional[torch.dtype],
250+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]],
240251
) -> TRTTensor:
241252
"""
242253
Add a TensorRT constant layer whose value is `value` to `network`.
243254
244255
Args:
245256
network (TRTNetwork): A TensorRT network to which we want to add
246257
a constant layer.
247-
value (Union[int, float, torch.Tensor]): A literal value or a PyTorch tensor
248-
that will be used as value of the added TensorRT Constant layer.
258+
value (Union[int, float, np.ndarray, torch.Tensor]): A literal value, Numpy array,
259+
or a PyTorch tensor that will be used as value of the added TensorRT Constant layer.
249260
name (str): Name of the added TensorRT Constant layer.
250-
dtype (Optional[torch.dtype]): If a dtype is given, we will convert the type
251-
of the given `value` to this dtype.
261+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
262+
If a dtype is given, we will convert the type of the given `value` to this dtype.
252263
253264
Returns:
254265
A TensorRT ITensor that represents the given value.
255266
"""
256-
257-
if dtype:
258-
value = value.to(dtype)
259-
constant = network.add_constant(value.shape, to_numpy(value))
267+
constant = network.add_constant(value.shape, to_numpy(value, dtype))
260268
constant.name = name
261269
return constant.get_output(0)
262270

263271

264272
def get_trt_tensor(
265-
network: TRTNetwork, input_val: Any, name: str, dtype: Optional[torch.dtype] = None
273+
network: TRTNetwork,
274+
input_val: Any,
275+
name: str,
276+
dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None,
266277
) -> TRTTensor:
267278
"""
268279
Given a value of random type, we try to convert it to a TensorRT ITensor.
@@ -274,33 +285,36 @@ def get_trt_tensor(
274285
input_val (Any): An value that we want to convert to a TensorRT ITensor.
275286
name (str): The name of the created TensorRT Constant layer if there's
276287
one.
277-
dtype (Optional[torch.dtype]): If dtype is provided, the given value
278-
will be converted to this dtype.
288+
dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]):
289+
If dtype is provided, the given value will be converted to this dtype.
279290
280291
Returns:
281292
A TensorRT ITensor that represents the given value.
282293
"""
283294
# TRT can not add constant for bool type. We do a work around to 1) cast it to int and 2)cast to bool later
284295
# This is useful for logical operations which require input to be bool type
285-
if isinstance(input_val, np.ndarray):
286-
input_val = torch.from_numpy(input_val)
287296
if isinstance(input_val, bool):
288297
input_val = int(input_val)
289-
if isinstance(input_val, torch.Tensor) and input_val.dtype == torch.bool:
290-
input_val = input_val.to(torch.int32)
291-
if isinstance(input_val, torch.Tensor) and input_val.dtype == torch.int64:
298+
299+
if isinstance(input_val, torch.Tensor) and (
300+
input_val.dtype == torch.bool or input_val.dtype == torch.int64
301+
):
292302
input_val = input_val.to(torch.int32)
303+
elif isinstance(input_val, np.ndarray) and (
304+
input_val.dtype == np.bool or input_val.dtype == np.int64
305+
):
306+
input_val = input_val.to(np.int32)
293307

294-
if isinstance(input_val, (torch.Tensor, int, float)):
308+
if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)):
295309
return create_constant(network, input_val, name, dtype)
296-
elif not isinstance(input_val, TRTTensor):
297-
raise RuntimeError(
298-
f"Received input {input_val} of name {name} that "
299-
"is not part of the TensorRT region!"
300-
)
301-
else:
310+
elif isinstance(input_val, TRTTensor):
302311
return input_val
303312

313+
raise RuntimeError(
314+
f"Received input {input_val} of name {name} that "
315+
"is not part of the TensorRT region!"
316+
)
317+
304318

305319
def prepend_ones(
306320
network: TRTNetwork,
@@ -482,10 +496,10 @@ def add_binary_elementwise_layer(
482496
is_rhs_trt_tensor = False
483497

484498
if isinstance(lhs_val, TRTTensor):
485-
lhs_dtype = torch_dtype_from_trt(lhs_val.dtype)
499+
lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH)
486500
is_lhs_trt_tensor = True
487501
if isinstance(rhs_val, TRTTensor):
488-
rhs_dtype = torch_dtype_from_trt(rhs_val.dtype)
502+
rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH)
489503
is_rhs_trt_tensor = True
490504

491505
if not is_lhs_trt_tensor and not is_rhs_trt_tensor:
@@ -510,9 +524,13 @@ def add_binary_elementwise_layer(
510524
# dtype but we don't have a way to detect whether it makes sense for the
511525
# scalar to be float or half. Hence we go with the lhs dtype.
512526
if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)):
513-
rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype)
527+
rhs_val = np.array(
528+
[rhs_val], dtype=unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY)
529+
)
514530
if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)):
515-
lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype)
531+
lhs_val = np.array(
532+
[lhs_val], dtype=unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY)
533+
)
516534

517535
# When lhs is scalar, and rhs has shape [1,], then currently the assert
518536
# will fail because lhs shape has fewer dimensions than rhs shape. This
@@ -552,14 +570,19 @@ def add_binary_elementwise_layer(
552570
return output
553571

554572

555-
def squeeze_left(const: torch.Tensor):
573+
def squeeze_left(const: Union[torch.Tensor, np.ndarray]):
556574
"""
557575
Squeeze the size-1 dimensions on the left side of the shape tuple.
558576
PyTorch's `squeeze()` doesn't support passing multiple `dim`s at once, so
559577
we do it iteratively.
560578
"""
561579
while len(const.shape) > 0 and const.shape[0] == 1:
562-
const = const.squeeze(dim=0)
580+
if isinstance(const, torch.Tensor):
581+
const = const.squeeze(dim=0)
582+
elif isinstance(const, np.ndarray):
583+
const = const.squeeze(axis=0)
584+
else:
585+
raise AssertionError(f"Expected torch Tensor or Numpy array, got: {const}")
563586
return const
564587

565588

@@ -786,7 +809,10 @@ def trunc_div(
786809
input = get_trt_tensor(network, input, f"{name}_input")
787810
if not isinstance(other, trt.tensorrt.ITensor):
788811
other = get_trt_tensor(
789-
network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype)
812+
network,
813+
other,
814+
f"{name}_other",
815+
dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH),
790816
)
791817

792818
abs_input_output = add_unary_layer(
@@ -875,13 +901,3 @@ def type_cast(
875901
layer_i.set_output_type(0, cast_type)
876902
set_layer_name(layer_i, target, f"{name}_dtype_change")
877903
return layer_i.get_output(0)
878-
879-
880-
def trt_dtype_to_torch_dtype(trt_dtype):
881-
table = {
882-
trt.bool: torch.bool,
883-
trt.int32: torch.int32,
884-
trt.float16: torch.float16,
885-
trt.float32: torch.float32,
886-
}
887-
return table[trt_dtype]

0 commit comments

Comments
 (0)