From 4fab8a59d3b193d545d41c8a3816d92800a58a22 Mon Sep 17 00:00:00 2001 From: Chengzhe Xu Date: Tue, 27 May 2025 20:54:04 +0000 Subject: [PATCH 1/5] feat: Implement SDPA op converter / lowering pass as extensions --- examples/dynamo/register_sdpa.py | 93 +++++++++++ examples/dynamo/sdpa_converter.py | 149 ++++++++++++++++++ examples/dynamo/torch_export_flux_dev.py | 4 +- py/torch_tensorrt/dynamo/lowering/__init__.py | 1 + 4 files changed, 245 insertions(+), 2 deletions(-) create mode 100644 examples/dynamo/register_sdpa.py create mode 100644 examples/dynamo/sdpa_converter.py diff --git a/examples/dynamo/register_sdpa.py b/examples/dynamo/register_sdpa.py new file mode 100644 index 0000000000..cc0cd5ba4b --- /dev/null +++ b/examples/dynamo/register_sdpa.py @@ -0,0 +1,93 @@ +import copy +import logging +import operator +from typing import Callable, Sequence, Tuple + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) +from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS +from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( + _aten_lowering_pass, +) + +from sdpa_converter import * +logger = logging.getLogger(__name__) + +# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention +# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default) +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_efficient_attention.default) +TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default) + +REPLACEABLE_ATEN_OPS = { + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, +} + +@_aten_lowering_pass +def replace_variants_of_sdpa( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Replace scaled_dot_product_attention with an equivalent + implementation which can be accurately converted to TRT + """ + attn_mask = None + is_causal = True + for node in gm.graph.nodes: + if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: + if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + breakpoint() + if len(node.args) == 7: + query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal = node.args + elif len(node.args) == 5: + query, key, value, attn_mask, is_causal = node.args + dropout_p = 0.0 + else: + raise ValueError(f"Unexpected number of arguments for {node.target} in the graph") + elif node.target == torch.ops.aten._scaled_dot_product_flash_attention.default: + if len(node.args) == 6: + query, key, value, dropout_p, is_causal, return_debug_mask = node.args + elif len(node.args) == 3: + query, key, value = node.args + dropout_p = 0.0 + is_causal = True + else: + raise ValueError(f"Unexpected number of arguments for {node.target} in the graph") + if attn_mask is not None: + logger.warning(f"We do not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration.") + + modified_input_args = (query, key, value, None, dropout_p, is_causal) + + # Create a new node with torch.nn.functional.scaled_dot_product_attention + # The input args is (query, key, value, is_causal). kwargs has scale + with gm.graph.inserting_after(node): + new_node = gm.graph.call_function( + torch.nn.functional.scaled_dot_product_attention, + args=modified_input_args, + kwargs={"scale": node.kwargs.get("scale", None)} + ) + + # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. + new_node.meta = copy.copy(node.meta) + # Check if there's a getitem node following this attention node + for user in list(node.users): + if user.op == "call_function" and user.target == operator.getitem: + # If the getitem is extracting the first element (the output tensor) + if user.args[1] == 0: + # Replace all uses of the getitem with the new attention node + user.replace_all_uses_with(new_node) + new_node.meta['val'] = new_node.meta['val'][0] + # Replace all uses of the original node with the new node + node.replace_all_uses_with(new_node) + + gm.graph.erase_node(node) + + # Clean up the graph + clean_up_graph_after_modifications(gm) + + logger.info("Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention") + return gm \ No newline at end of file diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py new file mode 100644 index 0000000000..642d40b065 --- /dev/null +++ b/examples/dynamo/sdpa_converter.py @@ -0,0 +1,149 @@ +import math +from typing import Optional, Union, Tuple, Any, Dict +import torch +import torch_tensorrt +import numpy as np +import tensorrt as trt +from torch.fx.node import Target +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion.converter_utils import ( + SourceIR, + cast_trt_tensor, + get_trt_tensor, +) +from torch_tensorrt.fx.types import TRTTensor +import logging +logger = logging.getLogger(__name__) + +def tril( + ctx: ConversionContext, + target: Union[Target, str], + source_ir: Optional[SourceIR], + name: str, + row: TRTTensor, + col: TRTTensor, +) -> TRTTensor: + row_arange_tensor = impl.arange.arange(ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1) + row_reshape_tensor = impl.shuffle.reshape(ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1]) + + col_arange_tensor = impl.arange.arange(ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1) + col_reshape_tensor = impl.shuffle.reshape(ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col]) + + mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor) + return mask + + +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.nn.functional.scaled_dot_product_attention, enabled=True, supports_dynamic_shapes=True) +def scaled_dot_product_attention( + ctx: torch_tensorrt.dynamo.conversion.ConversionContext, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: str, +) -> TRTTensor: + # TODO: Handle attn_mask and is_causal arguments in the future + query, key, value, attn_mask, dropout_p, is_causal = args + logger.info("Ignoring attn_mask and is_causal arguments provided by the original graph. " + "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " + "and for generate phase, is_causal=False since we pass only 1 input token at a time") + + + # TODO: remove this once we have a better way to handle the causal mask + scale = kwargs.get("scale", None) + source_ir = SourceIR.ATEN + # implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html + mm = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_mm", + query, + key, + other_matrix_op=trt.MatrixOperation.TRANSPOSE, + ) + if scale is None: + scale = query.shape[-1] + if scale < 0: + # dynamic shape + scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1) + sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale) + else: + # static shape + sqrt_scaled = math.sqrt(scale) + scaled = impl.elementwise.div( + ctx, + target, + source_ir, + name + "_scale", + mm, + sqrt_scaled, + ) + else: + scaled = impl.elementwise.mul( + ctx, + target, + source_ir, + name + "_scale", + mm, + scale, + ) + + # If is_causal is True, we need to generate a causal mask + if is_causal: + L, S = query.shape[-2], key.shape[-2] + if L >= 0 and S >= 0: + # static shape + attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype)) + temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0)) + attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf")) + attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias") + else: + # if any of the L or S is dynamic shape + if L < 0: + L = impl.shape.shape( + ctx, target, source_ir, name + "_shape_0", query, 2 + ) + if S < 0: + S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2) + + # generate the mask tensor + tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S) + + temp_mask = impl.unary.logical_not( + ctx, target, source_ir, name + "_logical_not", tril_tensor + ) + temp_mask_casted = cast_trt_tensor(ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir) + one_minus_temp_mask = impl.elementwise.sub( + ctx, target, source_ir, name + "_one_minus_temp_mask", 1.0, temp_mask_casted + ) + attn_bias = impl.unary.log(ctx, target, source_ir, name + "_log", one_minus_temp_mask) + + scaled_add_attn_bias = impl.elementwise.add( + ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias + ) + else: + scaled_add_attn_bias = scaled + + # Create a if condition to check if is_causal is True + if isinstance(is_causal, TRTTensor): + if_layer = ctx.net.add_if_conditional() + condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled + if_layer.set_condition(condition) + output_layer = if_layer.add_output(true_branch, false_branch) + scaled_add_attn_bias = output_layer.get_output(0) + + softmax = impl.normalization.softmax( + ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False + ) + out = impl.matmul.matrix_multiply( + ctx, + target, + source_ir, + name + "_out", + softmax, + value, + ) + + return out \ No newline at end of file diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py index 3891fcbb9a..37b4f3b571 100644 --- a/examples/dynamo/torch_export_flux_dev.py +++ b/examples/dynamo/torch_export_flux_dev.py @@ -26,6 +26,7 @@ import torch_tensorrt from diffusers import FluxPipeline from torch.export._trace import _export +import register_sdpa # Register SDPA as a standalone operator # %% # Define the FLUX-1.dev model @@ -112,7 +113,6 @@ min_block_size=1, use_fp32_acc=True, use_explicit_typing=True, -) # %% # Post Processing @@ -147,7 +147,7 @@ def generate_image(pipe, prompt, image_name): print(f"Image generated using {image_name} model saved as {image_name}.png") -generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code") +generate_image(pipe, ["A golden retriever holding a sign to debug"], "dog_code") # %% # The generated image is as shown below diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index b276a7542d..86d3500a9d 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -1,6 +1,7 @@ from ._decomposition_groups import ( torch_disabled_decompositions, torch_enabled_decompositions, + TORCH_TRT_DECOMPOSITIONS ) from ._decompositions import get_decompositions # noqa: F401 from .passes import * From d1fd50457085b68d260981c83c2dd4ed84f47c5c Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Tue, 27 May 2025 15:35:37 -0700 Subject: [PATCH 2/5] chore: minor change --- examples/dynamo/torch_export_flux_dev.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/dynamo/torch_export_flux_dev.py b/examples/dynamo/torch_export_flux_dev.py index 37b4f3b571..32e75b06d8 100644 --- a/examples/dynamo/torch_export_flux_dev.py +++ b/examples/dynamo/torch_export_flux_dev.py @@ -19,6 +19,8 @@ we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency) """ +import register_sdpa # Register SDPA as a standalone operator + # %% # Import the following libraries # ----------------------------- @@ -26,7 +28,6 @@ import torch_tensorrt from diffusers import FluxPipeline from torch.export._trace import _export -import register_sdpa # Register SDPA as a standalone operator # %% # Define the FLUX-1.dev model @@ -113,6 +114,7 @@ min_block_size=1, use_fp32_acc=True, use_explicit_typing=True, +) # %% # Post Processing @@ -147,7 +149,7 @@ def generate_image(pipe, prompt, image_name): print(f"Image generated using {image_name} model saved as {image_name}.png") -generate_image(pipe, ["A golden retriever holding a sign to debug"], "dog_code") +generate_image(pipe, ["A golden retriever holding a sign to code"], "dog_code") # %% # The generated image is as shown below From 026ac07af50c54ad203e07892a862df8f4ba7893 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 28 May 2025 16:08:42 -0700 Subject: [PATCH 3/5] chore: minor msg update --- examples/dynamo/register_sdpa.py | 63 +++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 18 deletions(-) diff --git a/examples/dynamo/register_sdpa.py b/examples/dynamo/register_sdpa.py index cc0cd5ba4b..6161c0afd4 100644 --- a/examples/dynamo/register_sdpa.py +++ b/examples/dynamo/register_sdpa.py @@ -4,23 +4,25 @@ from typing import Callable, Sequence, Tuple import torch +from sdpa_converter import * from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import ( _aten_lowering_pass, ) +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) -from sdpa_converter import * logger = logging.getLogger(__name__) # Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention # This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default) -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_efficient_attention.default) +TORCH_TRT_DECOMPOSITIONS.pop( + torch.ops.aten._scaled_dot_product_efficient_attention.default +) TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default) REPLACEABLE_ATEN_OPS = { @@ -28,6 +30,7 @@ torch.ops.aten._scaled_dot_product_flash_attention.default, } + @_aten_lowering_pass def replace_variants_of_sdpa( gm: torch.fx.GraphModule, settings: CompilationSettings @@ -39,26 +42,48 @@ def replace_variants_of_sdpa( is_causal = True for node in gm.graph.nodes: if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS: - if node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default: + if ( + node.target + == torch.ops.aten._scaled_dot_product_efficient_attention.default + ): breakpoint() if len(node.args) == 7: - query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal = node.args + ( + query, + key, + value, + attn_bias, + compute_log_sumexp, + dropout_p, + is_causal, + ) = node.args elif len(node.args) == 5: query, key, value, attn_mask, is_causal = node.args dropout_p = 0.0 else: - raise ValueError(f"Unexpected number of arguments for {node.target} in the graph") - elif node.target == torch.ops.aten._scaled_dot_product_flash_attention.default: + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) + elif ( + node.target + == torch.ops.aten._scaled_dot_product_flash_attention.default + ): if len(node.args) == 6: - query, key, value, dropout_p, is_causal, return_debug_mask = node.args + query, key, value, dropout_p, is_causal, return_debug_mask = ( + node.args + ) elif len(node.args) == 3: query, key, value = node.args dropout_p = 0.0 is_causal = True else: - raise ValueError(f"Unexpected number of arguments for {node.target} in the graph") + raise ValueError( + f"Unexpected number of arguments for {node.target} in the graph" + ) if attn_mask is not None: - logger.warning(f"We do not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration.") + logger.warning( + f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration." + ) modified_input_args = (query, key, value, None, dropout_p, is_causal) @@ -68,7 +93,7 @@ def replace_variants_of_sdpa( new_node = gm.graph.call_function( torch.nn.functional.scaled_dot_product_attention, args=modified_input_args, - kwargs={"scale": node.kwargs.get("scale", None)} + kwargs={"scale": node.kwargs.get("scale", None)}, ) # Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead. @@ -80,14 +105,16 @@ def replace_variants_of_sdpa( if user.args[1] == 0: # Replace all uses of the getitem with the new attention node user.replace_all_uses_with(new_node) - new_node.meta['val'] = new_node.meta['val'][0] + new_node.meta["val"] = new_node.meta["val"][0] # Replace all uses of the original node with the new node node.replace_all_uses_with(new_node) gm.graph.erase_node(node) - + # Clean up the graph clean_up_graph_after_modifications(gm) - - logger.info("Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention") - return gm \ No newline at end of file + + logger.info( + "Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention" + ) + return gm From da77ba0bfaa6e5a18a64d289b20a04267e0b5e97 Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 28 May 2025 16:09:53 -0700 Subject: [PATCH 4/5] chore: minor msg update --- examples/dynamo/register_sdpa.py | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/dynamo/register_sdpa.py b/examples/dynamo/register_sdpa.py index 6161c0afd4..7436f31939 100644 --- a/examples/dynamo/register_sdpa.py +++ b/examples/dynamo/register_sdpa.py @@ -46,7 +46,6 @@ def replace_variants_of_sdpa( node.target == torch.ops.aten._scaled_dot_product_efficient_attention.default ): - breakpoint() if len(node.args) == 7: ( query, From 8ab4623a412da8678a20da73079c400ce910161e Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Wed, 28 May 2025 16:19:38 -0700 Subject: [PATCH 5/5] chore: linter fixes --- examples/dynamo/sdpa_converter.py | 67 +++++++++++++------ py/torch_tensorrt/dynamo/lowering/__init__.py | 2 +- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/examples/dynamo/sdpa_converter.py b/examples/dynamo/sdpa_converter.py index 642d40b065..903324dff5 100644 --- a/examples/dynamo/sdpa_converter.py +++ b/examples/dynamo/sdpa_converter.py @@ -1,9 +1,11 @@ +import logging import math -from typing import Optional, Union, Tuple, Any, Dict -import torch -import torch_tensorrt +from typing import Any, Dict, Optional, Tuple, Union + import numpy as np import tensorrt as trt +import torch +import torch_tensorrt from torch.fx.node import Target from torch_tensorrt._enums import dtype from torch_tensorrt.dynamo.conversion import impl @@ -14,9 +16,10 @@ get_trt_tensor, ) from torch_tensorrt.fx.types import TRTTensor -import logging + logger = logging.getLogger(__name__) + def tril( ctx: ConversionContext, target: Union[Target, str], @@ -25,17 +28,31 @@ def tril( row: TRTTensor, col: TRTTensor, ) -> TRTTensor: - row_arange_tensor = impl.arange.arange(ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1) - row_reshape_tensor = impl.shuffle.reshape(ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1]) + row_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1 + ) + row_reshape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1] + ) - col_arange_tensor = impl.arange.arange(ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1) - col_reshape_tensor = impl.shuffle.reshape(ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col]) - - mask = impl.elementwise.ge(ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor) + col_arange_tensor = impl.arange.arange( + ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1 + ) + col_reshape_tensor = impl.shuffle.reshape( + ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col] + ) + + mask = impl.elementwise.ge( + ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor + ) return mask -@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(torch.nn.functional.scaled_dot_product_attention, enabled=True, supports_dynamic_shapes=True) +@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter( + torch.nn.functional.scaled_dot_product_attention, + enabled=True, + supports_dynamic_shapes=True, +) def scaled_dot_product_attention( ctx: torch_tensorrt.dynamo.conversion.ConversionContext, target: Target, @@ -45,11 +62,12 @@ def scaled_dot_product_attention( ) -> TRTTensor: # TODO: Handle attn_mask and is_causal arguments in the future query, key, value, attn_mask, dropout_p, is_causal = args - logger.info("Ignoring attn_mask and is_causal arguments provided by the original graph. " - "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " - "and for generate phase, is_causal=False since we pass only 1 input token at a time") + logger.info( + "Ignoring attn_mask and is_causal arguments provided by the original graph. " + "This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True " + "and for generate phase, is_causal=False since we pass only 1 input token at a time" + ) - # TODO: remove this once we have a better way to handle the causal mask scale = kwargs.get("scale", None) source_ir = SourceIR.ATEN @@ -114,12 +132,21 @@ def scaled_dot_product_attention( temp_mask = impl.unary.logical_not( ctx, target, source_ir, name + "_logical_not", tril_tensor ) - temp_mask_casted = cast_trt_tensor(ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir) + temp_mask_casted = cast_trt_tensor( + ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir + ) one_minus_temp_mask = impl.elementwise.sub( - ctx, target, source_ir, name + "_one_minus_temp_mask", 1.0, temp_mask_casted + ctx, + target, + source_ir, + name + "_one_minus_temp_mask", + 1.0, + temp_mask_casted, ) - attn_bias = impl.unary.log(ctx, target, source_ir, name + "_log", one_minus_temp_mask) - + attn_bias = impl.unary.log( + ctx, target, source_ir, name + "_log", one_minus_temp_mask + ) + scaled_add_attn_bias = impl.elementwise.add( ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias ) @@ -146,4 +173,4 @@ def scaled_dot_product_attention( value, ) - return out \ No newline at end of file + return out diff --git a/py/torch_tensorrt/dynamo/lowering/__init__.py b/py/torch_tensorrt/dynamo/lowering/__init__.py index 86d3500a9d..bec5e407b5 100644 --- a/py/torch_tensorrt/dynamo/lowering/__init__.py +++ b/py/torch_tensorrt/dynamo/lowering/__init__.py @@ -1,7 +1,7 @@ from ._decomposition_groups import ( + TORCH_TRT_DECOMPOSITIONS, torch_disabled_decompositions, torch_enabled_decompositions, - TORCH_TRT_DECOMPOSITIONS ) from ._decompositions import get_decompositions # noqa: F401 from .passes import *