diff --git a/.circleci/config.yml b/.circleci/config.yml index ae4261ac43..d46c695678 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -780,6 +780,21 @@ commands: - store_artifacts: path: /tmp/testlogs + test-dynamo-converters: + description: "Test the Dynamo aten converters" + steps: + - run: + name: Run Dynamo converter tests + command: | + cd tests/py/dynamo/converters + TESTS_TO_RUN=$(circleci tests glob "test_*.py" | circleci tests split --split-by=timings) + pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/test_results.xml $TESTS_TO_RUN + + - store_test_results: + path: /tmp/artifacts + - store_artifacts: + path: /tmp/testlogs + # =================== Dynamo tests end ======================== # # Define a job to be invoked later in a workflow. @@ -1036,6 +1051,7 @@ jobs: command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl # We install torch after torch-trt because pip automatically enforces the version constraint otherwise - dump-test-env + - test-dynamo-converters - test-dynamo-torch_compile - test-dynamo-models_torch_compile - test-dynamo-models_torch_export diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 1ea87c5a4e..8d3a842a47 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -68,6 +68,17 @@ def __init__(self, *args, **kwargs): - Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW) - Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW """ + # Compatibility code for switching over from InputTensorSpec + if "shape" in kwargs and "shape_ranges" in kwargs: + assert ( + len(kwargs["shape_ranges"]) == 1 and len(kwargs["shape_ranges"][0]) == 3 + ) + del kwargs["shape"] + + kwargs["min_shape"] = kwargs["shape_ranges"][0][0] + kwargs["opt_shape"] = kwargs["shape_ranges"][0][1] + kwargs["max_shape"] = kwargs["shape_ranges"][0][2] + if len(args) == 1: if not Input._supported_input_size_type(args[0]): raise TypeError( diff --git a/py/torch_tensorrt/dynamo/__init__.py b/py/torch_tensorrt/dynamo/__init__.py index 63a3308fe2..5918bad806 100644 --- a/py/torch_tensorrt/dynamo/__init__.py +++ b/py/torch_tensorrt/dynamo/__init__.py @@ -3,8 +3,9 @@ if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"): from ._settings import * + from .conversion import * from .aten_tracer import trace - from .converter_registry import ( + from .conversion.converter_registry import ( DYNAMO_CONVERTERS, dynamo_tensorrt_converter, ) diff --git a/py/torch_tensorrt/dynamo/conversion/SourceIR.py b/py/torch_tensorrt/dynamo/conversion/SourceIR.py new file mode 100644 index 0000000000..c0547986c4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/SourceIR.py @@ -0,0 +1,24 @@ +from enum import Enum, auto + + +class SourceIR(Enum): + NN = auto() + ACC = auto() + ATEN = auto() + PRIM = auto() + TORCHTRT_LOWERED = auto() + UNKNOWN = auto() + + def __str__(self): + if self == SourceIR.NN: + return "nn" + elif self == SourceIR.ACC: + return "acc" + elif self == SourceIR.ATEN: + return "aten" + elif self == SourceIR.PRIM: + return "prim" + elif self == SourceIR.TORCHTRT_LOWERED: + return "torchtrt_lowered" + else: + return "unknown_ir" diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index f50b22f27d..d201665a5b 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,3 +1,5 @@ +from .SourceIR import SourceIR +from .aten_ops_converters import * from .trt_interpreter import * from .conversion import * from .truncate_long_and_double import repair_long_or_double_inputs diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py new file mode 100644 index 0000000000..38f8692852 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -0,0 +1,378 @@ +import logging +from typing import Dict, Sequence, Tuple, Union +import torch +import tensorrt as trt +from torch_tensorrt.fx.converters import acc_ops_converters +from .converter_registry import dynamo_tensorrt_converter +from torch.fx.node import Argument, Target, Node + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR, impl +from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import cast_int_int_div_trt_tensor + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def args_bounds_check(args, i, replacement=None): + return args[i] if len(args) > i else replacement + + +@dynamo_tensorrt_converter(torch.ops.aten.batch_norm) +def aten_ops_batch_norm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.batch_norm( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[6], + args[7], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.div.default) +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode) +@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor) +def aten_ops_div( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + # If both are TRTTensor, both are cast to float32 + if isinstance(args[0], TRTTensor) and isinstance(args[1], TRTTensor): + kwargs_new["input"], kwargs_new["other"] = cast_int_int_div_trt_tensor( + network, + kwargs_new["input"], + kwargs_new["other"], + name, + ) + # If one is TRTTensor, it is cast to float32 + elif isinstance(args[0], TRTTensor) and ( + kwargs_new["input"].dtype == trt.int8 or kwargs_new["input"].dtype == trt.int32 + ): + kwargs_new["input"] = cast_trt_tensor( + network, kwargs_new["input"], trt.float32, name + ) + elif isinstance(args[1], TRTTensor) and ( + kwargs_new["other"].dtype == trt.int8 or kwargs_new["other"].dtype == trt.int32 + ): + kwargs_new["other"] = cast_trt_tensor( + network, kwargs_new["other"], trt.float32, name + ) + rounding_mode = kwargs.get("rounding_mode") + if rounding_mode is None: + return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name) + elif rounding_mode == "floor": + return acc_ops_converters.acc_ops_floor_div( + network, target, None, kwargs_new, name + ) + elif rounding_mode == "trunc": + return impl.elementwise.trunc_div( + network, target, SourceIR.ATEN, name, args[0], args[1] + ) + else: + raise RuntimeError( + f"Target {target} does not support rounding mode {rounding_mode}" + ) + + +def embedding_param_validator(embedding_node: Node): + + max_norm = args_bounds_check(embedding_node.args, 2) + norm_type = args_bounds_check(embedding_node.args, 3) + scale_grad_by_freq = args_bounds_check(embedding_node.args, 4) + sparse = args_bounds_check(embedding_node.args, 5) + + if max_norm is not None: + _LOGGER.debug( + f"Currently we don't support specifying max_norm, got {max_norm}." + ) + return False + + if norm_type is not None and norm_type != 2.0: + _LOGGER.debug( + f"Currently we don't support specifying norm_type, got {norm_type}." + ) + return False + + if scale_grad_by_freq is not None: + _LOGGER.debug( + f"Currently we don't support specifying scale gradient by word frequency, got {scale_grad_by_freq}." + ) + return False + + if sparse is not None: + _LOGGER.debug(f"Currently we don't support sparse gradient, got {sparse}.") + return False + + return True + + +@dynamo_tensorrt_converter( + torch.ops.aten.embedding.default, capability_validator=embedding_param_validator +) +def aten_ops_embedding( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.embedding.embedding( + network, + target, + SourceIR.ATEN, + name, + input=args[1], + weight=args[0], + max_norm=args_bounds_check(args, 2), + norm_type=args_bounds_check(args, 3), + scale_grad_by_freq=args_bounds_check(args, 4), + sparse=args_bounds_check(args, 5), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar) +@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor) +def aten_ops_fmod( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1]) + + +@dynamo_tensorrt_converter(torch.ops.aten.gelu.default) +def aten_ops_gelu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.activation.gelu( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.matmul) +@dynamo_tensorrt_converter(torch.ops.aten.mm.default) +def aten_ops_matmul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.matmul.matrix_multiply( + network, target, SourceIR.ATEN, name, args[0], args[1] + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default) +def aten_ops_layernorm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.layer_norm( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args[4], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.relu.default) +def aten_ops_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return impl.activation.relu( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default) +def aten_ops_rsqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return impl.elementwise.rsqrt( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim) +@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims) +def aten_ops_squeeze( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.squeeze.squeeze(network, target, SourceIR.ATEN, name, args[0], args[1]) + + +@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default) +def aten_ops_unsqueeze( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.unsqueeze.unsqueeze( + network, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1] + ) + + +@dynamo_tensorrt_converter(torch.ops.aten._softmax.default) +def aten_ops_softmax( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.normalization.softmax( + network, target, SourceIR.ATEN, name, args[0], args[1] + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.where.self) +def aten_ops_where( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.condition.where( + network, + target, + SourceIR.ATEN, + name, + args[1], + args[2], + args[0], + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.clamp.default) +def aten_ops_clamp( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.elementwise.clamp( + network, + target, + SourceIR.ATEN, + name, + input_val=args[0], + min_val=args_bounds_check(args, 1), + max_val=args_bounds_check(args, 2), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.select.int) +def aten_ops_select( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.select.select( + network, target, SourceIR.ATEN, name, args[0], args[1], args[2] + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor) +def aten_ops_slice( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.slice.slice_op( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + args[2], + args[3], + args_bounds_check(args, 4, replacement=1), + ) + + +@dynamo_tensorrt_converter(torch.ops.aten.permute.default) +def aten_ops_permute( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.permutation.permute( + network, + target, + SourceIR.ATEN, + name, + args[0], + args[1], + ) diff --git a/py/torch_tensorrt/dynamo/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py similarity index 100% rename from py/torch_tensorrt/dynamo/converter_registry.py rename to py/torch_tensorrt/dynamo/conversion/converter_registry.py diff --git a/py/torch_tensorrt/dynamo/conversion/converter_utils.py b/py/torch_tensorrt/dynamo/conversion/converter_utils.py index 4471931e4c..584e15b263 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_utils.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_utils.py @@ -1,5 +1,19 @@ import torch +from torch_tensorrt.fx.types import ( + TRTDataType, + TRTNetwork, + TRTTensor, +) + +from torch_tensorrt.fx.converters.converter_utils import ( + unified_dtype_converter, + Frameworks, +) + +import tensorrt as trt +from typing import List + def dynamic_unsupported(node: torch.fx.Node) -> bool: # Validate that none of the inputs to the node have Dynamic shapes @@ -28,3 +42,86 @@ def dynamic_unsupported(node: torch.fx.Node) -> bool: return False return True + + +def cast_trt_tensor( + network: TRTNetwork, + input_val: TRTTensor, + dtype: TRTDataType, + name: str, +) -> TRTTensor: + """ + Given a TRT Tensor, convert that Tensor to the specified dtype + Adds an Identity layer to the network which performs the conversion + Args: + network (TRTNetwork): A TensorRT network + input_val (TRTTensor): A TRT Tensor to cast to a new data type + dtype (TRTDataType): The TRTDataType to cast the input Tensor to + name (str): Name of the calling layer + Returns: + A TensorRT ITensor which has been casted to the specified dtype + """ + trt_dtype = unified_dtype_converter(dtype, Frameworks.TRT) + + if input_val.dtype != trt_dtype: + identity_layer = network.add_identity(input_val) + identity_layer.set_output_type(0, trt_dtype) + identity_layer.name = f"Cast ITensor {input_val.name} from {input_val.dtype} to {trt_dtype} - {name}" + return identity_layer.get_output(0) + else: + return input_val + + +def cast_int_int_div_trt_tensor( + network: TRTNetwork, + lhs_val: TRTTensor, + rhs_val: TRTTensor, + name: str, +) -> List[TRTTensor]: + """ + Given two `int` data type TRT Tensor to div operation, cast the TRT Tensor to float type + Args: + network (TRTNetwork): A TensorRT network + lhs_val (TRTTensor): A TRT Tensor numerator + rhs_val (TRTTensor): A TRT Tensor numerator + name (str): Name of calling layer + Returns: + A list of lhs_val and rhs_val casted to the approriate datatype + """ + if (lhs_val.dtype == trt.int8 or lhs_val.dtype == trt.int32) and ( + rhs_val.dtype == trt.int8 or rhs_val.dtype == trt.int32 + ): + lhs_val = cast_trt_tensor(network, lhs_val, trt.float32, name) + rhs_val = cast_trt_tensor(network, rhs_val, trt.float32, name) + return list((lhs_val, rhs_val)) + + +def broadcastable( + a: TRTTensor, + b: TRTTensor, +) -> bool: + "Check if two tensors are broadcastable according to torch rules" + a_shape = tuple(a.shape) + b_shape = tuple(b.shape) + # check from the trailing + diff = len(a_shape) - len(b_shape) + if diff == 0: + return True + if diff > 0: + max = len(a_shape) + min = len(b_shape) + greater_tensor = a_shape + lesser_tensor = b_shape + elif diff < 0: + max = len(b_shape) + min = len(a_shape) + greater_tensor = b_shape + lesser_tensor = a_shape + j = min - 1 + for i in range(max - 1, diff - 1, -1): + if not ( + greater_tensor[i] != lesser_tensor[j] + and (greater_tensor[i] == 1 or lesser_tensor[i] == 1) + ): + return False + return True diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py new file mode 100644 index 0000000000..db6e405978 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -0,0 +1,14 @@ +from torch_tensorrt.fx.converters.impl import convolution +from . import condition +from . import elementwise +from . import embedding +from . import normalization +from . import slice +from . import unary +from . import activation +from . import matmul +from . import select +from . import shape +from . import squeeze +from . import unsqueeze +from . import permutation diff --git a/py/torch_tensorrt/dynamo/conversion/impl/activation.py b/py/torch_tensorrt/dynamo/conversion/impl/activation.py new file mode 100644 index 0000000000..ec3e078820 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/activation.py @@ -0,0 +1,65 @@ +import numpy as np +from typing import Any, Optional +import math + +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.converters.impl.activation import * +from torch_tensorrt.fx.converters.converter_utils import ( + mark_as_int8_layer, + set_layer_name, + get_trt_plugin, +) +from torch_tensorrt.dynamo.conversion import SourceIR + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, + TRTPluginFieldCollection, +) + + +def gelu( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, + alpha: Optional[Any] = None, +): + approximate = alpha + if approximate is not None: + raise RuntimeError("GeLU converter currently doesn't support fast gelu compute") + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"GELU received input {input_val} that is not part " + "of the TensorRT region!" + ) + if network.has_implicit_batch_dimension: + raise RuntimeError( + "GeLU converter currently doesn't support implicit batch dimension" + ) + plugin_name = "CustomGeluPluginDynamic" + # type_id 0 for float32, 1 for float16 + type_id = trt.PluginField( + "type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32 + ) + field_collection = TRTPluginFieldCollection([type_id]) + plugin_version = "1" + + plugin = get_trt_plugin(plugin_name, field_collection, plugin_version) + + layer = network.add_plugin_v2([input_val], plugin) + + def gelu_dyn_range_fn(dyn_range): + return ( + dyn_range[0] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0))) + ), (dyn_range[1] * 0.5 * (1.0 + torch.erf(dyn_range[0] / math.sqrt(2.0)))) + + if input_val.dynamic_range is not None: + dyn_range = gelu_dyn_range_fn(input_val.dynamic_range) + mark_as_int8_layer(layer, dyn_range) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py new file mode 100644 index 0000000000..79472fa2e7 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py @@ -0,0 +1,108 @@ +from typing import Optional + + +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import broadcastable +from torch_tensorrt.fx.converters.converter_utils import ( + broadcast, + get_trt_tensor, + set_layer_name, +) +from torch_tensorrt.dynamo.conversion.impl.slice import expand + + +def where( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, + condition: TRTTensor, +) -> TRTTensor: + input_dim = len(tuple(input.shape)) + other_dim = len(tuple(other.shape)) + condition_dim = len(tuple(condition.shape)) + + if type(input) != TRTTensor: + assert type(input) is torch.Tensor, f"value {input} is not torch.Tensor!" + + if type(other) != TRTTensor: + assert type(other) is torch.Tensor, f"value {other} is not torch.Tensor!" + + if not (broadcastable(input, other)): + assert f"The two torch tensors should be broadcastable" + + # get output shape + # purpose of this is to bring input and other rank same as + # output_shape to input it to the add_expand operation + # condition will have dimension of either input or other + input, other = broadcast(network, input, other, f"{name}_x", f"{name}_y") + if len(tuple(condition.shape)) != len(tuple(input.shape)): + condition, input = broadcast( + network, condition, input, f"{name}_condition", f"{name}_x" + ) + + x_shape = list(input.shape) + y_shape = list(other.shape) + condition_shape = list(condition.shape) + output_shape = list(torch.broadcast_shapes(condition_shape, x_shape, y_shape)) + + # expand shape + if type(condition) != TRTTensor: + assert condition.dtype == torch.bool, "condition dtype is not bool" + if condition_shape != output_shape: + condition.expand(output_shape) + condition = condition.to(torch.int32) + condition_const = get_trt_tensor(network, condition, f"{name}_condition") + condition_layer = network.add_identity(condition_const) + condition_layer.set_output_type(0, trt.bool) + set_layer_name(condition_layer, target, f"{name}_condition") + condition_val = condition_layer.get_output(0) + else: + assert condition.dtype == trt.bool, "mask dtype is not bool!" + if condition_shape != condition_dim: + condition_val = expand( + network, target, source_ir, f"{name}_expand", condition, output_shape + ) + else: + condition_val = condition + + if type(input) != TRTTensor: + if x_shape != input_dim: + # special case where 1 element in input + if len(input.shape) == 0: + input = input.unsqueeze(0) + input = input.expand(output_shape) + x_val = get_trt_tensor(network, input, f"{name}_x") + else: + x_val = input + if x_shape != output_shape: + x_val = expand( + network, target, source_ir, f"{name}_x_expand", input, output_shape + ) + + if type(other) != TRTTensor: + if y_shape != output_shape: + # special case where 1 element in other + if len(other.shape) == 0: + other = other.unsqueeze(0) + other = other.expand(output_shape) + y_val = get_trt_tensor(network, other, f"{name}_y") + else: + y_val = other + if y_shape != other_dim: + y_val = expand( + network, target, source_ir, f"{name}_y_expand", y_val, output_shape + ) + + select_layer = network.add_select(condition_val, x_val, y_val) + + set_layer_name(select_layer, target, f"{name}_select") + + return select_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py new file mode 100644 index 0000000000..25d71e3702 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/__init__.py @@ -0,0 +1,2 @@ +from .ops import * +from .clamp import clamp diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py new file mode 100644 index 0000000000..9b15ebd4c4 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -0,0 +1,162 @@ +import operator +import warnings +from typing import Union, Callable, Any, Optional + +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp +from torch_tensorrt.fx.utils import ( + unified_dtype_converter, + Frameworks, +) +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.dynamo.conversion.converter_utils import ( + cast_trt_tensor, +) +from torch_tensorrt.fx.converters.converter_utils import ( + set_layer_name, + broadcast, + squeeze_left, + get_trt_tensor, +) + + +def get_python_op_from_trt_elementwise_op( + trt_op: TRTElementWiseOp, +) -> Callable[[Any, Any], Any]: + if trt_op == trt.ElementWiseOperation.SUM: + return operator.add + elif trt_op == trt.ElementWiseOperation.PROD: + return operator.mul + elif trt_op == trt.ElementWiseOperation.SUB: + return operator.sub + elif trt_op == trt.ElementWiseOperation.DIV: + return operator.truediv + elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: + return operator.floordiv + else: + raise RuntimeError(f"{trt_op} is not supported yet!") + + +def convert_binary_elementwise( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + op_type: trt.ElementWiseOperation, + lhs_val: Union[int, float, TRTTensor, torch.Tensor], + rhs_val: Union[int, float, TRTTensor, torch.Tensor], +) -> TRTTensor: + """ + This function adds a TensorRT elementwise layer. We allow both operands to be + constant (not a trt tensor) because in implicit batch dimension mode, we could + introduce constant via .size() op. Other scenario should be const folded first. + If any operand is not a trt tensor, we make it a trt constant layer while preserve + its dtype. Then we broadcast these two inputs to have the same number of dimensions. + We also promote the types of the two tensors to avoid dtype errors in TRT. + + Limitation: + If we are using implicit batch dim mode, the operand that is not a trt + tensor are not allowed to have larger ranks than the trt tensor operand. + + Args: + network (TRTNetwork): TensorRT network object. + target (Target): Target of fx node. + source_ir (SourceIR): The IR that is calling the function. + name (str): The name we want to assign to the created TensorRT layer. + lhs_val (TRTTensor): Left operand of the binary operation. Could + be a TensorRT tensor, a PyTorch tensor or a simple value. + rhs_val (TRTTensor): Right operand of the binary operation. Similar + to lhs_val. + op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation. + + Returns: + The output of TensorRT Elementwise layer. + """ + lhs_dtype = None + rhs_dtype = None + is_lhs_trt_tensor = False + is_rhs_trt_tensor = False + + if isinstance(lhs_val, TRTTensor): + lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH) + is_lhs_trt_tensor = True + if isinstance(rhs_val, TRTTensor): + rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH) + is_rhs_trt_tensor = True + + if not is_lhs_trt_tensor and not is_rhs_trt_tensor: + warnings.warn( + f"Both operands of the binary elementwise op {name} " + "are constant. In this case, please consider constant fold the model first." + ) + return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val) + + # If the following conditions are true: + # 1. the network has implicit batch dimension, + # 2. one operand has shape [] (real shape is [batch_size]), + # 3. another operand is a scalar, + # then the result should also have shape [] (real shape is [batch_size]). + # + # In such case, we need to convert the scalar operand to tensor, because + # this way the shape will become [1], and then will be properly squeezed + # into [], meaning that the result will have shape [], which is what we + # expect. + # + # Note that the dtype here is supposed to be the same as the scalar + # dtype but we don't have a way to detect whether it makes sense for the + # scalar to be float or half. Hence we go with the lhs dtype. + if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): + rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) + if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): + lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) + + # When lhs is scalar, and rhs has shape [1,], then currently the assert + # will fail because lhs shape has fewer dimensions than rhs shape. This + # happens when using implicit batch dimension, when we removed the 1st + # dimension from input tensor, causing it to have shape [] - a scalar. We + # fix it by reducing the rhs constant with a squeeze_left, so it becomes a + # scalar too. More generally, we squeeze_left on input if it's a constant + # tensor. This is safe because broadcast will pad dimensions on the left + # (prepend) to make lhs and rhs shape compatible. + if network.has_implicit_batch_dimension: + if isinstance(lhs_val, torch.Tensor): + lhs_val = squeeze_left(lhs_val) + if isinstance(rhs_val, torch.Tensor): + rhs_val = squeeze_left(rhs_val) + + lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) + + promoted_type = torch.promote_types( + unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH), + unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH), + ) + trt_promoted_type = unified_dtype_converter(promoted_type, Frameworks.TRT) + + if trt_promoted_type != lhs_val.dtype: + lhs_val = cast_trt_tensor(network, lhs_val, trt_promoted_type, name) + if trt_promoted_type != rhs_val.dtype: + rhs_val = cast_trt_tensor(network, rhs_val, trt_promoted_type, name) + + # Check the limitation in the doc string. + if network.has_implicit_batch_dimension: + if is_lhs_trt_tensor and not is_rhs_trt_tensor: + assert len(lhs_val.shape) >= len( + rhs_val.shape + ), f"{lhs_val.shape} >= {rhs_val.shape}" + elif not is_lhs_trt_tensor and is_rhs_trt_tensor: + assert len(rhs_val.shape) >= len( + lhs_val.shape + ), f"{rhs_val.shape} >= {lhs_val.shape}" + + lhs_val, rhs_val = broadcast( + network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ) + layer = network.add_elementwise(lhs_val, rhs_val, op_type) + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return output diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/clamp.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/clamp.py new file mode 100644 index 0000000000..59e1b0f723 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/clamp.py @@ -0,0 +1,78 @@ +import numpy as np +from typing import Optional +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.utils import ( + unified_dtype_converter, + Frameworks, +) + +from torch_tensorrt.fx.converters.converter_utils import ( + set_layer_name, + squeeze_left, + get_trt_tensor, +) + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + + +def add_clamp(network, input, val, op, name): + if not len(input.shape): + # clamping scalar + acc_ops_clamp_trt = get_trt_tensor( + network, + squeeze_left( + np.array( + [val], dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY) + ) + ), + f"{name}_clamp_{val}", + ) + else: + acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions + acc_ops_clamp_tensor = np.full( + acc_ops_clamp_shape, + val, + dtype=unified_dtype_converter(input.dtype, Frameworks.NUMPY), + ) + acc_ops_clamp_trt = network.add_constant( + acc_ops_clamp_shape, acc_ops_clamp_tensor + ).get_output(0) + layer = network.add_elementwise(input, acc_ops_clamp_trt, op) + return layer + + +def clamp( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val, + min_val=None, + max_val=None, +) -> TRTTensor: + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"Clamp received input {input_val} that is not part " + "of the TensorRT region!" + ) + + if min_val is not None: + clamp_min_layer = add_clamp( + network, input_val, min_val, trt.ElementWiseOperation.MAX, name + ) + set_layer_name(clamp_min_layer, target, f"{name}_clamp_min") + input_val = clamp_min_layer.get_output(0) + if max_val is not None: + clamp_max_layer = add_clamp( + network, input_val, max_val, trt.ElementWiseOperation.MIN, name + ) + set_layer_name(clamp_max_layer, target, f"{name}_clamp_max") + input_val = clamp_max_layer.get_output(0) + + return input_val diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py new file mode 100644 index 0000000000..089fcf223c --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/ops.py @@ -0,0 +1,177 @@ +from typing import Any, Optional + +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.utils import ( + unified_dtype_converter, + Frameworks, +) +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor + +from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary +from torch_tensorrt.dynamo.conversion.impl.unary import sign + + +def trunc_div( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + """ + Perform trunc divide on Tensor, result of divide will be round toward zero. + This means for positive number, it will be floor round; for negative number, + it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. + + Args: + network: INetworkDefinition. + target: node target + source_ir (SourceIR): Source IR calling the function. + name: namespace for the op + input: divisor. + other: dividend. + + Returns: + A TensorRT tensor represent the result of trunc divide. + """ + prod_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_prod", + trt.ElementWiseOperation.PROD, + input, + other, + ) + + sign_output = sign( + network, + target, + source_ir, + name, + prod_output, + ) + + # Convert constant input into ITensor for UnaryOperation + if not isinstance(input, trt.tensorrt.ITensor): + input = get_trt_tensor(network, input, f"{name}_input") + if not isinstance(other, trt.tensorrt.ITensor): + other = get_trt_tensor( + network, + other, + f"{name}_other", + dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), + ) + + abs_input_output = convert_unary( + network, + target, + source_ir, + f"{name}_abs_input", + trt.UnaryOperation.ABS, + input, + ) + abs_other_output = convert_unary( + network, + target, + source_ir, + f"{name}_abs_other", + trt.UnaryOperation.ABS, + other, + ) + abs_floor_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_floor_div", + trt.ElementWiseOperation.FLOOR_DIV, + abs_input_output, + abs_other_output, + ) + output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_output", + trt.ElementWiseOperation.PROD, + abs_floor_output, + sign_output, + ) + + return output + + +def rsqrt( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + + sqrt_trt_output = convert_unary( + network, + target, + source_ir, + f"{name}_sqrt", + trt.UnaryOperation.SQRT, + input, + ) + + output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_output", + trt.ElementWiseOperation.DIV, + 1, + sqrt_trt_output, + ) + + return output + + +def fmod( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it + trunc_div_value = trunc_div( + network, + target, + source_ir, + name + "_trunc_div", + input, + other, + ) + prod_value = convert_binary_elementwise( + network, + target, + source_ir, + name + "_prod", + trt.ElementWiseOperation.PROD, + trunc_div_value, + other, + ) + sub_value = convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name + "_sub", + trt.ElementWiseOperation.SUB, + input, + prod_value, + ) + return sub_value diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py new file mode 100644 index 0000000000..a68d2455ee --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -0,0 +1,73 @@ +import operator +import warnings +from typing import Optional, cast, Any + +import numpy as np + +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor + +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + set_layer_name, +) + +from torch_tensorrt.fx.converters.converter_utils import get_trt_tensor + + +def embedding( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + weight: TRTTensor, + max_norm: None, + norm_type: None, + scale_grad_by_freq: bool, + sparse: bool, +) -> TRTTensor: + + if network.has_implicit_batch_dimension: + raise RuntimeError( + "The `embedding` function should be called with explicit batch dimension." + ) + + indices_tensor = input + embedding_tensor = weight + if isinstance(indices_tensor, torch.Tensor) and indices_tensor.dtype == torch.int64: + raise RuntimeError( + "The `embedding` op has indices_tensor dtype=int64. This is incorrect since it has to be int32 to run on TRT." + ) + indices_tensor = get_trt_tensor(network, indices_tensor, f"{name}_indices_tensor") + embedding_tensor = get_trt_tensor( + network, embedding_tensor, f"{name}_embedding_tensor" + ) + # unsupported parameters + # ignore padding_idx since it is meaningful for training only + + if max_norm is not None: + raise RuntimeError( + f"Currently we don't support specifying max_norm, got {max_norm}." + ) + + if norm_type is not None and norm_type != 2.0: + raise RuntimeError( + f"Currently we don't support specifying max_norm, got {norm_type} for norm_type." + ) + + if scale_grad_by_freq: + raise RuntimeError( + "Currently we don't support scale gradient by word frequency." + ) + + if sparse: + raise RuntimeError("Currently we don't support sparse gradient.") + + # Implement embedding lookup with gather layer + gather_layer = network.add_gather(embedding_tensor, indices_tensor, axis=0) + set_layer_name(gather_layer, target, name + "_gather") + return gather_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/matmul.py b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py new file mode 100644 index 0000000000..846f4ab2ee --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/matmul.py @@ -0,0 +1,54 @@ +from typing import Optional + + +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.utils import ( + unified_dtype_converter, + Frameworks, +) +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_trt_tensor, + broadcast, + set_layer_name, +) + + +def matrix_multiply( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + if not isinstance(input, trt.tensorrt.ITensor): + input = get_trt_tensor(network, input, f"{name}_input") + if not isinstance(other, trt.tensorrt.ITensor): + other = get_trt_tensor( + network, + other, + f"{name}_other", + dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), + ) + + input_matrix_op = other_matrix_op = trt.MatrixOperation.NONE + preset_diff = 0 + + if len(input.shape) == 1: + preset_diff -= 1 + input_matrix_op = trt.MatrixOperation.VECTOR + + if len(other.shape) == 1: + preset_diff += 1 + other_matrix_op = trt.MatrixOperation.VECTOR + + input, other = broadcast( + network, input, other, f"{name}_input", f"{name}_other", preset_diff + ) + layer = network.add_matrix_multiply(input, input_matrix_op, other, other_matrix_op) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py new file mode 100644 index 0000000000..9d193fdf92 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -0,0 +1,313 @@ +from typing import cast, Union, Any, Optional, Sequence + +import numpy as np + +import tensorrt as trt +import torch +from torch.fx.node import Target + +import logging + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.fx.utils import get_dynamic_dims +from torch_tensorrt.dynamo.conversion import SourceIR + +from torch_tensorrt.fx.converters.converter_utils import ( + get_trt_plugin, + set_layer_name, + to_numpy, + has_dynamic_shape, + get_positive_dim, +) + +from torch_tensorrt.dynamo.conversion.impl.unary.base import ( + convert_unary, +) + +from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( + convert_binary_elementwise, +) + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def batch_norm( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + weight: torch.Tensor, + bias: torch.Tensor, + running_mean: torch.Tensor, + running_var: torch.Tensor, + training: torch.Tensor, + momentum: torch.Tensor, + eps: list, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"BatchNorm2d received input {input} that is not part " + "of the TensorRT region!" + ) + + if has_dynamic_shape(input.shape): + assert input.shape[1] != -1, "Channel dim can't be dynamic for batch norm." + + scale = cast(torch.Tensor, to_numpy(cast(torch.Tensor, weight))) / np.sqrt( + cast(torch.Tensor, to_numpy(cast(torch.Tensor, running_var))) + cast(float, eps) + ) + + bias = ( + to_numpy(cast(torch.Tensor, bias)) + - to_numpy(cast(torch.Tensor, running_mean)) * scale + ) + power = np.ones_like(scale) + + # For BatchNorm1d, reshape 1d to 2d + output_shape = input.shape + if not network.has_implicit_batch_dimension and len(input.shape) < 4: + assert ( + len(get_dynamic_dims(input.shape)) <= 1 + ), "BatchNorm1D with more than one dynamic dims is not currently supported." + reshape_layer = network.add_shuffle(input) + if len(input.shape) == 2: + reshape_layer.reshape_dims = (input.shape[0], input.shape[1], 1, 1) + else: # len(input_val.shape) == 3 + reshape_layer.reshape_dims = ( + input.shape[0], + input.shape[1], + input.shape[2], + 1, + ) + set_layer_name(reshape_layer, target, f"{name}_reshape_2d") + input = reshape_layer.get_output(0) + layer = network.add_scale(input, trt.ScaleMode.CHANNEL, bias, scale, power) + set_layer_name(layer, target, name) + + # For BatchNorm1d, reshape output back to 1d + if not network.has_implicit_batch_dimension and len(output_shape) < 4: + reshape_output_layer = network.add_shuffle(layer.get_output(0)) + reshape_output_layer.reshape_dims = tuple(output_shape) + set_layer_name(reshape_output_layer, target, f"{name}_reshape_1d") + layer = reshape_output_layer + return layer.get_output(0) + + +def layer_norm( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + normalized_shape: list, + weight: torch.Tensor, + bias: torch.Tensor, + eps: list, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + if not isinstance(input, trt.tensorrt.ITensor): + raise RuntimeError( + f"LayerNorm received input {input} that is not part " + "of the TensorRT region!" + ) + + gamma = weight.detach().cpu().float().numpy() + gamma_field = trt.PluginField("gamma", gamma, trt.PluginFieldType.FLOAT32) + beta = bias.detach().cpu().float().numpy() + beta_field = trt.PluginField("beta", beta, trt.PluginFieldType.FLOAT32) + eps_field = trt.PluginField( + "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 + ) + try: + normalized_shape = np.array(normalized_shape, dtype=np.int32) + except TypeError: + _LOGGER.error("Unable to convert normalized_shape to a field, fall back to []") + normalized_shape = np.array([], dtype=np.int32) + + normalized_shape_filed = trt.PluginField( + "normalized_shape", normalized_shape, trt.PluginFieldType.INT32 + ) + field_collection = trt.PluginFieldCollection( + [gamma_field, beta_field, eps_field, normalized_shape_filed] + ) + + try: + if network.has_implicit_batch_dimension: + plugin = get_trt_plugin("layer_norm", field_collection, "1", "fx2trt") + else: + plugin = get_trt_plugin("LayerNormDynamic", field_collection, "1", "fx2trt") + except AssertionError: + _LOGGER.error( + "Unable to find layer norm plugin, fall back to TensorRT implementation." + ) + return layer_norm_no_plugin( + network, target, source_ir, name, input, normalized_shape, weight, bias, eps + ) + layer = network.add_plugin_v2([input], plugin) + layer.name = name + return layer.get_output(0) + + +def layer_norm_no_plugin( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + normalized_shape: list, + weight: torch.Tensor, + bias: torch.Tensor, + eps: list, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"LayerNorm received input {input} that is not part " + "of the TensorRT region!" + ) + + shape = weight.shape # type: ignore[union-attr] + broadcasted_shape = (1,) * (len(input.shape) - len(shape)) + shape + gamma = to_numpy(weight.reshape(*shape)) # type: ignore[union-attr] + beta = to_numpy(bias.reshape(*shape)) # type: ignore[union-attr] + + axes = 0 + for d in range(len(shape)): + axes |= 1 << (len(input.shape) - d - 1) + + # E[x] + mean_expected_layer = network.add_reduce( + input, trt.ReduceOperation.AVG, axes, keep_dims=True + ) + set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") + + # X-E[x] + sub_trt = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_sub", + trt.ElementWiseOperation.SUB, + input, + mean_expected_layer.get_output(0), + ) + # Variance = mean(pow(x_sub_mean,2)) + pow_tensor = network.add_constant( + (1,) * len(input.shape), + trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), + ) + pow_tensor.name = f"{name}_power" + pow_var = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_pow_var", + trt.ElementWiseOperation.POW, + sub_trt, + pow_tensor.get_output(0), + ) + mean_trt_layer = network.add_reduce( + pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True + ) + set_layer_name(mean_trt_layer, target, f"{name}_mean") + # Variance + eps + eps_tensor = network.add_constant( + (1,) * len(input.shape), + trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), + ) + eps_tensor.name = f"{name}_eps" + add_trt = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_add", + trt.ElementWiseOperation.SUM, + mean_trt_layer.get_output(0), + eps_tensor.get_output(0), + ) + # SQRT((Var + eps)) + sqrt_trt = convert_unary( + network, + target, + source_ir, + f"{name}_sqrt", + trt.UnaryOperation.SQRT, + add_trt, + ) + # (x - E[x]) / sqrt((var + eps)) + div_trt = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_div_trt", + trt.ElementWiseOperation.DIV, + sub_trt, + sqrt_trt, + ) + + assert gamma is not None + gamma_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(gamma))) # type: ignore[attr-defined] + gamma_tensor.name = f"{name}_gamma" + assert beta is not None + beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined] + beta_tensor.name = f"{name}_beta" + # y * gamma + beta + scale_layer = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_scale", + trt.ElementWiseOperation.PROD, + div_trt, + gamma_tensor.get_output(0), + ) + return convert_binary_elementwise( + network, + target, + source_ir, + name, + trt.ElementWiseOperation.SUM, + scale_layer, + beta_tensor.get_output(0), + ) + + +def softmax( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Optional[Any] = None, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + input_ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr] + + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"softmax received input {input} that is not part " + "of the TensorRT region!" + ) + + # Used to get dim when dim is None. Copied from PyTorch softmax implementation. + def get_softmax_dim(ndim: int) -> int: + if ndim == 0 or ndim == 1 or ndim == 3: + ret = 0 + else: + ret = 1 + return ret + + if dim is None: + dim = get_softmax_dim(input_ranks) + else: + dim = cast(int, dim) + + dim = get_positive_dim(dim, input_ranks) + if network.has_implicit_batch_dimension: + assert dim != 0, "Can't apply softmax on batch dimension when it's implicit." + dim -= 1 + + layer = network.add_softmax(input) + layer.axes = 1 << dim + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/permutation.py b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py new file mode 100644 index 0000000000..492e35ba97 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/permutation.py @@ -0,0 +1,34 @@ +from typing import Optional, Sequence, cast + + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + set_layer_name, + get_positive_dim, +) + + +def permute( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + permutation: Sequence[int], +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"permute received input {input} that is not a TensorRT ITensor" + ) + + permutation = [ + get_positive_dim(i, len(input.shape)) for i in cast(Sequence[int], permutation) + ] + + layer = network.add_shuffle(input) + layer.second_transpose = tuple(permutation) + set_layer_name(layer, target, name, source_ir) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/select.py b/py/torch_tensorrt/dynamo/conversion/impl/select.py new file mode 100644 index 0000000000..26ad175104 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/select.py @@ -0,0 +1,64 @@ +from typing import Optional, cast + +import numpy as np +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, + has_dynamic_shape, + to_numpy, +) +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape + + +def select( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Shape, + index: Shape, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, dim), ranks) + dynamic_shape = has_dynamic_shape(input.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't select on negative shape dimension!" + index = index + + if index >= input.shape[dim]: + raise RuntimeError( + f"cannot have index greater than the dimension length! {input.shape[dim]}" + ) + output_shape = list(input.shape) + output_shape[dim] = 1 + if dynamic_shape > 0: + output_shape = get_shape_with_dynamic_shape( + network, target, source_ir, name, output_shape, input + ) + index_value = np.array(index, dtype=np.int32) + indices_tensor = network.add_constant( + index_value.shape, to_numpy(index_value) + ).get_output(0) + layer = network.add_gather(input, indices_tensor, dim) + out = layer.get_output(0) + if len(out.shape) != 1: + layer = network.add_shuffle(out) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/shape.py b/py/torch_tensorrt/dynamo/conversion/impl/shape.py new file mode 100644 index 0000000000..7f122f5646 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/shape.py @@ -0,0 +1,77 @@ +from typing import Union + +import numpy as np + +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + set_layer_name, + to_numpy, +) + +from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( + convert_binary_elementwise, +) + + +def get_shape_with_dynamic_shape( + network: TRTNetwork, + target: Target, + source_ir: SourceIR, + name: str, + shape: Union[list, tuple, torch.Tensor], + input_val: TRTTensor, +) -> TRTTensor: + """ + Prepare the real output tensor shape for dynamic shape mode tensor input. + How this functions works: + Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation + output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual + reduce operation output shape. Steps of calculations are: + 1. get the actual tensor shape of input_val via add_shape layer; + 2. create a all 0 tensor [0, 0, 0]; + 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; + 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace + all -1 dynamic shape dimensions with actual batch_size value; + 5. output shape with actual batch_size as [2048, 128, 256] + + Args: + network (TRTNetwork): TensorRT network object. + shape: calculated shape of the expected output tensor + input_val (TRTTensor): A TensorRT ITensor. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + Returns: + TensorRT ITensors that represents the actual shape of the input_val + """ + # Ger real shape info for input_val + input_shape = network.add_shape(input_val).get_output(0) + + scale_layer = network.add_constant( + input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) + ) + set_layer_name(scale_layer, target, f"{name}_scale") + scale_res = scale_layer.get_output(0) + + length = input_shape.shape[0] + zero_layer = network.add_constant( + input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) + ) + set_layer_name(zero_layer, target, f"{name}_zeros") + + condition_val = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_shape", + trt.ElementWiseOperation.LESS, + scale_res, + zero_layer.get_output(0), + ) + select_layer = network.add_select(condition_val, input_shape, scale_res) + set_layer_name(select_layer, target, f"{name}_select") + return select_layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py new file mode 100644 index 0000000000..97cc0d1404 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/base.py @@ -0,0 +1,39 @@ +from typing import Optional + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + has_dynamic_shape, + set_layer_name, +) + +from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape + + +def slice( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + start: Shape, + shape: Shape, + stride: Shape, +) -> TRTTensor: + dynamic_shape = has_dynamic_shape(input.shape) + if dynamic_shape: + shape = get_shape_with_dynamic_shape( + network, target, source_ir, name, shape, input + ) + layer = network.add_slice( + input, + start=start, + shape=[] if dynamic_shape else shape, + stride=stride, + ) + if dynamic_shape: + layer.set_input(2, shape) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py new file mode 100644 index 0000000000..848e13ba4b --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py @@ -0,0 +1,96 @@ +from typing import Optional, cast +import math + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, Shape +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, + has_dynamic_shape, + broadcast, + get_trt_tensor, +) +from torch_tensorrt.dynamo.conversion.impl.slice.base import slice + + +def slice_op( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: int, + start: int, + stop: int, + step: int, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"slice_tensor received input {input} that is not part " + "of the TensorRT region!" + ) + + ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) + dim = get_positive_dim(cast(int, dim), ranks) + dynamic_shape = has_dynamic_shape(input.shape) + if network.has_implicit_batch_dimension: + if dim == 0: + raise RuntimeError( + f"We do not support slice_tensor at batch dim when it's implicit, got {dim}!" + ) + dim = dim - 1 + else: + if dynamic_shape: + # Check whether slice target dim is dynamic shape dim + assert input.shape[dim] != -1, "Can't chunk on dynamic shape dimension!" + start_int = cast(int, start) + stop_int = cast(int, stop) + if stop_int == 2**63 - 1: + stop_int = input.shape[dim] + step_int = cast(int, step) + start = [0] * len(input.shape) + start[dim] = start_int + stride = [1] * len(start) + stride[dim] = step_int + output_shape = list(input.shape) + output_shape[dim] = math.ceil((stop_int - start_int) / step_int) + + return slice(network, target, source_ir, name, input, start, output_shape, stride) + + +def expand( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + sizes: Shape, +) -> TRTTensor: + shape = list(sizes) + + input_val = get_trt_tensor(network, input, f"{name}_input") + + if network.has_implicit_batch_dimension: + shape = shape[1:] + + ranks = len(input_val.shape) + # TRT does not support different dimension size + # though this condition is not seen in the case of bmm + # where input_t and shape dimensions are not equal + assert len(shape) >= ranks + if len(shape) != ranks: + shape_tuple = tuple([0] * len(shape)) + shape_tensor = get_trt_tensor(network, input, f"{name}_shape") + input_val, shape_tensor = broadcast( + network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val" + ) + ranks = len(shape) + + inshape = tuple(input_val.shape) + shape = tuple(shape) + start = tuple([0] * ranks) + stride = tuple( + [int(i == o) for i, o in zip(inshape, shape)] + ) # stride == 1 if dimensions match, 0 otherwise + return slice(network, target, source_ir, name, input_val, start, shape, stride) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py new file mode 100644 index 0000000000..4c5ad200ad --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/squeeze.py @@ -0,0 +1,63 @@ +from typing import Optional, cast, Any + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, + set_layer_name, +) + +from torch_tensorrt.fx.utils import get_dynamic_dims + + +def squeeze( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + dim: Optional[Any] = None, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"squeeze received input {input} that is not part " + "of the TensorRT region!" + ) + dims = [] + if dim is not None: + if isinstance(dim, int): + dims.append(cast(Optional[int], dim)) + else: + for dim in dim: + dims.append(cast(Optional[int], dim)) + + # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic + # dim, which is a very rare case. For now we just claim not supporting dim=None. + assert not (len(dims) == 0), "We don't support dim=None right now for squeeze." + + for dim in dims: + dim = cast(Optional[int], dim) + dim = get_positive_dim( + dim, + len(input.shape) + (1 if network.has_implicit_batch_dimension else 0), + ) + if network.has_implicit_batch_dimension: + assert dim != 0, "We don't support squeeze batch dim when it's implicit." + dim -= 1 + + assert input.shape[dim] != -1, "We don't support squeeze dynamic dim." + assert ( + len(get_dynamic_dims(input.shape)) <= 1 + ), "Currently more than one dynamic dim for input to squeeze is not supported." + + output_shape = [] + for i, s in enumerate(input.shape): + if (i in dims) and s == 1: + continue + output_shape.append(s) + layer = network.add_shuffle(input) + layer.reshape_dims = tuple(output_shape) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py new file mode 100644 index 0000000000..0ee1185850 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/base.py @@ -0,0 +1,44 @@ +from typing import Optional + +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import set_layer_name + + +def convert_unary( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + operation_type: trt.UnaryOperation, + input_val: TRTTensor, +) -> TRTTensor: + """ + Add a TensorRT Unary layer to `network`. + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. + op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + + Returns: + The output of TensorRT Unary layer. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_unary(input_val, operation_type) + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py new file mode 100644 index 0000000000..e0a255f800 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -0,0 +1,98 @@ +from typing import Optional + +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + +from torch_tensorrt.dynamo.conversion import SourceIR + + +from torch_tensorrt.dynamo.conversion.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.dynamo.conversion.impl.unary.base import convert_unary + + +def sign( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + """ + Sign is calculated as below: + x = input + sign = (exp(x) // exp(abs(x))) * 2 - 1 + For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0. + With multiply 2, the value become 2(for pos and 0) and 0(for neg). + Finally minus 1, the value become 1(for pos and 0) and -1(for neg). + + Args: + network (TRTNetwork): TensorRT network object. + target (Target): fx node target. + source_ir (SourceIR): Source IR calling the function + name (str): Name of the fx node with optional suffix. + input_val (TRTTensor): The input tensor. + + Returns: + A TensorRT tensor represent the result of sign operator. + """ + input_exp_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_exp", + trt.UnaryOperation.EXP, + input_val, + ) + input_abs_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_abs", + trt.UnaryOperation.ABS, + input_val, + ) + input_abs_exp_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_abs_exp", + trt.UnaryOperation.EXP, + input_abs_output, + ) + + floor_div_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_exp_floor_div", + trt.ElementWiseOperation.FLOOR_DIV, + input_exp_output, + input_abs_exp_output, + ) + + double_floor_div_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_floor_div*2", + trt.ElementWiseOperation.PROD, + floor_div_output, + 2, + ) + + return convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_sign", + trt.ElementWiseOperation.SUB, + double_floor_div_output, + 1, + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py new file mode 100644 index 0000000000..d1559ef324 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/unsqueeze.py @@ -0,0 +1,52 @@ +from typing import Optional, cast + +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor +from torch_tensorrt.dynamo.conversion import SourceIR +from torch_tensorrt.fx.converters.converter_utils import ( + get_positive_dim, + get_trt_tensor, + set_layer_name, +) + +from torch_tensorrt.fx.utils import get_dynamic_dims + + +def unsqueeze( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_t, + dim, +) -> TRTTensor: + input_val = get_trt_tensor(network, input_t, f"{name}_input_t") + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"unsqueeze received input {input_val} that is not part " + "of the TensorRT region!" + ) + + dim = cast(int, dim) + input_shape = input_val.shape + input_shape_size = ( + len(input_val.shape) + 1 + if network.has_implicit_batch_dimension + else len(input_val.shape) + ) + dim = get_positive_dim(dim, input_shape_size + 1) + + if network.has_implicit_batch_dimension: + assert dim != 0 + dim -= 1 + + assert ( + len(get_dynamic_dims(input_val.shape)) <= 1 + ), "Currently we don't support unsqueeze with more than one dynamic dims." + layer = network.add_shuffle(input_val) + layer.reshape_dims = ( + tuple(input_val.shape)[:dim] + (1,) + tuple(input_val.shape)[dim:] + ) + set_layer_name(layer, target, name) + return layer.get_output(0) diff --git a/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py b/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py index 6b72d87ff6..4293fb65eb 100644 --- a/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/trt_interpreter.py @@ -13,7 +13,7 @@ from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata -from torch_tensorrt.dynamo import DYNAMO_CONVERTERS as CONVERTERS +from .converter_registry import DYNAMO_CONVERTERS as CONVERTERS from torch_tensorrt import Input from torch_tensorrt.fx.observer import Observer from torch_tensorrt.fx.utils import ( @@ -64,7 +64,15 @@ def __init__( + "\n".join(f"{i}" for i in missing_ops) ) - self.optimization_profiles: Optional[List] = None + self.optimization_profiles = ( + [self.builder.create_optimization_profile()] + if any( + input_spec.shape_mode == Input._ShapeMode.DYNAMIC + for input_spec in input_specs + ) + else None + ) + self.input_specs = input_specs self.input_specs_iter = 0 self._cur_node_name: Optional[str] = None @@ -257,7 +265,7 @@ def placeholder(self, target, args, kwargs): opt_shape = current_input.shape["opt_shape"] max_shape = current_input.shape["max_shape"] self.optimization_profiles[0].set_shape( - target, [min_shape, opt_shape, max_shape] + target, min_shape, opt_shape, max_shape ) assert len(min_shape) == len(opt_shape) == len(max_shape) for i in range(len(min_shape)): diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 7021b55518..d56a3a8616 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -46,6 +46,16 @@ def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: return torch.reciprocal(torch.sqrt(*args, **kwargs)) +@register_decomposition(aten._unsafe_view, registry=DECOMPOSITIONS) +def unsafe_view_replacement(x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return torch.reshape(x, *args, **kwargs) + + +@register_decomposition(torch.ops.aten.lift_fresh_copy, registry=DECOMPOSITIONS) +def lift_fresh_copy_replacement(x: torch.Tensor) -> torch.Tensor: + return x + + @register_decomposition(aten.alias, registry=DECOMPOSITIONS) def alias_replacement(x: torch.Tensor) -> torch.Tensor: return x @@ -60,5 +70,12 @@ def addmm_replacement( ) +@register_decomposition(torch.ops.aten.reciprocal.default, registry=DECOMPOSITIONS) +def reciprocal_replacement( + input_: torch.Tensor, +) -> torch.Tensor: + return torch.div(1, input_) + + def get_decompositions(): return DECOMPOSITIONS diff --git a/py/torch_tensorrt/dynamo/test_utils.py b/py/torch_tensorrt/dynamo/test_utils.py new file mode 100644 index 0000000000..a3d742c70a --- /dev/null +++ b/py/torch_tensorrt/dynamo/test_utils.py @@ -0,0 +1,310 @@ +import time +import unittest +import torch +import logging +from typing import Callable, List, Optional, Set, Tuple +from torch.testing._internal.common_utils import TestCase + +import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer +from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( + compose_bmm, + compose_chunk, + compose_getitem_slice, + remove_ops, + replace_aten_op_with_indices, + replace_aten_reshape_alias_with_replace, + replace_builtin_ops, + replace_native_layernorm_with_layernorm, + replace_transpose_mm_op_with_linear, + run_const_fold, +) +from torch.fx.passes.infra.pass_base import PassResult +from torch_tensorrt.fx.passes.pass_utils import chain_passes + +# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry +from torch_tensorrt.dynamo.conversion.trt_interpreter import TRTInterpreter +from torch_tensorrt.dynamo.runtime._PythonTorchTRTModule import PythonTorchTRTModule +from torch_tensorrt import Input + + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +def fetch_attr(mod, target): + """ + Fetch an attribute from the ``Module`` hierarchy of ``mod.module``. + + Args: + target (str): The fully-qualfiied name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split(".") + attr_itr = mod + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + +@unittest.skipIf(not torch.cuda.is_available(), "Skip because CUDA is not available") +class TRTTestCase(TestCase): + def setUp(self): + super().setUp() + torch.manual_seed(3) + + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops, + interpreter, + rtol, + atol, + precision=torch.float, + check_dtype=True, + ): + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + if len(expected_ops): + self.assert_has_op(mod, expected_ops) + if unexpected_ops: + self.assert_unexpected_op(mod, unexpected_ops) + start = time.perf_counter() + interpreter_result = interpreter.run(precision=precision) + sec = time.perf_counter() - start + _LOGGER.info(f"Interpreter run time(s): {sec}") + trt_mod = PythonTorchTRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + + ref_outputs = mod(*inputs) + + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + outputs = trt_mod(*cuda_inputs) + end_event.record() + torch.cuda.synchronize() + _LOGGER.info( + f"TRT run time(s)= {(start_event.elapsed_time(end_event) * 1.0e-3)}" + ) + + if type(outputs) not in (list, tuple): + outputs = [outputs] + if type(ref_outputs) not in ( + list, + tuple, + torch.return_types.max, + torch.return_types.min, + ): + ref_outputs = [ref_outputs] + for out, ref in zip(outputs, ref_outputs): + if not isinstance(ref, torch.Tensor): + ref = torch.tensor([ref]) + ref = ref.cpu() # to_dtype test has cases with gpu output + if ref.dtype == torch.int64: + ref = ref.int() # convert torch.max's index output tensor to int32 + torch.testing.assert_close( + out.cpu(), + ref, + rtol=rtol, + atol=atol, + equal_nan=True, + check_dtype=check_dtype, + ) + + def run_test_custom_compare_results( + self, + mod, + inputs, + expected_ops, + interpreter, + comparators: List[Tuple[Callable, List]], + fp16_mode=False, + ): + """ + Runs the test and compares the result using the provided comparators. + The size of comparators must be equal to the number of outputs from 'mod'. + + mod - a model to run. + inputs - a list of the model inputs. + expected ops - a list of ops that should be verified. + interpreter - used for converting the model to TRT. + comparators - a list of (func, args) pairs corresponding to each of + the module outputs. usage: func(x, y, *args) + + """ + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + if len(expected_ops): + self.assert_has_op(mod, expected_ops) + + interpreter_result = interpreter.run( + precision=torch.half if fp16_mode else torch.float + ) + trt_mod = PythonTorchTRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + res_trt = trt_mod(*cuda_inputs).cpu() + res_cpu = mod(*inputs) + assert len(res_trt) == len(res_cpu) + assert len(res_cpu) == len(comparators) + for output_trt, output_cpu, comparator in zip( + res_trt, res_cpu, comparators + ): + comp_func = comparator[0] + args = comparator[1] + self.assertTrue(comp_func(output_trt, output_cpu, *args)) + + def run_test_with_error(self, mod, inputs, interpreter, expect_error): + with self.assertRaises(expect_error): + with torch.no_grad(): + cuda_inputs = [] + for i in inputs: + cuda_inputs.append(i.cuda()) + + mod.eval() + interpreter.run(precision=torch.float) + + def assert_has_op(self, mod, ops): + ops_in_mod = set() + + for node in mod.graph.nodes: + if node.op == "call_module": + ops_in_mod.add(type(fetch_attr(mod, node.target))) + elif node.op in {"call_function", "call_method"}: + ops_in_mod.add(node.target) + + self.assertTrue( + ops_in_mod >= ops, f"expected ops {ops}, actuall ops {ops_in_mod}" + ) + + def assert_unexpected_op(self, mod, ops): + for node in mod.graph.nodes: + if node.op == "call_module": + if type(fetch_attr(mod, node.target)) in ops: + return False + elif node.op in {"call_function", "call_method"}: + if node.target in ops: + return False + return True + + +class DispatchTestCase(TRTTestCase): + def generate_graph( + self, + mod: torch.nn.Module, + original_inputs: List[torch.Tensor], + expected_ops: Set[Callable], + unexpected_ops: Optional[Set[Callable]] = None, + customized_passes: List[Callable] = None, + ): + # Torchdynamo+aot proxytensor tracer + # Below are common passes + passes_list = [ + compose_bmm, + compose_chunk, + compose_getitem_slice, + replace_aten_reshape_alias_with_replace, + replace_aten_op_with_indices, + replace_transpose_mm_op_with_linear, # after compose_bmm + replace_native_layernorm_with_layernorm, + remove_ops, + replace_builtin_ops, # after replace_native_layernorm_with_layernorm + ] + # Combine with customized passes specific to any model + if customized_passes: + passes_list.extend(customized_passes) + fx_module, _ = aten_tracer.trace(mod, original_inputs) + for passes in passes_list: + pr: PassResult = passes(fx_module) + fx_module = pr.graph_module + fx_module(*original_inputs) + + fx_module = run_const_fold(fx_module) + _LOGGER.info(f"FX graph= {fx_module.graph}") + + if len(expected_ops): + self.assert_has_op(fx_module, expected_ops) + if unexpected_ops: + self.assert_unexpected_op(fx_module, unexpected_ops) + + return fx_module + + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops=None, + apply_passes=None, + rtol=1e-03, + atol=1e-03, + precision=torch.float, + check_dtype=True, + ): + mod.eval() + mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + + if apply_passes is not None: + pass_tracer = chain_passes(*apply_passes) + mod = pass_tracer(mod, inputs) + + interp = TRTInterpreter( + mod, + Input.from_tensors(inputs), + ) + super().run_test( + mod, + inputs, + expected_ops, + unexpected_ops, + interp, + rtol, + atol, + precision, + check_dtype, + ) + + def run_test_with_dynamic_shape( + self, + mod, + input_specs, + expected_ops, + unexpected_ops=None, + rtol=1e-03, + atol=1e-03, + ): + mod.eval() + inputs = [spec.example_tensor("opt_shape") for spec in input_specs] + mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None) + + interp = TRTInterpreter( + mod, + input_specs, + ) + # Since the lowering is based on optimal shape. We need to test with + # different shape(for ex. max shape) for testing dynamic shape + inputs_max = [spec.example_tensor("max_shape") for spec in input_specs] + super().run_test( + mod, inputs_max, expected_ops, unexpected_ops, interp, rtol, atol + ) diff --git a/tests/py/dynamo/backend/test_backend_compiler.py b/tests/py/dynamo/backend/test_backend_compiler.py index 282fcdbfd2..fc9ba5a232 100644 --- a/tests/py/dynamo/backend/test_backend_compiler.py +++ b/tests/py/dynamo/backend/test_backend_compiler.py @@ -222,7 +222,7 @@ def test_int64_input_partial_support(self): class PartiallySupportedMultiOp(torch.nn.Module): def forward(self, x, y): return torch.ops.aten.div.Tensor_mode( - x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor" + x, torch.ops.aten.add.Tensor(y, y), rounding_mode=None ) fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) diff --git a/tests/py/dynamo/backend/test_decompositions.py b/tests/py/dynamo/backend/test_decompositions.py index a9578e4ed8..0e11bfd2b1 100644 --- a/tests/py/dynamo/backend/test_decompositions.py +++ b/tests/py/dynamo/backend/test_decompositions.py @@ -78,15 +78,14 @@ def forward(self, x): return y # Operations expected to be removed in the traced graph after decompositions - expected_ops = {torch.ops.aten.sqrt.default, torch.ops.aten.reciprocal.default} - unexpected_ops = {torch.ops.aten.rsqrt.default} + expected_ops = {torch.ops.aten.sqrt.default, torch.ops.aten.div.Tensor} + unexpected_ops = { + torch.ops.aten.rsqrt.default, + torch.ops.aten.reciprocal.default, + } inputs = [ - torch.randint( - 1, - 10, - (5,), - ), + torch.randint(1, 10, (5,), dtype=torch.int32), ] fx_graph = torch.fx.symbolic_trace(Rsqrt()) @@ -182,6 +181,69 @@ def forward(self, x, y, z): f"AddMM TRT outputs don't match with the original model.", ) + def test_lowering_reciprocal(self): + class Reciprocal(torch.nn.Module): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + def forward(self, x): + y = torch.ops.aten.reciprocal.default(x) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = {torch.ops.aten.div.Tensor} + unexpected_ops = {torch.ops.aten.reciprocal.default} + + inputs = [ + torch.randn( + 5, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(Reciprocal()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Reciprocal TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/converters/test_adaptive_avgpool_aten.py b/tests/py/dynamo/converters/test_adaptive_avgpool_aten.py new file mode 100644 index 0000000000..68ce24c20f --- /dev/null +++ b/tests/py/dynamo/converters/test_adaptive_avgpool_aten.py @@ -0,0 +1,127 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestAdaptiveAvgPoolConverter(DispatchTestCase): + def test_adaptive_avgpool_mean(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 256, 256)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.mean.dim}, + ) + + @parameterized.expand( + [ + ((64, 64),), + ((128, 64),), + (64,), + ] + ) + def test_adaptive_avgpool( + self, + output_size, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d(output_size) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 256, 256)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten._adaptive_avg_pool2d.default}, + ) + + def test_adaptive_avgpool_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d((64, 64)) + + def forward(self, x): + return self.pool(x) + + input_specs = [ + Input( + shape=(-1, -1, 256, 256), + dtype=torch.float32, + shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten._adaptive_avg_pool2d.default}, + ) + + @parameterized.expand( + [ + ((16, 16, 16),), + ((32, 16, 4),), + (32,), + ] + ) + def test_adaptive_avgpool3d( + self, + output_size, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d(output_size) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32, 64, 64)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten._adaptive_avg_pool3d.default}, + ) + + def test_adaptive_avgpool3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16)) + + def forward(self, x): + return self.pool(x) + + input_specs = [ + Input( + shape=(-1, -1, 32, 64, 64), + dtype=torch.float32, + shape_ranges=[ + ((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64)) + ], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten._adaptive_avg_pool3d.default}, + ) + + # Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_batchnorm_aten.py b/tests/py/dynamo/converters/test_batchnorm_aten.py new file mode 100644 index 0000000000..c39f14abfe --- /dev/null +++ b/tests/py/dynamo/converters/test_batchnorm_aten.py @@ -0,0 +1,66 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestBatchNormConverter(DispatchTestCase): + def test_batchnorm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.batch_norm}) + + def test_batchnorm1d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm1d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + Input( + shape=(-1, 3, 5), + dtype=torch.float32, + shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) + + def test_batchnorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) + + # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm. + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_binary_ops_aten.py b/tests/py/dynamo/converters/test_binary_ops_aten.py new file mode 100644 index 0000000000..19fa02721c --- /dev/null +++ b/tests/py/dynamo/converters/test_binary_ops_aten.py @@ -0,0 +1,263 @@ +from typing import Callable +import unittest + +import torch +import torch.nn as nn + +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + +NEED_TEST_BOTH_CONSTANTS_CASE = True + +elementwise_ops = [ + ((lambda x, y: x + y), torch.ops.aten.add.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ( + (lambda x, y: torch.add(x, y)), + torch.ops.aten.add.Tensor, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ((lambda x, y: x.add(y)), torch.ops.aten.add.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: x - y), torch.ops.aten.sub.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: torch.sub(x, y)), torch.ops.aten.sub.Tensor, False), + ((lambda x, y: x.sub(y)), torch.ops.aten.sub.Tensor, False), + ((lambda x, y: x / y), torch.ops.aten.div.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ( + (lambda x, y: x // y), + torch.ops.aten.floor_divide.default, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y, rounding_mode="trunc")), + torch.ops.aten.div.Tensor_mode, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y, rounding_mode="floor")), + torch.ops.aten.div.Tensor_mode, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y)), + torch.ops.aten.div.Tensor, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.fmod(x, y)), + torch.ops.aten.fmod.Tensor, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ## torch.floor_divide rounds result toward zero, rather than -Inf. + ## https://github.com/pytorch/pytorch/issues/43874 + ( + (lambda x, y: torch.floor_divide(x, y)), + torch.ops.aten.floor_divide.default, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ((lambda x, y: x * y), torch.ops.aten.mul.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + (torch.pow, torch.ops.aten.pow.Tensor_Tensor, not NEED_TEST_BOTH_CONSTANTS_CASE), +] + + +class TestBinaryOpConverters(DispatchTestCase): + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops(self, name, orig_op: Callable, expected_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, x) + + m = TestModule(orig_op) + # Avoid dividing by 0. + inputs = [torch.rand(1, 1) + 1] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + @unittest.skip("Pending reimplementation of all binary converters in Dynamo") + def test_elementwise_ops_mismatched_dtypes( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x, y): + return self.orig_op(x, y) + + m = TestModule(orig_op) + # Avoid dividing by 0. + inputs = [ + 2 * torch.rand(1, 1, dtype=torch.float) + 1, + torch.randint(1, 3, (1, 1), dtype=torch.int), + ] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops_with_one_constant( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant = torch.randn(1) + self.orig_op = orig_op + + def forward(self, x): + x = self.orig_op(x, self.constant) + return self.orig_op(x, -2) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2)] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand( + [(op[1].__name__, op[0], op[1]) for op in elementwise_ops if op[2]] + ) + def test_elementwise_op_with_both_constants( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant0 = torch.nn.Parameter(torch.randn(1)) + self.constant1 = torch.nn.Parameter(torch.randn(1)) + self.orig_op = orig_op + + def forward(self, x): + const = self.orig_op(self.constant0, self.constant1) + return self.orig_op(x, const) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2)] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([((lambda x, y: x / y), torch.ops.aten.div.Tensor)]) + def test_elementwise_op_div_with_two_ints(self, orig_op: Callable, expected_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, x + 1) + + m = TestModule(orig_op) + inputs = [torch.randint(1, 10, (5,), dtype=torch.int32)] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([((lambda x, y: x / y), torch.ops.aten.div.Tensor)]) + def test_elementwise_op_div_with_one_int_one_constant( + self, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant1 = torch.nn.Parameter( + torch.randn( + 5, + ) + ) + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, self.constant1) + + m = TestModule(orig_op) + inputs = [torch.randint(1, 10, (5,), dtype=torch.int32)] + self.run_test(m, inputs, expected_ops={expected_op}) + + # Dynamic shape test + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + (-1, -1, -1), + ((1, 1, 1), (2, 2, 2), (3, 3, 3)), + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape( + self, _, x_shape, x_shape_ranges, y_shape, y_shape_ranges, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + Input( + shape=x_shape, + dtype=torch.float32, + shape_ranges=[x_shape_ranges], + ), + Input( + shape=y_shape, + dtype=torch.float32, + shape_ranges=[y_shape_ranges], + ), + ] + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape_four_dimensions( + self, _, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_cat_aten.py b/tests/py/dynamo/converters/test_cat_aten.py new file mode 100644 index 0000000000..d9d107de89 --- /dev/null +++ b/tests/py/dynamo/converters/test_cat_aten.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestCatConverter(DispatchTestCase): + @parameterized.expand( + [ + ("pos", 1), + ("neg", -2), + ] + ) + def test_cat(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.cat((x, y, z), dim) + + inputs = [torch.randn(1, 2, 3), torch.randn(1, 1, 3), torch.randn(1, 3, 3)] + self.run_test( + Cat(), + inputs, + expected_ops={torch.ops.aten.cat.default}, + ) + + @parameterized.expand( + [ + ("pos", 1), + ("neg", -2), + ] + ) + def test_cat_dynamic_shape(self, _, dim): + class Cat(nn.Module): + def forward(self, x, y): + return torch.cat((x, y), dim) + + input_specs = [ + Input( + shape=(16, -1, 3), + dtype=torch.float32, + shape_ranges=[((16, 2, 3), (16, 3, 3), (16, 32, 3))], + ), + Input( + shape=(16, -1, 3), + dtype=torch.float32, + shape_ranges=[((16, 2, 3), (16, 16, 3), (16, 32, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Cat(), + input_specs, + expected_ops={torch.ops.aten.cat.default}, + ) + + def test_cat_no_dim(self): + class Cat(nn.Module): + def forward(self, x, y, z): + return torch.cat((x, y, z)) + + inputs = [torch.randn(2, 1, 3), torch.randn(1, 1, 3), torch.randn(3, 1, 3)] + self.run_test( + Cat(), + inputs, + expected_ops={torch.ops.aten.cat.default}, + ) + + def test_cat_dynamic_shape_no_dim(self): + class Cat(nn.Module): + def forward(self, x, y): + return torch.cat((x, y)) + + input_specs = [ + Input( + shape=(-1, 16, 3), + dtype=torch.float32, + shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))], + ), + Input( + shape=(-1, 16, 3), + dtype=torch.float32, + shape_ranges=[((2, 16, 3), (3, 16, 3), (32, 16, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Cat(), + input_specs, + expected_ops={torch.ops.aten.cat.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_clamp_aten.py b/tests/py/dynamo/converters/test_clamp_aten.py new file mode 100644 index 0000000000..05716c1657 --- /dev/null +++ b/tests/py/dynamo/converters/test_clamp_aten.py @@ -0,0 +1,71 @@ +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestClampConverter(DispatchTestCase): + @parameterized.expand( + [ + param("default", min=-1, max=0), + param("min", min=0.5), + param("max", max=0.5), + param("minBiggerThanMax", min=1, max=0), + param("float32Boundary", min=-3.4028234663852886e38), + ] + ) + def test_clamp( + self, + test_name, + min=None, + max=None, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.clamp(x, min, max) + + inputs = [torch.randn(3, 4)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.clamp.default}) + + @parameterized.expand( + [ + param("default", min=-1, max=0), + param("min", min=0.5), + param("max", max=0.5), + param("minBiggerThanMax", min=1, max=0), + ] + ) + def test_clamp_with_dynamic_shape_four_dimensions( + self, + test_name, + min=None, + max=None, + ): + class TestModule(torch.nn.Module): + def forward(self, x): + return torch.clamp(x, min, max) + + class TestScalarModule(torch.nn.Module): + def forward(self, x): + y = torch.mean(x) + return torch.clamp(y, min, max) + + input_specs = [ + Input( + shape=(-1, -1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 1, 3, 3), (3, 3, 3, 3), (5, 5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.clamp.default} + ) + self.run_test_with_dynamic_shape( + TestScalarModule(), input_specs, expected_ops={torch.ops.aten.clamp.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_convolution_aten.py b/tests/py/dynamo/converters/test_convolution_aten.py new file mode 100644 index 0000000000..a906d70d43 --- /dev/null +++ b/tests/py/dynamo/converters/test_convolution_aten.py @@ -0,0 +1,203 @@ +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestConvolutionConverter(DispatchTestCase): + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1), (1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv1d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.convolution.default}, + ) + + def test_conv1d_with_dynamic_shape( + self, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + Input( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1), (1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv2d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, + 6, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} + ) + + # Testing with (-1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + + def test_conv2d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + ## TODO TRT 8.4.1 will trigger issue with this test. T127981773 + # param("groups", 1, groups=3), + ] + ) + def test_conv3d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} + ) + + # Testing with (-1, -1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + + def test_conv3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_elu_aten.py b/tests/py/dynamo/converters/test_elu_aten.py new file mode 100644 index 0000000000..dfaf2db5a6 --- /dev/null +++ b/tests/py/dynamo/converters/test_elu_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestELUConverter(DispatchTestCase): + def test_elu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) + + def test_elu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + def test_elu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.elu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_embedding_aten.py b/tests/py/dynamo/converters/test_embedding_aten.py new file mode 100644 index 0000000000..4d36478303 --- /dev/null +++ b/tests/py/dynamo/converters/test_embedding_aten.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from parameterized import param, parameterized +from torch_tensorrt import Input + + +class TestEmbeddingConverter(DispatchTestCase): + @parameterized.expand( + [ + param( + test_name="1d_indices", + indices_tensor=torch.tensor([3, 1, 2]), + weights_tensor=torch.randn(5, 10), + ), + param( + test_name="2d_indices", + indices_tensor=torch.tensor([[3, 1, 2], [4, 1, 3]]), + weights_tensor=torch.randn(5, 10), + ), + param( + test_name="3d_indices", + indices_tensor=torch.tensor([[[0, 1], [2, 3]], [[3, 4], [4, 0]]]), + weights_tensor=torch.randn(5, 10), + ), + ] + ) + def test_embedding( + self, + test_name, + indices_tensor, + weights_tensor, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + ): + class TestEmbedding(torch.nn.Module): + def forward(self, indices, weights): + return torch.nn.functional.embedding( + input=indices, + weight=weights, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + + self.run_test( + TestEmbedding(), + inputs=[indices_tensor.int(), weights_tensor.float()], + expected_ops={torch.ops.aten.embedding.default}, + ) + + def test_embedding_with_dynamic_shape_four_dimensions( + self, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + ): + class TestEmbedding(torch.nn.Module): + def forward(self, input, weights): + return torch.nn.functional.embedding( + input=input, + weight=weights, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + ) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.int, + shape_ranges=[((1, 1, 1, 1), (2, 3, 4, 5), (2, 3, 10, 10))], + ), + Input( + shape=(-1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1), (2, 3), (2, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestEmbedding(), + input_specs, + expected_ops={torch.ops.aten.embedding.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_expand_aten.py b/tests/py/dynamo/converters/test_expand_aten.py new file mode 100644 index 0000000000..1b1f3d1c14 --- /dev/null +++ b/tests/py/dynamo/converters/test_expand_aten.py @@ -0,0 +1,31 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestExpandConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (2, 3), (2, 1)), + ("3d_dim", (2, 3, 4), (2, 1, 1)), + ("4d_dim", (2, 3, 4, 5), (2, 1, 1, 1)), + ("keep_dim", (2, 3, -1, -1), (2, 1, 5, 5)), + ] + ) + def test_expand(self, _, sizes, init_size): + class Expand(nn.Module): + def forward(self, x): + return x.expand(*sizes) + + inputs = [torch.randn(*init_size)] + self.run_test( + Expand(), + inputs, + expected_ops={torch.ops.aten.expand.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_gelu_aten.py b/tests/py/dynamo/converters/test_gelu_aten.py new file mode 100644 index 0000000000..c62a028c0e --- /dev/null +++ b/tests/py/dynamo/converters/test_gelu_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestGeLUConverter(DispatchTestCase): + def test_gelu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.gelu.default}) + + def test_gelu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + ) + + def test_gelu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.gelu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.gelu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_hardtanh_aten.py b/tests/py/dynamo/converters/test_hardtanh_aten.py new file mode 100644 index 0000000000..8401dd17a9 --- /dev/null +++ b/tests/py/dynamo/converters/test_hardtanh_aten.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestHardTanHConverter(DispatchTestCase): + def test_hardtanh(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + def test_hardtanh_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + def test_hardtanh_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.hardtanh(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.hardtanh.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_layer_norm_aten.py b/tests/py/dynamo/converters/test_layer_norm_aten.py new file mode 100644 index 0000000000..a4766bd030 --- /dev/null +++ b/tests/py/dynamo/converters/test_layer_norm_aten.py @@ -0,0 +1,45 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestLayerNormConverter(DispatchTestCase): + def test_layer_norm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.layer_norm.default} + ) + + def test_layernorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.ln = torch.nn.LayerNorm([3, 224, 224]) + + def forward(self, x): + return self.ln(x) + + input_specs = [ + Input( + shape=(-1, 3, 224, 224), + dtype=torch.float32, + shape_ranges=[((1, 3, 224, 224), (1, 3, 224, 224), (2, 3, 224, 224))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.layer_norm.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_leaky_relu_aten.py b/tests/py/dynamo/converters/test_leaky_relu_aten.py new file mode 100644 index 0000000000..aa3d56641b --- /dev/null +++ b/tests/py/dynamo/converters/test_leaky_relu_aten.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestLeakyReLUConverter(DispatchTestCase): + def test_leaky_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + def test_leaky_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + def test_leaky_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.leaky_relu(x, negative_slope=0.05) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.leaky_relu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_linear_aten.py b/tests/py/dynamo/converters/test_linear_aten.py new file mode 100644 index 0000000000..b9e3261642 --- /dev/null +++ b/tests/py/dynamo/converters/test_linear_aten.py @@ -0,0 +1,71 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestLinearConverter(DispatchTestCase): + @parameterized.expand( + [ + ("default", [1, 512], True, torch.ops.aten.linear), + ("matrix", [5, 512], True, torch.ops.aten.linear), + ("no_bias", [1, 512], False, torch.ops.aten.linear), + ( + "multi_dim_matrix", + [4, 5, 512], + True, + torch.ops.aten.linear, + ), + ( + "multi_dim_matrix", + [4, 5, 512], + False, + torch.ops.aten.linear, + ), + ] + ) + def test_linear(self, test_name, shape, bias, op): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256, bias) + + def forward(self, x): + return self.linear(x) + + inputs = [torch.randn(shape)] + self.run_test(TestModule(), inputs, expected_ops={op}) + + # linear will be decomposed to P531484488 and view(reshape) can not handle reshape pattern + # like (2, 3, n)->(6, n) in implicit mode which is similar to dynamic shape test below. + + # Input is transposed through view [3,3,512]->[9,512]. Converter does not know dim=0 is dynamic now. + + # def test_linear_with_dynamic_shape(self): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.linear = torch.nn.Linear(512, 256) + + # def forward(self, x): + # return self.linear(x) + + # input_specs = [ + # Input( + # shape=(-1, 3, 512), + # dtype=torch.float32, + # shape_ranges=[((1, 3, 512), (3, 3, 512), (4, 3, 512))], + # ), + # ] + # self.run_test_with_dynamic_shape( + # TestModule(), + # input_specs, + # expected_ops={torch.ops.aten.addmm.default}, + # ) + + ## Testing with (-1, -1, 512) results into following error: + ## AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_matmul_aten.py b/tests/py/dynamo/converters/test_matmul_aten.py new file mode 100644 index 0000000000..f01325fb10 --- /dev/null +++ b/tests/py/dynamo/converters/test_matmul_aten.py @@ -0,0 +1,97 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestMatMulConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + ("2_2", (2, 3), (3, 1)), + # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? + # (2,3), (3,) torch.ops.aten.mv.default + # Following cases use torch.ops.aten.bmm.defauly + # ("4_3", (3,1,3,2), (2,2,3)), + # ("3_4", (3,1,3,2), (2,2,3)), + # ("3_4", (2, 2, 3), (3, 1, 3, 3)), + # ("4_2", (1, 2, 2, 3), (3, 2)), + ] + ) + def test_matmul_other_constant(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.other = nn.Parameter(torch.randn(*other_shape)) + + def forward(self, input): + return torch.matmul(input, self.other) + + inputs = [torch.randn(*input_shape)] + + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + ) + + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + ("1_2", (1, 3), (3, 2)), + # FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm? + # (2,3), (3,) torch.ops.aten.mv.default + # Following cases use torch.ops.aten.bmm.defauly + # ("4_3", (3,1,3,2), (2,2,3)), + # ("3_4", (3,1,3,2), (2,2,3)), + # ("3_4", (2, 2, 3), (3, 1, 3, 3)), + # ("4_2", (1, 2, 2, 3), (3, 2)), + ] + ) + def test_matmul_input_constant(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def __init__(self): + super().__init__() + self.input = nn.Parameter(torch.randn(*input_shape)) + + def forward(self, other): + return torch.matmul(self.input, other) + + inputs = [torch.randn(*other_shape)] + + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + ) + + @parameterized.expand( + [ + ("2_2", (2, 3), (3, 2)), + # ("2_3", (2, 3), (2, 3, 4)), + # ("4_4", (2, 2, 2, 3), (2, 1, 3, 2)), + # ("4_2", (2, 1, 2, 3), (3, 2)), + # ("2_1", (2, 3), (3,)), + # ("1_2", (3,), (3, 2)), + # ("1_1", (3,), (3,)), + ] + ) + def test_matmul(self, _, input_shape, other_shape): + class MatMul(nn.Module): + def forward(self, input, other): + return torch.matmul(input, other) + + inputs = [torch.randn(*input_shape), torch.randn(*other_shape)] + + self.run_test( + MatMul(), + inputs, + expected_ops={torch.ops.aten.mm.default}, + ) + + # FIXME: dynamic shape is giving bmm + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_mean_aten.py b/tests/py/dynamo/converters/test_mean_aten.py new file mode 100644 index 0000000000..fe31d90a24 --- /dev/null +++ b/tests/py/dynamo/converters/test_mean_aten.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestMeanDimConverter(DispatchTestCase): + def test_mean_dim_keepdims(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=[0, 1], keepdim=True) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim}) + + def test_mean_dim_keepdims_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=[0, 1, 2], keepdim=True) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim} + ) + + def test_mean_dim_keepdims_false(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=0, keepdim=False) + + inputs = [torch.randn(3, 5, 7)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.dim}) + + def test_mean_dim_keepdims_false_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x, dim=-1, keepdim=False) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.mean.dim} + ) + + +class TestMeanConverter(DispatchTestCase): + def test_mean(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x) + + inputs = [torch.randn(3, 8, 5, 7, 1)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.mean.default}) + + def test_mean_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return torch.mean(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 5, 8), (3, 10, 10))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.mean.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_permutation_aten.py b/tests/py/dynamo/converters/test_permutation_aten.py new file mode 100644 index 0000000000..f9d614ae68 --- /dev/null +++ b/tests/py/dynamo/converters/test_permutation_aten.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestPermuteConverter(DispatchTestCase): + @parameterized.expand( + [ + ("positive", [0, 2, 1]), + ("negative", [0, -1, -2]), + ] + ) + def test_permute_list(self, _, permutation): + class Permute(nn.Module): + def forward(self, x): + return x.permute(permutation) + + inputs = [torch.randn(1, 3, 2)] + self.run_test(Permute(), inputs, expected_ops={torch.ops.aten.permute.default}) + + @parameterized.expand( + [ + ("positive", [0, 2, 1]), + ("negative", [0, -1, -2]), + ] + ) + def test_permute(self, _, permutation): + class Permute(nn.Module): + def forward(self, x): + return x.permute(*permutation) + + inputs = [torch.randn(1, 3, 2)] + self.run_test(Permute(), inputs, expected_ops={torch.ops.aten.permute.default}) + + def test_permute_with_dynamic_shape(self): + class Permute(nn.Module): + def forward(self, x): + return x.permute(1, 2, 0) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Permute(), input_specs, expected_ops={torch.ops.aten.permute.default} + ) + + def test_permute_with_dynamic_shape_four_dimensions(self): + class Permute(nn.Module): + def forward(self, x): + return x.permute(1, 2, 3, 0) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + Permute(), input_specs, expected_ops={torch.ops.aten.permute.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_relu_aten.py b/tests/py/dynamo/converters/test_relu_aten.py new file mode 100644 index 0000000000..08ab04014d --- /dev/null +++ b/tests/py/dynamo/converters/test_relu_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestReLUConverter(DispatchTestCase): + def test_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.relu.default}) + + def test_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.relu.default} + ) + + def test_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.relu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_reshape_aten.py b/tests/py/dynamo/converters/test_reshape_aten.py new file mode 100644 index 0000000000..1df71abc1a --- /dev/null +++ b/tests/py/dynamo/converters/test_reshape_aten.py @@ -0,0 +1,103 @@ +import unittest + +import tensorrt as trt +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestReshapeConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1, 20),), + ((1, 10, -1),), + ] + ) + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape + + def forward(self, x): + return torch.reshape(x, self.target_shape) + + inputs = [torch.randn(1, 2, 10)] + self.run_test( + TestModule(target_shape), + inputs, + expected_ops={torch.ops.aten.view.default}, + ) + + @parameterized.expand( + [ + ((-1, 10),), + ((-1, 5),), + ((2, 2, -1),), + ] + ) + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape_with_dynamic_shape(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape + + def forward(self, x): + return torch.reshape(x, self.target_shape) + + input_specs = [ + Input( + shape=(-1, 2, 5), + dtype=torch.float32, + shape_ranges=[((1, 2, 5), (10, 2, 5), (10, 2, 5))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(target_shape), + input_specs, + expected_ops={torch.ops.aten.view.default}, + ) + + @unittest.skipIf( + trt.__version__ < "8.5", + "Shape tensor supported well in TensorRT 8.5 and later", + ) + def test_reshape_with_dynamic_shape_size(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + shape_y = y.shape + t = shape_y[1] + return torch.reshape(x, [-1, t, 3]) + + input_specs = [ + Input( + shape=(-1, 5, 6), + dtype=torch.float32, + shape_ranges=[((1, 5, 6), (3, 5, 6), (3, 5, 6))], + ), + Input( + shape=(-1, 5), + dtype=torch.float32, + shape_ranges=[((1, 5), (3, 5), (3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.view.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_rsqrt_aten.py b/tests/py/dynamo/converters/test_rsqrt_aten.py new file mode 100644 index 0000000000..5770e697fc --- /dev/null +++ b/tests/py/dynamo/converters/test_rsqrt_aten.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestRSqrtConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsqrt(self, _, x, alpha): + class rsqrt(nn.Module): + def forward(self, input): + return torch.rsqrt(input) + + inputs = [torch.randn(x) + 1] + self.run_test( + rsqrt(), + inputs, + expected_ops={torch.ops.aten.rsqrt.default}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_select_aten.py b/tests/py/dynamo/converters/test_select_aten.py new file mode 100644 index 0000000000..049cd9c7e6 --- /dev/null +++ b/tests/py/dynamo/converters/test_select_aten.py @@ -0,0 +1,79 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSelectConverterOne(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input = [torch.randn(1, 2)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.select.int}, + ) + + +class TestSelectConverterTwo(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input = [torch.randn(4, 4, 4, 4)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.select.int}, + ) + + +class TestSelectConverterWithDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_index", 1, 0), + ] + ) + def test_select_with_dynamic_shape(self, _, dim, index): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.select(input, dim, index) + + input_spec = [ + Input( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_spec, expected_ops={torch.ops.aten.select.int} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_selu_aten.py b/tests/py/dynamo/converters/test_selu_aten.py new file mode 100644 index 0000000000..7fb6afda76 --- /dev/null +++ b/tests/py/dynamo/converters/test_selu_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSeLUConverter(DispatchTestCase): + def test_selu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.elu.default}) + + def test_selu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + def test_selu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.selu(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.elu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_sigmoid_aten.py b/tests/py/dynamo/converters/test_sigmoid_aten.py new file mode 100644 index 0000000000..37bbea1730 --- /dev/null +++ b/tests/py/dynamo/converters/test_sigmoid_aten.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSigmoidConverter(DispatchTestCase): + def test_sigmoid(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + def test_sigmoid_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + def test_sigmoid_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.sigmoid.default} + ) + + def test_sigmoid_fp16(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.sigmoid(x) + + inputs = [torch.randn(1, 10)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.sigmoid.default}, + precision=torch.half, + check_dtype=False, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_slice_aten.py b/tests/py/dynamo/converters/test_slice_aten.py new file mode 100644 index 0000000000..86de36d351 --- /dev/null +++ b/tests/py/dynamo/converters/test_slice_aten.py @@ -0,0 +1,86 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSelectConverterImplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 0, 0, 7, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input = [torch.randn(10, 2, 3, 1)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.slice.Tensor}, + ) + + +class TestSelectConverterExplicitBatch(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 1, 0, 7, 2), + ("select_dim_start_stop_step_exact", 1, 0, 10, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input = [torch.randn(10, 10, 3, 1)] + self.run_test( + TestModule(), + input, + expected_ops={torch.ops.aten.slice.Tensor}, + ) + + +class TestSelectConverterDynamicShape(DispatchTestCase): + @parameterized.expand( + [ + ("select_dim_start_stop_step", 1, 0, 7, 2), + ("select_dim_start_stop_step", 1, 0, 10, 2), + ] + ) + def test_slice(self, _, dim, start, stop, step): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + out = torch.ops.aten.slice.Tensor(input, dim, start, stop, step) + return out + + input_specs = [ + Input( + shape=(1, 10, -1), + dtype=torch.float32, + shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.slice.Tensor}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_softmax_aten.py b/tests/py/dynamo/converters/test_softmax_aten.py new file mode 100644 index 0000000000..8d33f3ebe0 --- /dev/null +++ b/tests/py/dynamo/converters/test_softmax_aten.py @@ -0,0 +1,45 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSoftMaxConverter(DispatchTestCase): + def test_softmax(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(1) + + def forward(self, x): + return self.softmax(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten._softmax.default} + ) + + def test_softmax_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.softmax = torch.nn.Softmax(2) + + def forward(self, x): + return self.softmax(x) + + input_specs = [ + Input( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_squeeze_aten.py b/tests/py/dynamo/converters/test_squeeze_aten.py new file mode 100644 index 0000000000..152fe86300 --- /dev/null +++ b/tests/py/dynamo/converters/test_squeeze_aten.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (0), (2, 1)), + ("3d_one_dim", (0), (2, 2, 1)), + ("3d_two_dim", (0, 1), (2, 1, 1)), + ("4d_dim", (0, 1, 2), (2, 2, 1, 1)), + ] + ) + def test_squeeze(self, _, dim, init_size): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + inputs = [torch.randn(*init_size)] + expected_op = {} + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + self.run_test( + Squeeze(), + inputs, + expected_ops=expected_op, + ) + + +class TestSqueezeConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim", (1), (-1, 1), [((1, 1), (1, 1), (3, 1))]), + ("3d_one_dim", (1), (-1, 2, 1), [((1, 2, 1), (1, 2, 1), (3, 2, 1))]), + # ("3d_two_dim", (0, 1), (-1, -1, 1), [((1, 3, 1, 1), (1, 3, 1, 1))]), + ] + ) + def test_squeeze(self, _, dim, init_size, shape_range): + class Squeeze(nn.Module): + def forward(self, x): + return torch.squeeze(x, dim) + + if isinstance(dim, int) == 1: + expected_op = {torch.ops.aten.squeeze.dim} + else: + expected_op = {torch.ops.aten.squeeze.dims} + input_specs = [ + Input( + shape=init_size, + dtype=torch.float32, + shape_ranges=shape_range, + ), + ] + self.run_test_with_dynamic_shape( + Squeeze(), + input_specs, + expected_ops=expected_op, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_tanh_aten.py b/tests/py/dynamo/converters/test_tanh_aten.py new file mode 100644 index 0000000000..f9aa94a7bc --- /dev/null +++ b/tests/py/dynamo/converters/test_tanh_aten.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestTanhConverter(DispatchTestCase): + def test_tanh(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.tanh.default}) + + def test_tanh_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + input_specs = [ + Input( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default} + ) + + def test_tanh_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.tanh(x) + + input_specs = [ + Input( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.tanh.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_unsqueeze_aten.py b/tests/py/dynamo/converters/test_unsqueeze_aten.py new file mode 100644 index 0000000000..db8ae7151f --- /dev/null +++ b/tests/py/dynamo/converters/test_unsqueeze_aten.py @@ -0,0 +1,62 @@ +import torch +import torch.fx +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase +from torch_tensorrt import Input + + +class TestUnsqueeze(DispatchTestCase): + @parameterized.expand( + [ + ("negative_dim", -2), + ("positive_dim", 2), + ] + ) + def test_unsqueeze(self, _, dim): + class Unsqueeze(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.unsqueeze(x, self.dim) + + inputs = [torch.randn(1, 2, 3)] + self.run_test( + Unsqueeze(dim), inputs, expected_ops={torch.ops.aten.unsqueeze.default} + ) + + # Testing with more than one dynamic dims results in following error: + # AssertionError: Currently we don't support unsqueeze with more than one dynamic dims. + + @parameterized.expand( + [ + ("negative_dim_dynamic", -4), + ("positive_dim_dynamic", 1), + ] + ) + def test_unsqueeze_with_dynamic_shape(self, _, dim): + class Unsqueeze(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.unsqueeze(x, self.dim) + + input_specs = [ + Input( + shape=(-1, 2, 3), + dtype=torch.float32, + shape_ranges=[((1, 2, 3), (2, 2, 3), (3, 2, 3))], + ), + ] + self.run_test_with_dynamic_shape( + Unsqueeze(dim), input_specs, expected_ops={torch.ops.aten.unsqueeze.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/converters/test_where_aten.py b/tests/py/dynamo/converters/test_where_aten.py new file mode 100644 index 0000000000..39ba0500b9 --- /dev/null +++ b/tests/py/dynamo/converters/test_where_aten.py @@ -0,0 +1,33 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.dynamo.test_utils import DispatchTestCase + + +class TestWhereConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_condition_xshape_yshape", (2, 2), (2, 2)), + ("2d_broadcast_condition_xshape_yshape", (2, 2), (2, 1)), + ("3d_condition_xshape_yshape", (2, 2, 1), (2, 2, 1)), + ("2d_3d_condition_xshape_yshape", (2, 2), (1, 2, 2)), + ] + ) + def test_(self, _, x_size, y_size): + class Where(nn.Module): + def forward(self, condition, x, y): + return torch.where(condition, x, y) + + inputX = torch.randn(*x_size) + inputOther = torch.randn(*y_size) + condition = inputX < 0 + self.run_test( + Where(), + (condition, inputX, inputOther), + expected_ops={torch.ops.aten.where.self}, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index 1141d54a7b..0fdfcb3fd0 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -32,7 +32,7 @@ def test_resnet18(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 8, + "min_block_size": 10, "ir": "torch_compile", } @@ -66,7 +66,7 @@ def test_mobilenet_v2(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 8, + "min_block_size": 10, "ir": "torch_compile", } @@ -100,7 +100,7 @@ def test_efficientnet_b0(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 8, + "min_block_size": 10, "ir": "torch_compile", } @@ -143,7 +143,7 @@ def test_bert_base_uncased(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 8, + "min_block_size": 10, "ir": "torch_compile", } trt_mod = torchtrt.compile(model, **compile_spec) @@ -181,7 +181,7 @@ def test_resnet18_half(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 8, + "min_block_size": 10, "ir": "torch_compile", }