diff --git a/examples/apps/flux_demo.py b/examples/apps/flux_demo.py index c0028d1459..5220f38ec6 100644 --- a/examples/apps/flux_demo.py +++ b/examples/apps/flux_demo.py @@ -120,7 +120,7 @@ def forward_loop(mod): "enabled_precisions": enabled_precisions, "truncate_double": True, "min_block_size": 1, - "use_python_runtime": False, + "use_python_runtime": True, "immutable_weights": False, "offload_module_to_cpu": args.low_vram_mode, "use_explicit_typing": use_explicit_typing, diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index f1a7f9a8fc..fe9a01b06c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -532,6 +532,7 @@ def aten_ops_gelu( @dynamo_tensorrt_converter(torch.ops.aten.matmul, supports_dynamic_shapes=True) +@dynamo_tensorrt_converter(torch.ops.aten.matmul.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.dot.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mm.default, supports_dynamic_shapes=True) @dynamo_tensorrt_converter(torch.ops.aten.mv.default, supports_dynamic_shapes=True) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py index ab9629b0db..097a81b8d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py @@ -1,8 +1,8 @@ +import logging import operator import warnings from typing import Any, Callable, Optional, Union -import numpy as np import tensorrt as trt import torch from torch.fx.node import Target @@ -20,6 +20,8 @@ ) from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor +logger = logging.getLogger(__name__) + def get_python_op_from_trt_elementwise_op( trt_op: TRTElementWiseOp, @@ -148,7 +150,11 @@ def convert_binary_elementwise( ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir ) - if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape): + if len(lhs_val.shape) == len(rhs_val.shape) and all( + a == b or a == 1 or b == 1 for a, b in zip(lhs_val.shape, rhs_val.shape) + ): + logger.info(f"skip broadcast for {name}") + elif has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape): lhs_val, rhs_val = broadcast( ctx, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs" ) diff --git a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py index 52b541d3a8..eca5d7fe77 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py +++ b/py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py @@ -172,6 +172,7 @@ aten.upsample_trilinear3d.vec, aten.upsample_bicubic2d.vec, aten.linear.default, + aten.matmul.default, } diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 4cead5d0cb..fb7b833a5f 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -9,7 +9,6 @@ _get_decomp_for_cia, ) from torch._ops import OpOverload - from torch_tensorrt.dynamo._defaults import default_device from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim from torch_tensorrt.dynamo.utils import to_torch_device @@ -423,8 +422,8 @@ def instance_norm_decomposition( @register_torch_trt_decomposition( torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS -) # type: ignore -def full_like_decomposition(*args, **kwargs) -> torch.Tensor: +) +def full_like_decomposition(*args: Any, **kwargs: Any) -> torch.Tensor: input = args[0] shape = args[0].shape fill_value = args[1] @@ -454,11 +453,13 @@ def scaled_dot_product_attention_decomposition( ) -> torch.Tensor: L, S = query.size(-2), key.size(-2) device = query.device - attn_bias = torch.zeros(L, S, dtype=query.dtype, device=device) + + if is_causal or attn_mask is not None: + attn_bias = torch.zeros((L, S), dtype=query.dtype, device=device) if is_causal: assert attn_mask is None, "attn_mask must be None when is_causal=True" - temp_mask = torch.ones(L, S, dtype=torch.bool, device=device).tril(diagonal=0) + temp_mask = torch.ones((L, S), dtype=torch.bool, device=device).tril(diagonal=0) attn_bias = attn_bias.masked_fill(temp_mask.logical_not(), float("-inf")) if attn_mask is not None: @@ -471,7 +472,7 @@ def scaled_dot_product_attention_decomposition( key = key.repeat_interleave(query.size(-3) // key.size(-3), -3) value = value.repeat_interleave(query.size(-3) // value.size(-3), -3) - attn_weight = query @ key.transpose(-2, -1) + attn_weight = torch.matmul(query, key.transpose(-2, -1)) if scale is None: scale = torch.sqrt(torch.scalar_tensor(query.size(-1), dtype=torch.int)) @@ -479,9 +480,12 @@ def scaled_dot_product_attention_decomposition( else: attn_weight = attn_weight * scale - attn_weight = attn_weight + attn_bias + if is_causal or attn_mask is not None: + # We only add attn_bias when we have to, otherwise this will have a negative impact on the performance even it's 0. + attn_weight = attn_weight + attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) - return attn_weight @ value + return torch.matmul(attn_weight, value) @register_torch_trt_decomposition( diff --git a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py index e569c45cfa..282693d299 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py @@ -10,6 +10,18 @@ def split_addmm_nodes(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Splits all `torch.ops.aten.addmm.default` nodes in the FX graph into separate + `add` and `mm` nodes. This is useful for passes that want to insert additional + logic (such as FP32 accumulation) specifically around the matrix multiplication + operation, rather than the fused addmm. + + Args: + gm (torch.fx.GraphModule): The FX graph module to transform. + + Returns: + torch.fx.GraphModule: The modified FX graph module with addmm nodes split. + """ target = torch.ops.aten.addmm.default addmm_nodes = [node for node in gm.graph.nodes if node.target == target] for addmm_node in addmm_nodes: @@ -52,6 +64,7 @@ def accumulate_fp32_matmul( matmul_targets = [ torch.ops.aten.mm.default, torch.ops.aten.bmm.default, + torch.ops.aten.matmul.default, ] # Split torch.addmm nodes into add + mm and only add cast nodes around mm nodes diff --git a/tools/perf/Flux/benchmark.sh b/tools/perf/Flux/benchmark.sh index 79f5e4b66c..3b29ac0989 100644 --- a/tools/perf/Flux/benchmark.sh +++ b/tools/perf/Flux/benchmark.sh @@ -1,9 +1,20 @@ #TODO: Enter the HF Token huggingface-cli login --token HF_TOKEN +nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> pytorch_fp16_gpu_utilization.txt & +NVIDIA_SMI_PID=$! +python flux_perf.py --pytorch --max_batch_size 3 > pytorch_fp16_benchmark.txt +kill $NVIDIA_SMI_PID + nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp8_gpu_utilization.txt & NVIDIA_SMI_PID=$! -python flux_perf.py --dtype fp8 --low_vram_mode> fp8_benchmark.txt +python flux_perf.py --dtype fp8 --max_batch_size 3 > fp8_benchmark.txt +kill $NVIDIA_SMI_PID + + +nvidia-smi --query-gpu=index,utilization.gpu,utilization.memory,temperature.gpu,temperature.memory,power.draw,clocks.sm,clocks.mem,memory.total,memory.used --format=csv,nounits -lms 500 >> fp16_gpu_utilization.txt & +NVIDIA_SMI_PID=$! +python flux_perf.py --dtype fp16 --max_batch_size 3 > fp16_benchmark.txt kill $NVIDIA_SMI_PID diff --git a/tools/perf/Flux/flux_perf.py b/tools/perf/Flux/flux_perf.py index 1d3b2acbbc..969f4c93d8 100644 --- a/tools/perf/Flux/flux_perf.py +++ b/tools/perf/Flux/flux_perf.py @@ -44,9 +44,22 @@ def benchmark(pipe, prompt, inference_step, batch_size=1, iterations=1): return +from diffusers import FluxPipeline + + def main(args): print(f"Running flux_perfwith args: {args}") - pipe, backbone, trt_gm = compile_model(args) + if not args.pytorch: + pipe, backbone, trt_gm = compile_model(args) + else: + pipe = ( + FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", + torch_dtype=torch.float16, + ) + .to(torch.float16) + .to("cuda:0") + ) benchmark(pipe, ["Test"], 20, batch_size=args.max_batch_size, iterations=3) @@ -83,6 +96,11 @@ def main(args): action="store_true", help="Use dynamic shapes", ) + parser.add_argument( + "--pytorch", + action="store_true", + help="Use pytorch runtime and no tensorrt", + ) parser.add_argument("--max_batch_size", type=int, default=1) args = parser.parse_args() main(args)