diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 8f6408492a..15e238a2d3 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -14,8 +14,6 @@ ) from torch_tensorrt.dynamo.backend.conversion import convert_module -from torch._dynamo.backends.common import fake_tensor_unsupported - from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler @@ -23,7 +21,6 @@ @td.register_backend(name="torch_tensorrt") -@fake_tensor_unsupported def torch_tensorrt_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], @@ -35,7 +32,6 @@ def torch_tensorrt_backend( @td.register_backend(name="aot_torch_tensorrt_aten") -@fake_tensor_unsupported def aot_torch_tensorrt_aten_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], @@ -55,7 +51,6 @@ def aot_torch_tensorrt_aten_backend( ) -@fake_tensor_unsupported def _pretraced_backend( gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], diff --git a/py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py b/py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py new file mode 100644 index 0000000000..17df523ab8 --- /dev/null +++ b/py/torch_tensorrt/dynamo/backend/test/test_specialized_models.py @@ -0,0 +1,114 @@ +from utils import lower_graph_testing +from torch.testing._internal.common_utils import run_tests, TestCase +import torch +from torch_tensorrt.dynamo import compile + + +class TestFakeTensors(TestCase): + def test_lowering_mul_int(self): + class MulInt(torch.nn.Module): + def forward(self, x): + return x * 7 + + # Operations expected to be included in the traced graph after decompositions + expected_ops = { + torch.ops.aten.mul.Tensor, + } + + inputs = [ + torch.rand( + 3, + 5, + 7, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(MulInt()) + _, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + msg=f"MulInt TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_lowering_add_float(self): + class AddFloat(torch.nn.Module): + def forward(self, x): + return x + 84.0 + + # Operations expected to be included in the traced graph after decompositions + expected_ops = { + torch.ops.aten.add.Tensor, + } + + inputs = [ + torch.rand( + 1, + 5, + 7, + 9, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(AddFloat()) + _, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, inputs, min_block_size=1, pass_through_build_failures=True + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + msg=f"AddFloat TRT outputs don't match with the original model.", + ) + + torch._dynamo.reset() + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index e4298600cb..4f58aec995 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -16,7 +16,12 @@ from torch_tensorrt.dynamo.fx_ts_compat import CONVERTERS from .input_tensor_spec import InputTensorSpec from torch_tensorrt.fx.observer import Observer -from torch_tensorrt.fx.utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt +from torch_tensorrt.fx.utils import ( + get_dynamic_dims, + LowerPrecision, + unified_dtype_converter, + Frameworks, +) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -321,7 +326,9 @@ def placeholder(self, target, args, kwargs): self.optimization_profiles[i].set_shape(target, *shape_range) return self.network.add_input( - name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype) + name=target, + shape=tuple(shape), + dtype=unified_dtype_converter(dtype, Frameworks.TRT), ) def call_module(self, target, args, kwargs): diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index f270ce3ea8..52168b4161 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -18,7 +18,7 @@ from torch.fx.immutable_collections import immutable_list from torch.fx.node import Argument, Target -from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt +from ..utils import get_dynamic_dims, unified_dtype_converter, Frameworks from .converter_utils import * # noqa: F403 from torch_tensorrt.fx.passes.lower_basic_pass import ( @@ -400,7 +400,7 @@ def acc_ops_pad_with_slice_layer( ) # cast value to TRTensor - dt = torch_dtype_from_trt(input_val.dtype) + dt = unified_dtype_converter(input_val.dtype, Frameworks.TORCH) value = 0 if value == None else value value_const = get_trt_tensor( network, torch.tensor([value], dtype=dt), f"{name}_value" @@ -1550,7 +1550,7 @@ def acc_ops_to_dtype( input_t = get_trt_tensor(network, input_val, f"{name}_input_t") if input_dtype: if isinstance(input_dtype, torch.dtype): - input_dtype = torch_dtype_to_trt(input_dtype) + input_dtype = unified_dtype_converter(input_dtype, Frameworks.TRT) input_t = type_cast(network, target, f"{name}_input", input_t, input_dtype) return input_t @@ -1811,7 +1811,7 @@ def acc_ops_logical_xor( # f"isinf received input {input_t} that is not part " # "of the TensorRT region!" # ) -# tdtype = torch_dtype_from_trt(input_t.dtype) +# tdtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH) # inf_t = torch.ones(tuple(input_t.shape)) # inf_t = inf_t * float("inf") @@ -1849,7 +1849,7 @@ def acc_ops_any( if input_t.dtype in (trt.float32, trt.float16, trt.int32): comp_t = torch.zeros(tuple([*input_t.shape])).to( - torch_dtype_from_trt(input_t.dtype) + unified_dtype_converter(input_t.dtype, Frameworks.TORCH) ) comp_t = get_trt_tensor(network, comp_t, f"{name}_comp_t") kwargs_new = {"input": input_t, "other": comp_t} @@ -2738,7 +2738,7 @@ def acc_ops_masked_fill_tensor( if type(value_t) is torch.Tensor: value_t = value_t.cpu().numpy() # cast to input type - input_dtype = torch_dtype_from_trt(input_t.dtype) + input_dtype = unified_dtype_converter(input_t.dtype, Frameworks.TORCH) value_t = (torch.ones(shape) * value_t).to(input_dtype) input_val = get_trt_tensor(network, input_t, f"{name}_input") value_val = get_trt_tensor(network, value_t, f"{name}_input") @@ -2872,7 +2872,11 @@ def add_clamp(network, input, val, op, name): # clamping scalar acc_ops_clamp_trt = get_trt_tensor( network, - squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))), + squeeze_left( + torch.tensor( + [val], dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH) + ) + ), f"{name}_clamp_{val}", ) else: @@ -2881,7 +2885,8 @@ def add_clamp(network, input, val, op, name): ( val * torch.ones( - acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype) + acc_ops_clamp_shape, + dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), ) ) .cpu() @@ -3527,7 +3532,9 @@ def acc_ops_cumsum( iterator = loop.add_iterator(input_val, dim, False) data = iterator.get_output(0) new_dims = tuple(data.shape) - zero_tensor = torch.zeros(new_dims, dtype=trt_dtype_to_torch_dtype(input_val.dtype)) + zero_tensor = torch.zeros( + new_dims, dtype=unified_dtype_converter(input_val.dtype, Frameworks.TORCH) + ) zero_tensor = network.add_constant( zero_tensor.shape, to_numpy(zero_tensor) ).get_output(0) @@ -3670,7 +3677,7 @@ def acc_ops_new_ones( dtype_val = kwargs.get("dtype") if dtype_val is None: dtype_val = input_val.dtype - dtype_val = torch_dtype_from_trt(dtype_val) + dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH) device_val = kwargs.get("device") assert ( @@ -3694,7 +3701,7 @@ def acc_ops_new_empty( dtype_val = kwargs.get("dtype") if dtype_val is None: dtype_val = input_val.dtype - dtype_val = torch_dtype_from_trt(dtype_val) + dtype_val = unified_dtype_converter(dtype_val, Frameworks.TORCH) device_val = kwargs.get("device") assert ( diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 4f93d98a26..42a988d0ad 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -18,8 +18,6 @@ from torch.fx.immutable_collections import immutable_list from torch.fx.node import Argument, Target -from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt - from .converter_utils import * # noqa: F403 import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils from torch_tensorrt.fx.converters.impl import activation diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index d13be41d05..2d4cdbef86 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -20,7 +20,7 @@ TRTPluginFieldCollection, TRTTensor, ) -from ..utils import torch_dtype_from_trt +from ..utils import unified_dtype_converter, Frameworks class SourceIR(Enum): @@ -151,28 +151,50 @@ def extend_mod_attr_to_tuple(mod: torch.nn.Module, name: str, size: int): return extend_attr_to_tuple(val, size) -def to_numpy(tensor: Optional[torch.Tensor]) -> Optional[np.ndarray]: +def to_numpy( + value: Optional[Union[torch.Tensor, np.ndarray, int, float]], + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, +) -> Optional[np.ndarray]: """ - Convert a PyTorch Tensor to a Numpy Array. If the tensor is + Convert a PyTorch Tensor, Numpy array, or scalar to a Numpy Array. If the tensor is quantized it will be dequantized first. Args: - tensor (Optional[torch.Tensor]): A PyTorch tensor or None. + value (Optional[Union[torch.Tensor, np.ndarray, int, float]]): + A PyTorch tensor, Numpy array, int, or float + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If a dtype is given, we will convert the type of the given `value` to this dtype. Returns: A Numpy array. """ + output = None - if tensor is None: - return tensor + if value is None or isinstance(value, np.ndarray): + output = value - assert isinstance( - tensor, torch.Tensor - ), f"to_numpy can only be called on None or a torch.Tensor, got: {tensor}" - if tensor.is_quantized: - tensor = tensor.dequantize() + elif isinstance(value, torch.Tensor): + if value.is_quantized: + value = value.dequantize() - return tensor.cpu().detach().contiguous().numpy() + output = value.cpu().detach().contiguous().numpy() + + elif isinstance(value, int): + output = np.array([value], dtype=np.int32) + + elif isinstance(value, float): + output = np.array([value], dtype=np.float32) + + else: + raise AssertionError( + f"to_numpy can only be called on None, int, float, np.ndarray, or torch.Tensor, got: {value}" + ) + + return ( + output + if dtype is None + else output.astype(unified_dtype_converter(dtype, Frameworks.NUMPY)) + ) def has_dynamic_shape(shape: Shape) -> bool: @@ -225,9 +247,9 @@ def get_axes_for_reduce_op( def create_constant( network: TRTNetwork, - value: Union[int, float, torch.Tensor], + value: Union[int, float, np.ndarray, torch.Tensor], name: str, - dtype: Optional[torch.dtype], + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]], ) -> TRTTensor: """ Add a TensorRT constant layer whose value is `value` to `network`. @@ -235,30 +257,28 @@ def create_constant( Args: network (TRTNetwork): A TensorRT network to which we want to add a constant layer. - value (Union[int, float, torch.Tensor]): A literal value or a PyTorch tensor - that will be used as value of the added TensorRT Constant layer. + value (Union[int, float, np.ndarray, torch.Tensor]): A literal value, Numpy array, + or a PyTorch tensor that will be used as value of the added TensorRT Constant layer. name (str): Name of the added TensorRT Constant layer. - dtype (Optional[torch.dtype]): If a dtype is given, we will convert the type - of the given `value` to this dtype. + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If a dtype is given, we will convert the type of the given `value` to this dtype. Returns: A TensorRT ITensor that represents the given value. """ - if isinstance(value, int): - value = torch.IntTensor([value]) - - if isinstance(value, float): - value = torch.Tensor([value]) - - if dtype: - value = value.to(dtype) - constant = network.add_constant(value.shape, to_numpy(value)) + constant = network.add_constant( + (1,) if isinstance(value, (int, float)) else value.shape, + to_numpy(value, dtype), + ) constant.name = name return constant.get_output(0) def get_trt_tensor( - network: TRTNetwork, input_val: Any, name: str, dtype: Optional[torch.dtype] = None + network: TRTNetwork, + input_val: Any, + name: str, + dtype: Optional[Union[torch.dtype, np.dtype, TRTDataType]] = None, ) -> TRTTensor: """ Given a value of random type, we try to convert it to a TensorRT ITensor. @@ -270,33 +290,36 @@ def get_trt_tensor( input_val (Any): An value that we want to convert to a TensorRT ITensor. name (str): The name of the created TensorRT Constant layer if there's one. - dtype (Optional[torch.dtype]): If dtype is provided, the given value - will be converted to this dtype. + dtype (Optional[Union[torch.dtype, np.dtype, TRTDataType]]): + If dtype is provided, the given value will be converted to this dtype. Returns: A TensorRT ITensor that represents the given value. """ # TRT can not add constant for bool type. We do a work around to 1) cast it to int and 2)cast to bool later # This is useful for logical operations which require input to be bool type - if isinstance(input_val, np.ndarray): - input_val = torch.from_numpy(input_val) if isinstance(input_val, bool): input_val = int(input_val) - if isinstance(input_val, torch.Tensor) and input_val.dtype == torch.bool: - input_val = input_val.to(torch.int32) - if isinstance(input_val, torch.Tensor) and input_val.dtype == torch.int64: + + if isinstance(input_val, torch.Tensor) and ( + input_val.dtype == torch.bool or input_val.dtype == torch.int64 + ): input_val = input_val.to(torch.int32) + elif isinstance(input_val, np.ndarray) and ( + input_val.dtype == np.bool_ or input_val.dtype == np.int64 + ): + input_val = input_val.to(np.int32) - if isinstance(input_val, (torch.Tensor, int, float)): + if isinstance(input_val, (torch.Tensor, np.ndarray, int, float)): return create_constant(network, input_val, name, dtype) - elif not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"Received input {input_val} of name {name} that " - "is not part of the TensorRT region!" - ) - else: + elif isinstance(input_val, TRTTensor): return input_val + raise RuntimeError( + f"Received input {input_val} of name {name} that " + "is not part of the TensorRT region!" + ) + def prepend_ones( network: TRTNetwork, @@ -478,10 +501,10 @@ def add_binary_elementwise_layer( is_rhs_trt_tensor = False if isinstance(lhs_val, TRTTensor): - lhs_dtype = torch_dtype_from_trt(lhs_val.dtype) + lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH) is_lhs_trt_tensor = True if isinstance(rhs_val, TRTTensor): - rhs_dtype = torch_dtype_from_trt(rhs_val.dtype) + rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH) is_rhs_trt_tensor = True if not is_lhs_trt_tensor and not is_rhs_trt_tensor: @@ -506,9 +529,13 @@ def add_binary_elementwise_layer( # dtype but we don't have a way to detect whether it makes sense for the # scalar to be float or half. Hence we go with the lhs dtype. if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): - rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) + rhs_val = np.array( + [rhs_val], dtype=unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY) + ) if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): - lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) + lhs_val = np.array( + [lhs_val], dtype=unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY) + ) # When lhs is scalar, and rhs has shape [1,], then currently the assert # will fail because lhs shape has fewer dimensions than rhs shape. This @@ -519,9 +546,9 @@ def add_binary_elementwise_layer( # tensor. This is safe because broadcast will pad dimensions on the left # (prepend) to make lhs and rhs shape compatible. if network.has_implicit_batch_dimension: - if isinstance(lhs_val, torch.Tensor): + if isinstance(lhs_val, (torch.Tensor, np.ndarray)): lhs_val = squeeze_left(lhs_val) - if isinstance(rhs_val, torch.Tensor): + if isinstance(rhs_val, (torch.Tensor, np.ndarray)): rhs_val = squeeze_left(rhs_val) lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) @@ -548,14 +575,19 @@ def add_binary_elementwise_layer( return output -def squeeze_left(const: torch.Tensor): +def squeeze_left(const: Union[torch.Tensor, np.ndarray]): """ Squeeze the size-1 dimensions on the left side of the shape tuple. PyTorch's `squeeze()` doesn't support passing multiple `dim`s at once, so we do it iteratively. """ while len(const.shape) > 0 and const.shape[0] == 1: - const = const.squeeze(dim=0) + if isinstance(const, torch.Tensor): + const = const.squeeze(dim=0) + elif isinstance(const, np.ndarray): + const = const.squeeze(axis=0) + else: + raise AssertionError(f"Expected torch Tensor or Numpy array, got: {const}") return const @@ -782,7 +814,10 @@ def trunc_div( input = get_trt_tensor(network, input, f"{name}_input") if not isinstance(other, trt.tensorrt.ITensor): other = get_trt_tensor( - network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype) + network, + other, + f"{name}_other", + dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), ) abs_input_output = add_unary_layer( @@ -871,13 +906,3 @@ def type_cast( layer_i.set_output_type(0, cast_type) set_layer_name(layer_i, target, f"{name}_dtype_change") return layer_i.get_output(0) - - -def trt_dtype_to_torch_dtype(trt_dtype): - table = { - trt.bool: torch.bool, - trt.int32: torch.int32, - trt.float16: torch.float16, - trt.float32: torch.float32, - } - return table[trt_dtype] diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 846c90bdd5..d7ef976fba 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -17,7 +17,7 @@ from .converter_registry import CONVERTERS from .input_tensor_spec import InputTensorSpec from .observer import Observer -from .utils import get_dynamic_dims, LowerPrecision, torch_dtype_to_trt +from .utils import get_dynamic_dims, LowerPrecision, unified_dtype_converter, Frameworks _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -306,7 +306,9 @@ def placeholder(self, target, args, kwargs): self.optimization_profiles[i].set_shape(target, *shape_range) return self.network.add_input( - name=target, shape=tuple(shape), dtype=torch_dtype_to_trt(dtype) + name=target, + shape=tuple(shape), + dtype=unified_dtype_converter(dtype, Frameworks.TRT), ) def call_module(self, target, args, kwargs): diff --git a/py/torch_tensorrt/fx/trt_module.py b/py/torch_tensorrt/fx/trt_module.py index 099bbfcdc9..ab2d9ac348 100644 --- a/py/torch_tensorrt/fx/trt_module.py +++ b/py/torch_tensorrt/fx/trt_module.py @@ -4,7 +4,7 @@ import tensorrt as trt import torch -from .utils import torch_dtype_from_trt +from .utils import unified_dtype_converter, Frameworks class TRTModule(torch.nn.Module): @@ -53,7 +53,9 @@ def _initialize(self): ) self.input_dtypes: Sequence[torch.dtype] = [ - torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) + unified_dtype_converter( + self.engine.get_binding_dtype(idx), Frameworks.TORCH + ) for idx in self.input_binding_indices_in_order ] self.input_shapes: Sequence[Sequence[int]] = [ @@ -61,7 +63,9 @@ def _initialize(self): for idx in self.input_binding_indices_in_order ] self.output_dtypes: Sequence[torch.dtype] = [ - torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) + unified_dtype_converter( + self.engine.get_binding_dtype(idx), Frameworks.TORCH + ) for idx in self.output_binding_indices_in_order ] self.output_shapes = [ @@ -71,7 +75,9 @@ def _initialize(self): for idx in self.output_binding_indices_in_order ] self.hidden_output_dtypes: Sequence[torch.dtype] = [ - torch_dtype_from_trt(self.engine.get_binding_dtype(idx)) + unified_dtype_converter( + self.engine.get_binding_dtype(idx), Frameworks.TORCH + ) for idx in self.hidden_output_binding_indices_in_order ] self.hidden_output_shapes = [ diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index a8a3851655..e70fc862d0 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import List, Optional, Callable +from typing import Dict, List, Optional, Callable, Union +import numpy as np from packaging import version # @manual=//deeplearning/trt/python:py_tensorrt @@ -15,6 +16,45 @@ from .types import Shape, TRTDataType +class Frameworks(Enum): + NUMPY = "numpy" + TORCH = "torch" + TRT = "trt" + + +DataTypeEquivalence: Dict[ + TRTDataType, Dict[Frameworks, Union[TRTDataType, np.dtype, torch.dtype]] +] = { + trt.int8: { + Frameworks.NUMPY: np.int8, + Frameworks.TORCH: torch.int8, + Frameworks.TRT: trt.int8, + }, + trt.int32: { + Frameworks.NUMPY: np.int32, + Frameworks.TORCH: torch.int32, + Frameworks.TRT: trt.int32, + }, + trt.float16: { + Frameworks.NUMPY: np.float16, + Frameworks.TORCH: torch.float16, + Frameworks.TRT: trt.float16, + }, + trt.float32: { + Frameworks.NUMPY: np.float32, + Frameworks.TORCH: torch.float32, + Frameworks.TRT: trt.float32, + }, +} + +if trt.__version__ >= "7.0": + DataTypeEquivalence[trt.bool] = { + Frameworks.NUMPY: np.bool_, + Frameworks.TORCH: torch.bool, + Frameworks.TRT: trt.bool, + } + + class LowerPrecision(Enum): FP32 = "fp32" FP16 = "fp16" @@ -35,52 +75,33 @@ def from_str(label: str) -> Optional["LowerPrecision"]: return None -def torch_dtype_to_trt(dtype: torch.dtype) -> TRTDataType: - """ - Convert PyTorch data types to TensorRT data types. - - Args: - dtype (torch.dtype): A PyTorch data type. - - Returns: - The equivalent TensorRT data type. - """ - if trt.__version__ >= "7.0" and dtype == torch.bool: - return trt.bool - elif dtype == torch.int8: - return trt.int8 - elif dtype == torch.int32: - return trt.int32 - elif dtype == torch.float16: - return trt.float16 - elif dtype == torch.float32: - return trt.float32 - else: - raise TypeError("%s is not supported by tensorrt" % dtype) - - -def torch_dtype_from_trt(dtype: TRTDataType) -> torch.dtype: +def unified_dtype_converter( + dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks +) -> Union[np.dtype, torch.dtype, TRTDataType]: """ - Convert TensorRT data types to PyTorch data types. + Convert TensorRT, Numpy, or Torch data types to any other of those data types. Args: - dtype (TRTDataType): A TensorRT data type. + dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type. + to (Frameworks): The framework to convert the data type to. Returns: - The equivalent PyTorch data type. + The equivalent data type in the requested framework. """ - if dtype == trt.int8: - return torch.int8 - elif trt.__version__ >= "7.0" and dtype == trt.bool: - return torch.bool - elif dtype == trt.int32: - return torch.int32 - elif dtype == trt.float16: - return torch.float16 - elif dtype == trt.float32: - return torch.float32 + assert to in Frameworks, f"Expected valid Framework for translation, got {to}" + + if dtype in (np.int8, torch.int8, trt.int8): + return DataTypeEquivalence[trt.int8][to] + elif trt.__version__ >= "7.0" and dtype in (np.bool_, torch.bool, trt.bool): + return DataTypeEquivalence[trt.bool][to] + elif dtype in (np.int32, torch.int32, trt.int32): + return DataTypeEquivalence[trt.int32][to] + elif dtype in (np.float16, torch.float16, trt.float16): + return DataTypeEquivalence[trt.float16][to] + elif dtype in (np.float32, torch.float32, trt.float32): + return DataTypeEquivalence[trt.float32][to] else: - raise TypeError("%s is not supported by torch" % dtype) + raise TypeError("%s is not a supported dtype" % dtype) def get_dynamic_dims(shape: Shape) -> List[int]: