Skip to content

Fixed SDPA slow down and linear slow down #3700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/apps/flux_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def forward_loop(mod):
"enabled_precisions": enabled_precisions,
"truncate_double": True,
"min_block_size": 1,
"use_python_runtime": False,
"use_python_runtime": True,
"immutable_weights": False,
"offload_module_to_cpu": args.low_vram_mode,
"use_explicit_typing": use_explicit_typing,
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def aten_ops_gelu(


@dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this actually needed ? consider removing it if unnecessary

@dynamo_tensorrt_converter(torch.ops.aten.matmul.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True)
Expand Down
10 changes: 8 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import operator
import warnings
from typing import Any, Callable, Optional, Union

import numpy as np
import tensorrt as trt
import torch
from torch.fx.node import Target
Expand All @@ -20,6 +20,8 @@
)
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor

logger = logging.getLogger(__name__)


def get_python_op_from_trt_elementwise_op(
trt_op: TRTElementWiseOp,
Expand Down Expand Up @@ -148,7 +150,11 @@ def convert_binary_elementwise(
ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir
)

if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
if len(lhs_val.shape) == len(rhs_val.shape) and all(
a == b or a == 1 or b == 1 for a, b in zip(lhs_val.shape, rhs_val.shape)
):
logger.info(f"skip broadcast for {name}")
elif has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
lhs_val, rhs_val = broadcast(
ctx, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
aten.upsample_trilinear3d.vec,
aten.upsample_bicubic2d.vec,
aten.linear.default,
aten.matmul.default,
}


Expand Down
20 changes: 12 additions & 8 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
_get_decomp_for_cia,
)
from torch._ops import OpOverload

from torch_tensorrt.dynamo._defaults import default_device
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
from torch_tensorrt.dynamo.utils import to_torch_device
Expand Down Expand Up @@ -423,8 +422,8 @@ def instance_norm_decomposition(

@register_torch_trt_decomposition(
torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
) # type: ignore
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
)
def full_like_decomposition(*args: Any, **kwargs: Any) -> torch.Tensor:
input = args[0]
shape = args[0].shape
fill_value = args[1]
Expand Down Expand Up @@ -454,11 +453,13 @@ def scaled_dot_product_attention_decomposition(
) -> torch.Tensor:
L, S = query.size(-2), key.size(-2)
device = query.device
attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device)

if is_causal or attn_mask is not None:
attn_bias = torch.zeros((L, S), dtype=query.dtype, device=device)

if is_causal:
assert attn_mask is None, "attn_mask must be None when is_causal=True"
temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0)
temp_mask = torch.ones((L, S), dtype=torch.bool, device=device).tril(diagonal=0)
attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf"))

if attn_mask is not None:
Expand All @@ -471,17 +472,20 @@ def scaled_dot_product_attention_decomposition(
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

attn_weight = query @ key.transpose(-2, -1)
attn_weight = torch.matmul(query, key.transpose(-2, -1))

if scale is None:
scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int))
attn_weight = attn_weight / scale
else:
attn_weight = attn_weight * scale

attn_weight = attn_weight + attn_bias
if is_causal or attn_mask is not None:
# We only add attn_bias when we have to, otherwise this will have a negative impact on the performance even it's 0.
attn_weight = attn_weight + attn_bias

attn_weight = torch.softmax(attn_weight, dim=-1)
return attn_weight @ value
return torch.matmul(attn_weight, value)


@register_torch_trt_decomposition(
Expand Down
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,18 @@


def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Splits all `torch.ops.aten.addmm.default` nodes in the FX graph into separate
`add` and `mm` nodes. This is useful for passes that want to insert additional
logic (such as FP32 accumulation) specifically around the matrix multiplication
operation, rather than the fused addmm.

Args:
gm (torch.fx.GraphModule): The FX graph module to transform.

Returns:
torch.fx.GraphModule: The modified FX graph module with addmm nodes split.
"""
target = torch.ops.aten.addmm.default
addmm_nodes = [node for node in gm.graph.nodes if node.target == target]
for addmm_node in addmm_nodes:
Expand Down Expand Up @@ -52,6 +64,7 @@ def accumulate_fp32_matmul(
matmul_targets = [
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.matmul.default,
]

# Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes
Expand Down
13 changes: 12 additions & 1 deletion tools/perf/Flux/benchmark.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
#TODO: Enter the HF Token
huggingface-cli login --token HF_TOKEN

nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> pytorch_fp16_gpu_utilization.txt &
NVIDIA_SMI_PID=$!
python flux_perf.py --pytorch --max_batch_size 3 > pytorch_fp16_benchmark.txt
kill $NVIDIA_SMI_PID

nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp8_gpu_utilization.txt &
NVIDIA_SMI_PID=$!
python flux_perf.py --dtype fp8 --low_vram_mode> fp8_benchmark.txt
python flux_perf.py --dtype fp8 --max_batch_size 3 > fp8_benchmark.txt
kill $NVIDIA_SMI_PID


nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp16_gpu_utilization.txt &
NVIDIA_SMI_PID=$!
python flux_perf.py --dtype fp16 --max_batch_size 3 > fp16_benchmark.txt
kill $NVIDIA_SMI_PID


20 changes: 19 additions & 1 deletion tools/perf/Flux/flux_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,22 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1):
return


from diffusers import FluxPipeline


def main(args):
print(f"Running flux_perfwith args: {args}")
pipe, backbone, trt_gm = compile_model(args)
if not args.pytorch:
pipe, backbone, trt_gm = compile_model(args)
else:
pipe = (
FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
)
.to(torch.float16)
.to("cuda:0")
)

benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3)

Expand Down Expand Up @@ -83,6 +96,11 @@ def main(args):
action="store_true",
help="Use dynamic shapes",
)
parser.add_argument(
"--pytorch",
action="store_true",
help="Use pytorch runtime and no tensorrt",
)
parser.add_argument("--max_batch_size", type=int, default=1)
args = parser.parse_args()
main(args)
Loading