diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py similarity index 89% rename from py/torch_tensorrt/dynamo/backend/_defaults.py rename to py/torch_tensorrt/dynamo/_defaults.py index 0afbc60f8c..c130a2154d 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -10,3 +10,4 @@ VERSION_COMPATIBLE = False OPTIMIZATION_LEVEL = None USE_PYTHON_RUNTIME = None +TRUNCATE_LONG_AND_DOUBLE = False diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index 38e60fce41..03b52f518d 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -5,12 +5,13 @@ from functools import partial from typing import Any, Optional, Sequence -from torch_tensorrt import EngineCapability, Device +from torch_tensorrt._Device import Device +from torch_tensorrt._enums import EngineCapability from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt.dynamo.backend.utils import prepare_inputs, prepare_device from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend -from torch_tensorrt.dynamo.backend._defaults import ( +from torch_tensorrt.dynamo._defaults import ( PRECISION, DEBUG, WORKSPACE_SIZE, @@ -20,6 +21,7 @@ VERSION_COMPATIBLE, OPTIMIZATION_LEVEL, USE_PYTHON_RUNTIME, + TRUNCATE_LONG_AND_DOUBLE, ) @@ -43,7 +45,7 @@ def compile( dla_local_dram_size=1073741824, dla_global_dram_size=536870912, calibrator=None, - truncate_long_and_double=False, + truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE, require_full_compilation=False, min_block_size=MIN_BLOCK_SIZE, torch_executed_ops=[], @@ -62,7 +64,7 @@ def compile( "The Dynamo backend is an experimental feature, for which only the " + "following arguments are supported: " + "{enabled_precisions, debug, workspace_size, min_block_size, " - + "torch_executed_ops, pass_through_build_failures}" + + "truncate_long_and_double, torch_executed_ops, pass_through_build_failures}" ) if not isinstance(inputs, collections.abc.Sequence): @@ -103,6 +105,7 @@ def compile( version_compatible=version_compatible, optimization_level=optimization_level, use_python_runtime=use_python_runtime, + truncate_long_and_double=truncate_long_and_double, **kwargs, ) @@ -130,6 +133,7 @@ def create_backend( version_compatible: bool = VERSION_COMPATIBLE, optimization_level: Optional[int] = OPTIMIZATION_LEVEL, use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, + truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE, **kwargs, ): """Create torch.compile backend given specified arguments @@ -163,5 +167,6 @@ def create_backend( version_compatible=version_compatible, optimization_level=optimization_level, use_python_runtime=use_python_runtime, + truncate_long_and_double=truncate_long_and_double, **kwargs, ) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index b97079948e..2ebb783ad2 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -4,7 +4,7 @@ from functools import partial import torch._dynamo as td -from torch_tensorrt.dynamo.backend._settings import CompilationSettings +from torch_tensorrt.dynamo.common import CompilationSettings from torch_tensorrt.dynamo.backend.lowering._decompositions import ( get_decompositions, ) @@ -16,6 +16,7 @@ get_submod_inputs, ) from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs +from torch_tensorrt.dynamo.common import repair_long_or_double_inputs from torch_tensorrt.dynamo.backend.conversion import convert_module from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler @@ -134,6 +135,16 @@ def _compile_module( partitioned_module, submodule, sample_inputs ) + # Ensure all submodule inputs do not require a gradient + for param in submodule_inputs: + param.requires_grad = False + + # Handle long/double inputs if requested by the user + if settings.truncate_long_and_double: + submodule_inputs = repair_long_or_double_inputs( + partitioned_module, submodule, submodule_inputs, name + ) + # Create TRT Module from submodule trt_mod = convert_module( submodule, diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index 425fb0941e..a4e25c5231 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -2,8 +2,8 @@ import torch import io from torch_tensorrt.fx.trt_module import TRTModule -from torch_tensorrt.dynamo.backend._settings import CompilationSettings -from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import ( +from torch_tensorrt.dynamo.common import ( + CompilationSettings, InputTensorSpec, TRTInterpreter, ) diff --git a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py index 4d82bf4be5..2158a940aa 100644 --- a/py/torch_tensorrt/dynamo/backend/lowering/_partition.py +++ b/py/torch_tensorrt/dynamo/backend/lowering/_partition.py @@ -3,8 +3,8 @@ import torch -from torch_tensorrt.dynamo.backend._defaults import MIN_BLOCK_SIZE from torch_tensorrt.dynamo.backend.lowering import SUBSTITUTION_REGISTRY +from torch_tensorrt.dynamo._defaults import MIN_BLOCK_SIZE from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner, Partition from torch.fx.graph_module import GraphModule from torch.fx.node import _get_qualified_name diff --git a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py b/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py index 2af251adbc..c5275f4a7b 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py @@ -4,7 +4,7 @@ from copy import deepcopy from torch_tensorrt.dynamo import compile from utils import lower_graph_testing -from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT +from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT class TestTRTModuleNextCompilation(TestCase): @@ -169,5 +169,116 @@ def forward(self, x, y): ) +class Test64BitInput(TestCase): + def test_float64_input_full_support(self): + class FullySupportedMultiOp(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.mean.dim( + torch.ops.aten.mul.Tensor(torch.ops.aten.add.Tensor(x, y), 2), [0] + ) + + fx_graph = torch.fx.symbolic_trace(FullySupportedMultiOp()) + partitioned_graph = partition(deepcopy(fx_graph), min_block_size=3) + + self.assertEquals( + len(list(partitioned_graph.named_children())), + 1, + "All operators are supported, there should be one segment", + ) + + inputs = [ + torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(), + torch.randint(-5, 5, (16, 7), dtype=torch.double).cuda(), + ] + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, + inputs, + min_block_size=1, + pass_through_build_failures=True, + truncate_long_and_double=True, + debug=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"TRT outputs don't match with the original model.", + ) + + def test_int64_input_partial_support(self): + class PartiallySupportedMultiOp(torch.nn.Module): + def forward(self, x, y): + return torch.ops.aten.div.Tensor_mode( + x, torch.ops.aten.add.Tensor(y, y), rounding_mode="floor" + ) + + fx_graph = torch.fx.symbolic_trace(PartiallySupportedMultiOp()) + unexpected_ops = {torch.ops.aten.add.Tensor} + + inputs = [ + torch.randint(-40, 40, (16, 7, 5), dtype=torch.long).cuda(), + torch.randint(1, 40, (16, 7, 5), dtype=torch.long).cuda(), + ] + + (unexpected_ops_seen, _, partitioned_graphs,) = lower_graph_testing( + fx_graph, + inputs, + unexpected_ops=unexpected_ops, + min_block_size=1, + torch_executed_ops={"torch.ops.aten.add.Tensor"}, + testing_partitioning=True, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + self.assertEquals( + len(partitioned_graphs), + 1, + "Without control flow breaks, there should only be a single graph", + ) + self.assertEquals( + len(list(partitioned_graphs[0].named_children())), + 1, + "Certain operators are set to run in Torch, expected 1 segment", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = compile( + fx_graph, + inputs, + min_block_size=1, + pass_through_build_failures=True, + truncate_long_and_double=True, + debug=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"TRT outputs don't match with the original model.", + ) + + if __name__ == "__main__": run_tests() diff --git a/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py b/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py index d947c955e0..340797aa69 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_decompositions.py @@ -3,7 +3,7 @@ from torch.testing._internal.common_utils import run_tests, TestCase import torch from torch_tensorrt.dynamo import compile -from torch_tensorrt.dynamo.common_utils.test_utils import DECIMALS_OF_AGREEMENT +from torch_tensorrt.dynamo.common.test_utils import DECIMALS_OF_AGREEMENT class TestLowering(TestCase): diff --git a/py/torch_tensorrt/dynamo/backend/utils.py b/py/torch_tensorrt/dynamo/backend/utils.py index 23a1cd4795..ed8c8ab932 100644 --- a/py/torch_tensorrt/dynamo/backend/utils.py +++ b/py/torch_tensorrt/dynamo/backend/utils.py @@ -2,10 +2,10 @@ import logging from dataclasses import replace, fields -from torch_tensorrt.dynamo.backend._settings import CompilationSettings +from torch_tensorrt.dynamo.common import CompilationSettings, use_python_runtime_parser from typing import Any, Union, Sequence, Dict -from torch_tensorrt import _Input, Device -from ..common_utils import use_python_runtime_parser +from torch_tensorrt import _Input +from torch_tensorrt._Device import Device logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/common_utils/__init__.py b/py/torch_tensorrt/dynamo/common/__init__.py similarity index 84% rename from py/torch_tensorrt/dynamo/common_utils/__init__.py rename to py/torch_tensorrt/dynamo/common/__init__.py index de0ce0a48a..68120aadd9 100644 --- a/py/torch_tensorrt/dynamo/common_utils/__init__.py +++ b/py/torch_tensorrt/dynamo/common/__init__.py @@ -1,6 +1,11 @@ import logging from typing import Optional +from ._settings import CompilationSettings +from .input_tensor_spec import InputTensorSpec +from .fx2trt import TRTInterpreter, TRTInterpreterResult +from .truncate_long_and_double import repair_long_or_double_inputs + logger = logging.getLogger(__name__) diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/common/_settings.py similarity index 86% rename from py/torch_tensorrt/dynamo/backend/_settings.py rename to py/torch_tensorrt/dynamo/common/_settings.py index d074a6b079..34e9be62c6 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/common/_settings.py @@ -2,7 +2,7 @@ from typing import Optional, Sequence from torch_tensorrt.fx.utils import LowerPrecision -from torch_tensorrt.dynamo.backend._defaults import ( +from torch_tensorrt.dynamo._defaults import ( PRECISION, DEBUG, WORKSPACE_SIZE, @@ -12,6 +12,7 @@ VERSION_COMPATIBLE, OPTIMIZATION_LEVEL, USE_PYTHON_RUNTIME, + TRUNCATE_LONG_AND_DOUBLE, ) @@ -27,3 +28,4 @@ class CompilationSettings: version_compatible: bool = VERSION_COMPATIBLE optimization_level: Optional[int] = OPTIMIZATION_LEVEL use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME + truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/common/fx2trt.py similarity index 100% rename from py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py rename to py/torch_tensorrt/dynamo/common/fx2trt.py diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py b/py/torch_tensorrt/dynamo/common/input_tensor_spec.py similarity index 100% rename from py/torch_tensorrt/dynamo/fx_ts_compat/input_tensor_spec.py rename to py/torch_tensorrt/dynamo/common/input_tensor_spec.py diff --git a/py/torch_tensorrt/dynamo/common_utils/test_utils.py b/py/torch_tensorrt/dynamo/common/test_utils.py similarity index 100% rename from py/torch_tensorrt/dynamo/common_utils/test_utils.py rename to py/torch_tensorrt/dynamo/common/test_utils.py diff --git a/py/torch_tensorrt/dynamo/common/truncate_long_and_double.py b/py/torch_tensorrt/dynamo/common/truncate_long_and_double.py new file mode 100644 index 0000000000..fc3263de57 --- /dev/null +++ b/py/torch_tensorrt/dynamo/common/truncate_long_and_double.py @@ -0,0 +1,207 @@ +import torch +from torch.fx.node import _get_qualified_name +from typing import Optional, Sequence, Union + + +def _extract_downstream_get_nodes( + module_node: torch.fx.Node, output_indices: Sequence[int] +) -> Sequence[torch.fx.Node]: + """Extracts downstream users of a node which get the item at a particular index + + Certain module-type nodes have multiple outputs (tuple of outputs). This function + returns downstream nodes which call the _operator.getitem function, which extracts + the element at a particular index in the tuple + + Args: + module_node: FX module-type node to analyze + output_index: Indices in the module node output to search for + Returns: + List of nodes which get the item at the specified index in the module node output + """ + get_nodes = [] + + # Iterate over all downstream users of the node object + for user in module_node.users: + # If the user is a "get" node accessing the specified index, store it + if _get_qualified_name(user.target) == "_operator.getitem" and ( + user.args[1] in output_indices + ): + get_nodes.append(user) + + return get_nodes + + +def _repair_64bit_input( + gm: torch.fx.GraphModule, + position: int, + submodule_name: str, + submodule_outputs: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]], + dtype: torch.dtype, +): + """Fixes a single Long/Double input to a TRT-accelerated subgraph + + In-Place modifies the provided graph + + Inserts a cast to the 32-bit equivalent type for TRT, then if necessary, + inserts an upcast back to the 64-bit type for subsequent Torch operations + + Args: + gm: FX GraphModule enclosing the TRT subgraph + position: Index in the submodule inputs at which the long or double input is found + submodule_name: Name of TRT-accelerated subgraph module in FX graph + submodule_outputs: Output tensor(s) of TRT-accelerated subgraph (used for dtypes/structure) + dtype: Data type of tensor at position in submodule (double/long) + """ + assert dtype in ( + torch.int64, + torch.float64, + ), f"dtype argument must be torch.int64 or torch.float64, got {dtype}" + + # Determine target data type in 32 and 64 bit forms + dtype_64bit = dtype + dtype_32bit = torch.int32 if (dtype == torch.int64) else torch.float32 + + # Find the node representing the submodule in the graph + module_node = None + + # Iterate over all nodes in the graph, seeking target module name match + for n in gm.graph.nodes: + if n.op == "call_module" and str(n.target) == submodule_name: + module_node = n + break + + if module_node is None: + raise AssertionError( + f"Sought module node {submodule_name}, could not find in graph:\n{gm.graph}" + ) + + # Extract the 64-bit node of the input + node_64bit = module_node.all_input_nodes[position] + + # Prior to the module, insert a cast to the 32-bit equivalent node + with gm.graph.inserting_before(module_node): + node_32bit = gm.graph.call_function( + torch.ops.aten._to_copy.default, + args=(node_64bit,), + kwargs={"dtype": dtype_32bit}, + ) + + # Replace 64-bit input to TRT module with new 32-bit cast node + module_node.replace_input_with(node_64bit, node_32bit) + + output_positions_64bit = set() + outputs_list = ( + [submodule_outputs] + if isinstance(submodule_outputs, torch.Tensor) + else submodule_outputs + ) + + # Determine if any outputs of the model are 64-bit type and store their indices + if submodule_outputs is not None: + for output_position, output in enumerate(outputs_list): + if output.dtype == dtype_64bit: + output_positions_64bit.add(output_position) + + # Only enter this code block if there exists a 64-bit output + # This implies a cast is needed, since TRT cannot output 64-bit tensors + if output_positions_64bit: + # Determine whther the outputs of the module are tuple-type or not + is_collection_output = False + if isinstance(submodule_outputs, tuple): + is_collection_output = True + + if not is_collection_output: + # If the output is a single tensor, insert a cast back to int64 + with gm.graph.inserting_after(module_node): + cast_node_64bit = gm.graph.call_function( + torch.ops.aten._to_copy.default, + args=(module_node,), + kwargs={"dtype": dtype_64bit}, + ) + + # Replace all uses of the TRT module (except the cast node) with the 64-bit equivalent + module_node.replace_all_uses_with( + cast_node_64bit, delete_user_cb=lambda user: (user != cast_node_64bit) + ) + + else: + # If the output is a tuple of tensors, extract downstream users for each 64-bit output + get_nodes = _extract_downstream_get_nodes( + module_node, output_positions_64bit + ) + + # For each downstream user, append a cast node back to the 64-bit precision + for get_node in get_nodes: + with gm.graph.inserting_after(get_node): + cast_node_64bit = gm.graph.call_function( + torch.ops.aten._to_copy.default, + args=(get_node,), + kwargs={"dtype": torch.int64}, + ) + + get_node.replace_all_uses_with( + cast_node_64bit, + delete_user_cb=lambda user: (user != cast_node_64bit), + ) + + # Clean up graph and ensure invariants are preserved + gm.graph.eliminate_dead_code() + gm.graph.lint() + gm.recompile() + + +def repair_long_or_double_inputs( + parent_graph: torch.fx.GraphModule, + submodule: torch.fx.GraphModule, + submodule_inputs: Sequence[torch.Tensor], + submodule_name: Optional[str] = None, +) -> Sequence[torch.Tensor]: + """Fixes all Long/Double type inputs to a TRT-accelerated subgraph + + In-Place modifies the provided graph + + Inserts a cast to the 32-bit equivalent type for TRT, then if necessary, + inserts an upcast back to the 64-bit type for subsequent Torch operations + + Args: + parent_graph: FX GraphModule enclosing the TRT subgraph + submodule: Child submodule to repair inputs on + submodule_inputs: Input tensor(s) of TRT-accelerated subgraph (used for dtypes/structure) + submodule_name: Optionally specify the name of the submodule target in the parent graph + Returns: + New submodule inputs, updated accordingly with long/double truncation + """ + num_submodule_inputs = len(submodule_inputs) + repaired_outputs_once = False + + # For each input to the TRT subgraph, check if its type is long/double + for position in range(num_submodule_inputs): + param = submodule_inputs[position] + + # If the data type of the input is long/double, insert necessary + # casts to replace the operation + if param.dtype in (torch.int64, torch.float64): + # Ensure outputs are only repaired once per submodule to avoid + # unnecessary ops showing up in the graph + if not repaired_outputs_once: + submodule_outputs = submodule(*submodule_inputs) + + _repair_64bit_input( + parent_graph, + position, + submodule_name if submodule_name is not None else submodule._get_name(), + None if repaired_outputs_once else submodule_outputs, + param.dtype, + ) + + repaired_outputs_once = True + + # Repair submodule inputs in accordance with inserted casts + dtype_32bit = torch.int32 if (param.dtype == torch.int64) else torch.float32 + submodule_inputs = ( + submodule_inputs[:position] + + (param.to(dtype_32bit),) + + submodule_inputs[position + 1 :] + ) + + return submodule_inputs diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py b/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py index 85ce01ef20..3c17701d5d 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/__init__.py @@ -6,8 +6,6 @@ NO_IMPLICIT_BATCH_DIM_SUPPORT, tensorrt_converter, ) -from .fx2trt import TRTInterpreter, TRTInterpreterResult # noqa -from .input_tensor_spec import InputTensorSpec # noqa from .lower_setting import LowerSetting # noqa from .lower import compile # usort: skip #noqa diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py index c0f1ae7870..96b343cf29 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py @@ -10,11 +10,14 @@ import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer from torch.fx.passes.splitter_base import SplitResult -from .fx2trt import TRTInterpreter, TRTInterpreterResult +from torch_tensorrt.dynamo.common import ( + TRTInterpreter, + TRTInterpreterResult, + use_python_runtime_parser, +) from .lower_setting import LowerSetting from .passes.lower_pass_manager_builder import LowerPassManagerBuilder from .passes.pass_utils import PassFunc, validate_inference -from ..common_utils import use_python_runtime_parser from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting @@ -22,6 +25,18 @@ from torch_tensorrt.fx.trt_module import TRTModule from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt._Device import Device +from torch_tensorrt.dynamo._defaults import ( + PRECISION, + DEBUG, + WORKSPACE_SIZE, + MIN_BLOCK_SIZE, + PASS_THROUGH_BUILD_FAILURES, + MAX_AUX_STREAMS, + VERSION_COMPATIBLE, + OPTIMIZATION_LEVEL, + USE_PYTHON_RUNTIME, + TRUNCATE_LONG_AND_DOUBLE, +) logger = logging.getLogger(__name__) @@ -35,24 +50,25 @@ def compile( disable_tf32=False, sparse_weights=False, enabled_precisions=set(), - min_block_size: int = 3, - workspace_size=0, + min_block_size: int = MIN_BLOCK_SIZE, + workspace_size=WORKSPACE_SIZE, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, calibrator=None, - truncate_long_and_double=False, + truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE, require_full_compilation=False, - debug=False, + explicit_batch_dimension=False, + debug=DEBUG, refit=False, timing_cache_prefix="", save_timing_cache=False, cuda_graph_batch_size=-1, is_aten=False, - use_python_runtime=None, - max_aux_streams=None, - version_compatible=False, - optimization_level=None, + use_python_runtime=USE_PYTHON_RUNTIME, + max_aux_streams=MAX_AUX_STREAMS, + version_compatible=VERSION_COMPATIBLE, + optimization_level=OPTIMIZATION_LEVEL, num_avg_timing_iters=1, torch_executed_ops=[], torch_executed_modules=[], @@ -77,6 +93,7 @@ def compile( max_aux_streams: max number of aux stream to use version_compatible: enable version compatible feature optimization_level: builder optimization level + truncate_long_and_double: Whether to truncate long and double inputs to TRT engines automatically Returns: A torch.nn.Module lowered by TensorRT. """ @@ -133,6 +150,7 @@ def compile( max_aux_streams=max_aux_streams, version_compatible=version_compatible, optimization_level=optimization_level, + truncate_long_and_double=truncate_long_and_double, ) lowerer = Lowerer.create(lower_setting=lower_setting) return lowerer(module, inputs) @@ -209,6 +227,7 @@ def default_split_function( splitter_setting.use_implicit_batch_dim = False splitter_setting.min_block_size = lower_setting.min_block_size splitter_setting.use_experimental_rt = not lower_setting.use_python_runtime + splitter_setting.truncate_long_and_double = lower_setting.truncate_long_and_double splitter = TRTSplitter(model, inputs, settings=splitter_setting) splitter.node_support_preview() return splitter.generate_split_results() diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py index 9301a2cd90..c9d097b50d 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py @@ -4,7 +4,7 @@ from torch import nn from torch.fx.passes.pass_manager import PassManager -from .input_tensor_spec import InputTensorSpec +from torch_tensorrt.dynamo.common import InputTensorSpec from torch_tensorrt.fx.passes.lower_basic_pass import ( fuse_permute_linear, fuse_permute_matmul, @@ -73,6 +73,7 @@ class LowerSetting(LowerSettingBasic): max_aux_streams: max number of aux stream to use version_compatible: enable version compatible feature optimization_level: builder optimization level + truncate_long_and_double: Whether to truncate long and double inputs to TRT engines automatically """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -100,3 +101,4 @@ class LowerSetting(LowerSettingBasic): max_aux_streams: Optional[int] = None version_compatible: bool = False optimization_level: Optional[int] = None + truncate_long_and_double: bool = False diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py index 0fd3777254..04e367d1b3 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/passes/lower_pass_manager_builder.py @@ -10,7 +10,10 @@ from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult from torch_tensorrt.fx.utils import LowerPrecision from torch_tensorrt import _Input -from ..input_tensor_spec import InputTensorSpec +from torch_tensorrt.dynamo.common import ( + InputTensorSpec, + repair_long_or_double_inputs, +) from ..lower_setting import LowerSetting from torch_tensorrt.fx.observer import Observer @@ -196,6 +199,14 @@ def lower_func(split_result: SplitResult) -> nn.Module: _LOGGER.info(f"Now lowering submodule {submod_name}") lowering_start_time = datetime.datetime.now() + if self.lower_setting.truncate_long_and_double: + submod_inputs = repair_long_or_double_inputs( + parent_graph=split_result.split_module, + submodule=submod, + submodule_inputs=submod_inputs, + submodule_name=submod_name, + ) + self.lower_setting.input_specs = self._trt_input lowered_module = self._lower_func( diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py index 0761b964f8..0575d55660 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/test/core/test_input_tensor_spec.py @@ -5,7 +5,7 @@ import torch import torch_tensorrt from torch.testing._internal.common_utils import run_tests, TestCase -from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, LowerSetting +from torch_tensorrt.dynamo.common import InputTensorSpec class TestTRTModule(TestCase): diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py index 334243fef4..a1857a7677 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/common_fx2trt.py @@ -13,7 +13,7 @@ from torch.fx.passes import shape_prop from torch.fx.passes.infra.pass_base import PassResult from torch.testing._internal.common_utils import TestCase -from torch_tensorrt.dynamo.fx_ts_compat import InputTensorSpec, TRTInterpreter +from torch_tensorrt.dynamo.common import InputTensorSpec, TRTInterpreter from torch_tensorrt.fx.passes.lower_basic_pass_aten import ( compose_bmm, compose_chunk, diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index f34aad6caf..fc9bf634c2 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -8,7 +8,7 @@ from transformers import BertModel -from torch_tensorrt.dynamo.common_utils.test_utils import ( +from torch_tensorrt.dynamo.common.test_utils import ( COSINE_THRESHOLD, cosine_similarity, ) diff --git a/py/torch_tensorrt/fx/lower.py b/py/torch_tensorrt/fx/lower.py index 5f66519e05..961d817661 100644 --- a/py/torch_tensorrt/fx/lower.py +++ b/py/torch_tensorrt/fx/lower.py @@ -43,6 +43,7 @@ def compile( use_experimental_fx_rt=False, correctness_atol=1e-1, correctness_rtol=1e-1, + truncate_long_and_double=False, ) -> nn.Module: """ Takes in original module, input and lowering setting, run lowering workflow to turn module @@ -62,6 +63,7 @@ def compile( cuda_graph_batch_size: Cuda graph batch size, default to be -1. dynamic_batch: batch dimension (dim=0) is dynamic. use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + truncate_long_and_double: Whether to truncate long and double inputs to TRT engines automatically Returns: A torch.nn.Module lowered by TensorRT. """ @@ -85,6 +87,7 @@ def compile( use_experimental_rt=use_experimental_fx_rt, correctness_atol=correctness_atol, correctness_rtol=correctness_rtol, + truncate_long_and_double=truncate_long_and_double, ) lowerer = Lowerer.create(lower_setting=lower_setting) return lowerer(module, input) @@ -159,6 +162,7 @@ def default_split_function( splitter_setting.use_implicit_batch_dim = not lower_setting.explicit_batch_dimension splitter_setting.min_acc_module_size = lower_setting.min_acc_module_size splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt + splitter_setting.truncate_long_and_double = lower_setting.truncate_long_and_double splitter = TRTSplitter(model, inputs, settings=splitter_setting) splitter.node_support_preview() return splitter.generate_split_results() diff --git a/py/torch_tensorrt/fx/lower_setting.py b/py/torch_tensorrt/fx/lower_setting.py index 07e7bf0dac..d57fbced2b 100644 --- a/py/torch_tensorrt/fx/lower_setting.py +++ b/py/torch_tensorrt/fx/lower_setting.py @@ -74,6 +74,7 @@ class LowerSetting(LowerSettingBasic): correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + truncate_long_and_double: Whether to truncate long and double inputs to TRT engines automatically """ input_specs: List[InputTensorSpec] = dc.field(default_factory=list) @@ -101,3 +102,4 @@ class LowerSetting(LowerSettingBasic): correctness_atol: float = 0.1 correctness_rtol: float = 0.1 use_experimental_rt: bool = False + truncate_long_and_double: bool = False diff --git a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py index 6e6b40d42f..a6cc152f51 100644 --- a/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py +++ b/py/torch_tensorrt/fx/passes/lower_pass_manager_builder.py @@ -10,6 +10,9 @@ from torch.fx.passes.splitter_base import generate_inputs_for_submodules, SplitResult from torch_tensorrt.fx.passes.pass_utils import apply_bfloat_float_conversion from torch_tensorrt.fx.utils import LowerPrecision +from torch_tensorrt.dynamo.common import ( + repair_long_or_double_inputs, +) from ..input_tensor_spec import generate_input_specs @@ -193,6 +196,14 @@ def lower_func(split_result: SplitResult) -> nn.Module: _LOGGER.info(f"Now lowering submodule {submod_name}") lowering_start_time = datetime.datetime.now() + if self.lower_setting.truncate_long_and_double: + submod_inputs = repair_long_or_double_inputs( + parent_graph=split_result.split_module, + submodule=submod, + submodule_inputs=submod_inputs, + submodule_name=submod_name, + ) + self.lower_setting.input_specs = generate_input_specs( submod_inputs, self.lower_setting, diff --git a/py/torch_tensorrt/fx/test/tracer/test_aten_long_and_double_inputs.py b/py/torch_tensorrt/fx/test/tracer/test_aten_long_and_double_inputs.py new file mode 100644 index 0000000000..b917bbffe0 --- /dev/null +++ b/py/torch_tensorrt/fx/test/tracer/test_aten_long_and_double_inputs.py @@ -0,0 +1,75 @@ +import unittest + +import torch + +from torch_tensorrt.fx.lower import compile +from torch_tensorrt.fx.utils import LowerPrecision + + +class LongInputTest(unittest.TestCase): + def test_long_input(self): + class Model(torch.nn.Module): + def forward(self, x): + out = x + 1 + out = out * 2 + out = out - 1 + return out + + mod = Model().cuda().eval() + + inputs = [torch.randint(-40, 40, (3, 4, 7)).cuda().long()] + + aten_mod = compile( + mod, + inputs, + min_acc_module_size=3, + explicit_batch_dimension=True, + verbose_log=True, + lower_precision=LowerPrecision.FP16, + truncate_long_and_double=True, + dynamic_batch=False, + is_aten=True, + ) + + aten_output = aten_mod(*inputs)[0].detach().cpu() + torch_output = mod(*inputs).detach().cpu() + + max_diff = float(torch.max(torch.abs(aten_output - torch_output))) + + self.assertAlmostEqual( + max_diff, 0, 4, msg="Torch outputs don't match with TRT outputs" + ) + + +class DoubleInputTest(unittest.TestCase): + def test_double_input(self): + class Model(torch.nn.Module): + def forward(self, x): + out = x + 1 + out = out * 2 + return torch.mean(out, dim=-1) + + mod = Model().cuda().eval() + + inputs = [torch.rand((3, 4, 1)).cuda().double()] + + aten_mod = compile( + mod, + inputs, + min_acc_module_size=3, + explicit_batch_dimension=True, + verbose_log=True, + lower_precision=LowerPrecision.FP32, + truncate_long_and_double=True, + dynamic_batch=False, + is_aten=True, + ) + + aten_output = aten_mod(*inputs)[0].detach().cpu() + torch_output = mod(*inputs).detach().cpu() + + max_diff = float(torch.max(torch.abs(aten_output - torch_output))) + + self.assertAlmostEqual( + max_diff, 0, 4, msg="Torch outputs don't match with TRT outputs" + ) diff --git a/py/torch_tensorrt/fx/tools/trt_splitter.py b/py/torch_tensorrt/fx/tools/trt_splitter.py index 6fcb40c0d8..a122c4d3b6 100644 --- a/py/torch_tensorrt/fx/tools/trt_splitter.py +++ b/py/torch_tensorrt/fx/tools/trt_splitter.py @@ -19,6 +19,7 @@ def create_trt_operator_support( use_implicit_batch_dim=True, exclude_support_node_name: set = (), + truncate_long_and_double: bool = False, ) -> ops.OperatorSupportBase: """Creates an `OperatorSupportBase` instance used for TRT splitting purpose.""" # Create an `OperatorSupport` that declares a node supported if it @@ -32,14 +33,17 @@ def create_trt_operator_support( support_dict[get_acc_ops_name(k)] = None supported_if_converter_registered = ops.OperatorSupport(support_dict=support_dict) - return ops.chain( - ops.OpSupports.decline_if_node_in_names(exclude_support_node_name), - # 1. Node is not supported if it has args with int64 or float64 dtype: - ops.OpSupports.decline_if_input_dtype(torch.int64), - ops.OpSupports.decline_if_input_dtype(torch.float64), - # 2. Node is supported if it has TRT converter: - supported_if_converter_registered, - ) + op_support_checks = [ + ops.OpSupports.decline_if_node_in_names(exclude_support_node_name) + ] + + if not truncate_long_and_double: + op_support_checks.append(ops.OpSupports.decline_if_input_dtype(torch.int64)) + op_support_checks.append(ops.OpSupports.decline_if_input_dtype(torch.float64)) + + op_support_checks.append(supported_if_converter_registered) + + return ops.chain(*op_support_checks) class TRTSplitterSetting(splitter_base._SplitterSettingBase): @@ -52,6 +56,7 @@ def __init__(self): self.use_implicit_batch_dim: bool = True self.exclude_support_node_name: set = set() self.use_experimental_rt: bool = False + self.truncate_long_and_double: bool = False if self.use_experimental_rt and self.use_implicit_batch_dim: raise ValueError( @@ -71,7 +76,9 @@ def __init__( settings = TRTSplitterSetting() if not operator_support: operator_support = create_trt_operator_support( - settings.use_implicit_batch_dim, settings.exclude_support_node_name + settings.use_implicit_batch_dim, + settings.exclude_support_node_name, + settings.truncate_long_and_double, ) super().__init__( module,