From e746243d6ccfe7d8a0429c900a37db65bdc862aa Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 27 Sep 2023 18:22:21 -0700 Subject: [PATCH 1/2] perf: Add lowering passes to improve TRT conversion - Focus on variance and sum converters, reducing instances of extraneous layers from unnecessary reshapes - Add test cases to validate new additions --- .../writing_dynamo_aten_lowering_passes.rst | 10 +-- py/torch_tensorrt/dynamo/aten_tracer.py | 2 +- py/torch_tensorrt/dynamo/backend/backends.py | 2 +- .../dynamo/lowering/_decomposition_groups.py | 3 + .../dynamo/lowering/_decompositions.py | 50 ++++++++++- .../lowering/passes/_aten_lowering_pass.py | 18 ++-- .../lowering/passes/constant_folding.py | 5 +- .../lowering/passes/fuse_prims_broadcast.py | 82 +++++++++++++++++++ .../dynamo/lowering/passes/pass_manager.py | 28 +++++-- .../remove_input_alias_fixing_clones.py | 5 +- .../lowering/passes/repair_input_as_output.py | 5 +- .../lowering/test_aten_lowering_passes.py | 65 +++++++++++++++ .../py/dynamo/lowering/test_decompositions.py | 68 +++++++++++++++ tests/py/dynamo/testing_utilities.py | 2 +- 14 files changed, 322 insertions(+), 23 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py diff --git a/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst b/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst index d64f81d4aa..4c29bc9b75 100644 --- a/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst +++ b/docsrc/contributors/writing_dynamo_aten_lowering_passes.rst @@ -12,7 +12,7 @@ Lowering Pass Requirements ------------ An ATen lowering pass function in Torch-TRT must satisfy two requirements: -- The function must take as input a single `torch.fx.GraphModule` and return the lowered `torch.fx.GraphModule` +- The function must take as input a `torch.fx.GraphModule` and a sequence of torch Tensors, `Sequence[torch.Tensor]`, and return the lowered `torch.fx.GraphModule` - The function must leave the graph in a valid and invoke-able state, including performing any necessary linting and recompilation See this link for information on `Graph Manipulations `_ in FX. See below for an example of a lowering pass which repairs graphs that have inputs which are also outputs, a disallowed configuration for TRT Engines. @@ -22,7 +22,7 @@ Example Lowering Pass .. code-block:: python - def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + def repair_input_as_output(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule: """Repair scenarios where inputs are also outputs of the graph TRT does not allow such cases, so we insert a clone (identity) layer @@ -82,7 +82,7 @@ For instance, to insert the pass at the default location (end of the list), the .. code-block:: python @_aten_lowering_pass - def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule: ... Alternatively, to insert the pass at a custom index (such as the front of the list) in the passlist, the following code can be used: @@ -90,7 +90,7 @@ Alternatively, to insert the pass at a custom index (such as the front of the li .. code-block:: python @_aten_lowering_pass(index=0) - def my_custom_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + def my_custom_pass(gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]) -> torch.fx.GraphModule: ... There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for displaying the currently-available lowering pass list, applying those passes to an arbitrary `torch.fx.GraphModule`, and removing the lowering pass at a specific index. @@ -101,7 +101,7 @@ There are also provided utilities in `torch_tensorrt.dynamo.lowering.passes` for print(dump_lowering_passes()) # Apply lowering passes to a GraphModule - apply_lowering_passes(graph_module) + apply_lowering_passes(graph_module, sample_inputs) # Remove the lowering pass at index 1 _remove_lowering_pass(index=1) diff --git a/py/torch_tensorrt/dynamo/aten_tracer.py b/py/torch_tensorrt/dynamo/aten_tracer.py index b271d0d6fb..be2f2efd9c 100644 --- a/py/torch_tensorrt/dynamo/aten_tracer.py +++ b/py/torch_tensorrt/dynamo/aten_tracer.py @@ -28,6 +28,6 @@ def trace( "torch._export.DECOMP_TABLE", get_decompositions(experimental_decompositions) ): graph_module = export(model, tuple(inputs)).module() - graph_module = apply_lowering_passes(graph_module) + graph_module = apply_lowering_passes(graph_module, inputs) logger.debug("Post export graph: " + str(graph_module.graph)) return graph_module diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index f8508d752e..7b98079564 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -87,7 +87,7 @@ def _pretraced_backend( logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph)) - gm = apply_lowering_passes(gm) + gm = apply_lowering_passes(gm, sample_inputs) trt_compiled = compile_module( gm, diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 8a5df8988e..f1cfaae348 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -149,6 +149,7 @@ aten.special_log_ndtr, aten.special_xlog1py, aten.stack, + aten.std, aten.t, aten.tanh_backward, aten.threshold, @@ -163,6 +164,8 @@ aten.upsample_bilinear2d, aten.upsample_bilinear2d.vec, aten.upsample_nearest2d_backward, + aten.var, + aten.var_mean, aten.xlogy, aten.zero, aten.zero_, diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 57e1954575..eeddb5f41b 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, List, Optional import torch from torch._decomp import register_decomposition @@ -135,6 +135,54 @@ def reciprocal_replacement( return torch.div(1, input_) +@register_torch_trt_decomposition( + torch.ops.prims.var.default, registry=TORCH_TRT_DECOMPOSITIONS +) +def var_decomposition( + input_tensor: torch.Tensor, + dims: Optional[List[int]], + correction: int, + output_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + if dims is None: + dims = [] + + # If the dimensions are empty, variance is taken over all dimensions + if isinstance(dims, (tuple, list)) and len(dims) == 0: + N = input_tensor.numel() + # Otherwise, the number of samples is the product of the dimensions reduced over + else: + N = 1 + for dim_i in dims: + N *= input_tensor.shape[dim_i] + + # Compute the mean, difference, and correction term as per the formula: + # https://pytorch.org/docs/stable/generated/torch.var.html + + # Additionally, prims does not support keepdim, and so we only keep dimensions + # on the first reduction, then remove it for the second + sample_mean = torch.mean(input_tensor, dims, keepdim=True) + diff = input_tensor - sample_mean + squared_diff = diff * diff + variance_unnormalized = torch.sum(squared_diff, dims, keepdim=False) + + if correction is None: + correction_term = float(N - 1) + elif isinstance(correction, int): + correction_term = float(N - correction) + elif isinstance(correction, float): + correction_term = float(N) - correction + else: + raise RuntimeError("correction must be int or float") + + if correction_term <= 0: + raise RuntimeError(f"correction term was non-positive, got: {correction_term}") + + variance = variance_unnormalized / correction_term + + return variance + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index 43d70a4cac..a73254f487 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -1,9 +1,10 @@ import logging -from typing import Callable, Optional +from typing import Callable, Optional, Sequence, Union import torch from .constant_folding import constant_fold +from .fuse_prims_broadcast import fuse_prims_broadcast from .pass_manager import DynamoPassManager from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output @@ -13,19 +14,24 @@ remove_input_alias_fixing_clones, constant_fold, repair_input_as_output, + fuse_prims_broadcast, ] ) logger = logging.getLogger(__name__) -LoweringPassSignature = Callable[[torch.fx.GraphModule], torch.fx.GraphModule] +LoweringPassSignature = Callable[ + [torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule +] def _aten_lowering_pass( *args: LoweringPassSignature, index: Optional[int] = None, -) -> LoweringPassSignature: +) -> Union[ + LoweringPassSignature, Callable[[LoweringPassSignature], LoweringPassSignature] +]: """Adds a lowering pass to the registry, at a specified index if desired If no index is specified, the lowering pass is inserted at the end of the list @@ -65,12 +71,14 @@ def _remove_lowering_pass(*, index: int) -> None: return -def apply_lowering_passes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: +def apply_lowering_passes( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: """Applies the lowering passes to a graph module, returns the modified GraphModule""" logging.debug( f"Invoking DynamoPassManager and applying lowering passes: {ATEN_LOWERING_PASSES}" ) - return ATEN_LOWERING_PASSES(gm) + return ATEN_LOWERING_PASSES(gm, sample_inputs) def dump_lowering_passes() -> str: diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py index ea2547f6bf..94398b7c6b 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py @@ -1,4 +1,5 @@ import logging +from typing import Sequence import torch from torch_tensorrt._utils import sanitized_torch_version @@ -21,7 +22,9 @@ @torch.utils._python_dispatch._disable_current_modes() # type: ignore -def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: +def constant_fold( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: """Adapted from: https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197 diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py new file mode 100644 index 0000000000..db407d7fdb --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py @@ -0,0 +1,82 @@ +import logging +from typing import Sequence + +import torch +from torch.fx.passes.shape_prop import ShapeProp +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +# TODO: Add relevant prims to this fusion +def fuse_prims_broadcast( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Fuses prim nodes which are effectively the ATen equivalents with keep_dim=True""" + modified_graph = False + + # Propagate shapes through the graph to determine if broadcast can be resolved + try: + ShapeProp(gm).propagate(*sample_inputs) + except (RuntimeError, AssertionError): + logger.warning( + "Shape Propagation Failed on Graph, skipping fuse_prims_broadcast lowering pass", + exc_info=True, + ) + return gm + + for node in gm.graph.nodes: + # If the node is a sum prims operator, with broadcast_in_dim being the only consumer + # it is a candidate for fusing + if ( + node.target in (torch.ops.prims.sum.default,) + and len(node.users) == 1 + and list(node.users)[0].target == torch.ops.prims.broadcast_in_dim.default + ): + # Get broadcasted shape, reduced dimensions, and original tensor shape + broadcast_node = list(node.users)[0] + broadcasted_shape = broadcast_node.args[1] + reduced_dims = node.args[1] + original_shape = node.args[0].meta["tensor_meta"].shape + + # If the rank of the broadcasted shape is the same as the original + # and the broadcasts are all singletons for the reduced dimensions + # and all of the non-reduced dimensions are identical to the originals + + # Then the broadcast is effectively performing a "keep_dim=True" operation + if ( + len(broadcasted_shape) == len(original_shape) + and all(broadcasted_shape[i] == 1 for i in reduced_dims) + and all( + broadcasted_shape[j] == original_shape[j] + for j in range(len(original_shape)) + if j not in reduced_dims + ) + ): + # Fuse the operator to its convertible alternative + with gm.graph.inserting_after(broadcast_node): + modified_graph = True + + if node.target == torch.ops.prims.sum.default: + fused_node = gm.graph.call_function( + torch.ops.aten.sum.dim_IntList, + args=(node.args[0], reduced_dims, True), + ) + + # Replace all uses of the placeholder except the cloned node + # with the cloned placeholder + broadcast_node.replace_all_uses_with( + fused_node, + ) + + # Erase uses of the broadcast node and original + gm.graph.erase_node(broadcast_node) + gm.graph.erase_node(node) + + if modified_graph: + gm = clean_up_graph_after_modifications(gm) + logger.debug(f"Fused prims-broadcast paradigm:\n{gm.graph}") + + return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py index 51e2584364..64e03147a2 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/pass_manager.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, List, Optional +from typing import Any, Callable, List, Optional, Sequence import torch from torch.fx.passes.pass_manager import PassManager @@ -8,7 +8,11 @@ class DynamoPassManager(PassManager): # type: ignore[misc] def __init__( self, passes: Optional[ - List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]] + List[ + Callable[ + [torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule + ] + ] ] = None, ): super().__init__(passes) @@ -16,14 +20,22 @@ def __init__( @classmethod def build_from_passlist( cls, - passes: Optional[List[Callable[[torch.fx.GraphModule], torch.fx.GraphModule]]], + passes: Optional[ + List[ + Callable[ + [torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule + ] + ] + ], ) -> Any: pm = DynamoPassManager(passes) return pm def add_pass_with_index( self, - lowering_pass: Callable[[torch.fx.GraphModule], torch.fx.GraphModule], + lowering_pass: Callable[ + [torch.fx.GraphModule, Sequence[torch.Tensor]], torch.fx.GraphModule + ], index: Optional[int] = None, ) -> None: if index is None: @@ -35,8 +47,12 @@ def add_pass_with_index( def remove_pass_with_index(self, index: int) -> None: del self.passes[index] - def __call__(self, source: Any) -> Any: - return super().__call__(source) + def __call__(self, gm: Any, sample_inputs: Any) -> Any: + self.validate() + out, example_inputs = gm, sample_inputs + for _pass in self.passes: + out = _pass(out, example_inputs) + return out def __str__(self) -> str: return str(self.passes) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_input_alias_fixing_clones.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_input_alias_fixing_clones.py index dce88ad109..7630f3c1a5 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_input_alias_fixing_clones.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_input_alias_fixing_clones.py @@ -1,4 +1,5 @@ import logging +from typing import Sequence import torch from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( @@ -9,7 +10,9 @@ # TODO: Delete this lowering pass once aot_export_joint_simple is patched -def remove_input_alias_fixing_clones(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: +def remove_input_alias_fixing_clones( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: """Remove the auxiliary clone nodes inserted to fix input aliasing See: https://github.com/pytorch/pytorch/issues/108079 diff --git a/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py b/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py index ec2f5b0ae0..b97b95e686 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/repair_input_as_output.py @@ -1,4 +1,5 @@ import logging +from typing import Sequence import torch from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( @@ -9,7 +10,9 @@ logger = logging.getLogger(__name__) -def repair_input_as_output(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: +def repair_input_as_output( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: """Repair scenarios where inputs are also outputs of the graph TRT does not allow such cases, so we insert a clone (identity) layer diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index a63c5e3439..2183910f84 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -91,5 +91,70 @@ def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: self.assertNotIn(identity_pass, ATEN_LOWERING_PASSES.passes) +class TestPrimBroadcastFusion(TestCase): + def test_input_as_output(self): + class InputAsOutput(torch.nn.Module): + def forward(self, x): + return torch.var_mean(x, keepdim=True)[1] + + inputs = [ + torch.rand( + 5, + 7, + ).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(InputAsOutput()) + expected_ops = {torch.ops.aten.sum.dim_IntList} + unexpected_ops = {torch.ops.aten.var.default, torch.ops.prims.var.default} + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"InputAsOutput TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/lowering/test_decompositions.py b/tests/py/dynamo/lowering/test_decompositions.py index fd834394c1..40e5a8f3e8 100644 --- a/tests/py/dynamo/lowering/test_decompositions.py +++ b/tests/py/dynamo/lowering/test_decompositions.py @@ -245,6 +245,74 @@ def forward(self, x): f"Reciprocal TRT outputs don't match with the original model.", ) + def test_lowering_prims_var(self): + class Var(torch.nn.Module): + def forward(self, x): + y = torch.var(x) + return y + + # Operations expected to be removed in the traced graph after decompositions + expected_ops = { + torch.ops.aten.mean.dim, + torch.ops.aten.sub.Tensor, + torch.ops.aten.mul.Tensor, + torch.ops.aten.sum.dim_IntList, + torch.ops.aten.div.Tensor, + } + unexpected_ops = {torch.ops.aten.var.default, torch.ops.prims.div.default} + + inputs = [ + torch.randn( + 5, + 10, + 1, + ).cuda() + ] + + fx_graph = torch.fx.symbolic_trace(Var()) + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + f"Var TRT outputs don't match with the original model.", + ) + if __name__ == "__main__": run_tests() diff --git a/tests/py/dynamo/testing_utilities.py b/tests/py/dynamo/testing_utilities.py index 344cd6bc1d..b55194fa4c 100644 --- a/tests/py/dynamo/testing_utilities.py +++ b/tests/py/dynamo/testing_utilities.py @@ -53,7 +53,7 @@ def fx_dynamo_testing_backend( decompositions=get_decompositions(), ) - gm = apply_lowering_passes(gm) + gm = apply_lowering_passes(gm, sample_inputs) trt_compiled = custom_backend( gm, From b42203e0eb15db09164ce5978cb213d657a6f855 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 28 Sep 2023 14:05:48 -0700 Subject: [PATCH 2/2] perf: Add efficient attention lowering pass --- .../dynamo/conversion/aten_ops_converters.py | 15 +++ .../dynamo/conversion/impl/__init__.py | 1 + .../dynamo/conversion/impl/attention.py | 50 ++++++++ .../dynamo/lowering/_decompositions.py | 5 - .../lowering/passes/_aten_lowering_pass.py | 2 + .../lowering/passes/fuse_prims_broadcast.py | 2 +- .../passes/lower_efficient_attention.py | 74 +++++++++++ .../lowering/test_aten_lowering_passes.py | 119 +++++++++++++++++- 8 files changed, 258 insertions(+), 10 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/impl/attention.py create mode 100644 py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index d844fe1995..f6567488e3 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -1517,3 +1517,18 @@ def aten_ops_max_pool( dilation=args_bounds_check(args, 4, replacement=1), ceil_mode=args_bounds_check(args, 5, replacement=False), ) + + +@dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, +) # type: ignore[misc] +def tensorrt_scaled_dot_product_attention( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + return impl.attention.scaled_dot_product_attention( + ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2] + ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index b247bd2cf9..3f49377619 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -2,6 +2,7 @@ from . import ( activation, + attention, cast, condition, conv, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/attention.py b/py/torch_tensorrt/dynamo/conversion/impl/attention.py new file mode 100644 index 0000000000..6221357ca2 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/impl/attention.py @@ -0,0 +1,50 @@ +import math +from typing import Optional, Union + +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import SourceIR +from torch_tensorrt.fx.types import TRTTensor + + +def scaled_dot_product_attention( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + query: TRTTensor, + key: TRTTensor, + value: TRTTensor, +) -> TRTTensor: + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + div = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + math.sqrt(query.shape[-1]), + ) + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", div, -1 + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + + return out diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index eeddb5f41b..158523a7e1 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -83,11 +83,6 @@ def inplace_op(*args, **kwargs): # type: ignore replace_inplace_op(aten.scatter_reduce_, aten.scatter_reduce) -@register_torch_trt_decomposition(aten.std, registry=TORCH_TRT_DECOMPOSITIONS) -def std_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore - return torch.sqrt(torch.var(*args, **kwargs)) - - @register_torch_trt_decomposition(aten.rsqrt, registry=TORCH_TRT_DECOMPOSITIONS) def rsqrt_replacement(*args, **kwargs) -> torch.Tensor: # type: ignore return torch.reciprocal(torch.sqrt(*args, **kwargs)) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index a73254f487..ffbe1c7f44 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -5,6 +5,7 @@ from .constant_folding import constant_fold from .fuse_prims_broadcast import fuse_prims_broadcast +from .lower_efficient_attention import lower_efficient_attention from .pass_manager import DynamoPassManager from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output @@ -14,6 +15,7 @@ remove_input_alias_fixing_clones, constant_fold, repair_input_as_output, + lower_efficient_attention, fuse_prims_broadcast, ] ) diff --git a/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py b/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py index db407d7fdb..312926e870 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/fuse_prims_broadcast.py @@ -77,6 +77,6 @@ def fuse_prims_broadcast( if modified_graph: gm = clean_up_graph_after_modifications(gm) - logger.debug(f"Fused prims-broadcast paradigm:\n{gm.graph}") + logger.debug(f"Graph after fusing prims-broadcast paradigm:\n{gm.graph}") return gm diff --git a/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py b/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py new file mode 100644 index 0000000000..944b0788b0 --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/lower_efficient_attention.py @@ -0,0 +1,74 @@ +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, + get_tensor_placeholders, +) + +logger = logging.getLogger(__name__) + + +def lower_efficient_attention( + gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor] +) -> torch.fx.GraphModule: + """Replace a specific version of scaled_dot_product_attention with an equivalent + implementation which can be easily converted to TRT + """ + orig, replacement = efficient_attention_replacement() + + if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement): + gm = clean_up_graph_after_modifications(gm) + logger.debug( + f"Graph after lowering _scaled_dot_product_efficient_attention:\n{gm.graph}" + ) + + return gm + + +def efficient_attention_replacement() -> ( + Tuple[ + torch.fx.GraphModule, + Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor], + ] +): + """Constructs the original and replacement functions for efficient attention""" + + # Empty boilerplate function taking in three Tensors and returning one + def boilerplate( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + ... + + # Trace boilerplate function and extract placeholder and output nodes + orig = torch.fx.symbolic_trace(boilerplate) + q, k, v = get_tensor_placeholders(orig) + output = [node for node in orig.graph.nodes if node.op == "output"][0] + + # Graph types to replace are those which use the _scaled_dot_product_efficient_attention + # function and extract only the first element + with orig.graph.inserting_before(output): + att = orig.graph.call_function( + torch.ops.aten._scaled_dot_product_efficient_attention.default, + args=(q, k, v, None, False), + ) + out = orig.graph.call_function( + operator.getitem, + args=(att, 0), + ) + + # Assign the output of the graph to be the single getitem output + output.args = (out,) + + orig.graph.lint() + orig.recompile() + + # Replacement graph consists of the functional version of scaled_dot_product_attention + def replacement( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor + ) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention(query, key, value) + + return orig, replacement diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 2183910f84..1bbb54192c 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -92,8 +92,8 @@ def identity_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: class TestPrimBroadcastFusion(TestCase): - def test_input_as_output(self): - class InputAsOutput(torch.nn.Module): + def test_broadcast_fusion(self): + class BroadcastFusion(torch.nn.Module): def forward(self, x): return torch.var_mean(x, keepdim=True)[1] @@ -104,7 +104,7 @@ def forward(self, x): ).cuda(), ] - fx_graph = torch.fx.symbolic_trace(InputAsOutput()) + fx_graph = torch.fx.symbolic_trace(BroadcastFusion()) expected_ops = {torch.ops.aten.sum.dim_IntList} unexpected_ops = {torch.ops.aten.var.default, torch.ops.prims.var.default} @@ -151,7 +151,118 @@ def forward(self, x): max_diff, 0, DECIMALS_OF_AGREEMENT, - msg=f"InputAsOutput TRT outputs don't match with the original model.", + msg=f"BroadcastFusion TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + +class TestLowerEfficientAttention(TestCase): + def test_lower_efficient_attention(self): + class EfficientAttention(torch.nn.Module): + def forward(self, q, k, v): + attn = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, k, v, None, False + ) + return attn[0] + + inputs = [ + torch.rand(8, 4, 5, 4).cuda(), + torch.rand(8, 4, 2, 4).cuda(), + torch.rand(8, 4, 2, 4).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(EfficientAttention()) + expected_ops = {torch.nn.functional.scaled_dot_product_attention} + unexpected_ops = { + torch.ops.aten._scaled_dot_product_efficient_attention.default + } + + unexpected_ops_seen, expected_ops_unseen = lower_graph_testing( + fx_graph, + inputs, + expected_ops=expected_ops, + unexpected_ops=unexpected_ops, + min_block_size=1, + ) + + self.assertEquals( + len(unexpected_ops_seen), + 0, + f"The following unexpected ops were encountered: {unexpected_ops_seen}", + ) + + self.assertEquals( + len(expected_ops_unseen), + 0, + f"The following expected ops were not encountered: {expected_ops_unseen}", + ) + torch._dynamo.reset() + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"EfficientAttention TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + def test_efficient_attention_converter(self): + class EfficientAttention(torch.nn.Module): + def forward(self, q, k, v): + attn = torch.ops.aten._scaled_dot_product_efficient_attention.default( + q, k, v, None, False + ) + return attn[0] + + inputs = [ + torch.rand(1, 3, 6, 4).cuda(), + torch.rand(1, 3, 2, 4).cuda(), + torch.rand(1, 3, 2, 4).cuda(), + ] + + fx_graph = torch.fx.symbolic_trace(EfficientAttention()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = torch.cat( + [tensor.detach().cpu() for tensor in optimized_model(*inputs)] + ) + torch_model_results = torch.cat( + [tensor.detach().cpu() for tensor in fx_graph(*inputs)] + ) + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + DECIMALS_OF_AGREEMENT, + msg=f"EfficientAttention TRT outputs don't match with the original model.", ) torch._dynamo.reset()