Skip to content

Attention Dispatcher #11368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed

Attention Dispatcher #11368

wants to merge 14 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Apr 19, 2025

Usage

# test.py
import torch
from diffusers import Lumina2Pipeline, attention_backend

pipe = Lumina2Pipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "A cat holding a sign that says 'Hello, World!' in a colorful park with flowers and trees"

with attention_backend("sage_varlen"):
    image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")
# fails because flex attention requires head dim to be a power of 2
DIFFUSERS_ATTN_PROVIDER="flex" CUDA_VISIBLE_DEVICES=3 python3 test.py
# dispatches to cudnn internally in pytorch, so it's the same as using "_native_cudnn" (see below)
DIFFUSERS_ATTN_PROVIDER="native" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="flash_varlen" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="sage_varlen" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="_native_cudnn" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="_native_efficient" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="xformers" CUDA_VISIBLE_DEVICES=3 python3 test.py
attention-only benchmark
import torch
from diffusers.models.attention_dispatch import attention_backend, dispatch_attention_fn

torch.manual_seed(0)

# Wan 1.3B/CogVideoX
batch = 1
num_heads = 12
head_dim = 128
dtype = torch.bfloat16

resolutions = [(1, 512, 512), (1, 1024, 1024), (49, 480, 720), (29, 1024, 1024), (81, 480, 832)]
seq_lens = [((res[0] - 1) // 4 + 1) * res[1] * res[2] // 8 // 8 // 4 for res in resolutions]
print("Sequence lengths:", seq_lens)

for seq_len in seq_lens:
    flops = 4 * batch * num_heads * head_dim * seq_len * seq_len

    torch.manual_seed(0)
    query = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
    key = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
    value = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")

    results = {}
    
    for backend in ["flash", "flash_varlen", "_native_flash", "_native_cudnn", "_native_efficient", "xformers", "_sage_qk_int8_pv_fp16_cuda"]:
        with attention_backend(backend):
            for _ in range(5):
                # Warmup
                _ = dispatch_attention_fn(query, key, value)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            result = dispatch_attention_fn(query, key, value)
            end.record()
            torch.cuda.synchronize()

            elapsed_time = start.elapsed_time(end) / 1000
            results[backend] = elapsed_time
    
    tflops_s_flash = flops / results["flash"] / 1e12
    tflops_s_flash_varlen = flops / results["flash_varlen"] / 1e12
    tflops_s_native_flash = flops / results["_native_flash"] / 1e12
    tflops_s_native_cudnn = flops / results["_native_cudnn"] / 1e12
    tflops_s_native_efficient = flops / results["_native_efficient"] / 1e12
    tflops_s_xformers = flops / results["xformers"] / 1e12
    tflops_s_sage_qk_int8_pv_fp16_cuda = flops / results["_sage_qk_int8_pv_fp16_cuda"] / 1e12

    print()
    print(f"Shape: {query.shape}")
    print(f"TFLOPs: {flops / 1e12:.2f}")
    print("===== TFLOPS =====")
    print(f"                     (flash): {tflops_s_flash:.2f}")
    print(f"              (flash_varlen): {tflops_s_flash_varlen:.2f}")
    print(f"              (native_flash): {tflops_s_native_flash:.2f}")
    print(f"              (native_cudnn): {tflops_s_native_cudnn:.2f}")
    print(f"          (native_efficient): {tflops_s_native_efficient:.2f}")
    print(f"                  (xformers): {tflops_s_xformers:.2f}")
    print(f"(_sage_qk_int8_pv_fp16_cuda): {tflops_s_sage_qk_int8_pv_fp16_cuda:.2f}")
    print("==========")
Model benchmark
import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
import torch.nn.attention.flex_attention
from diffusers import (
    AllegroPipeline,
    CogVideoXPipeline,
    FluxPipeline,
    HunyuanVideoPipeline,
    LattePipeline,
    LTXPipeline,
    MochiPipeline,
    WanPipeline,
    AttentionBackendName,
    attention_backend,
)
from diffusers.hooks import apply_group_offloading
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/home/aryan/work/diffusers")
branch = repo.active_branch

torch.nn.attention.flex_attention.flex_attention = torch.compile(torch.nn.attention.flex_attention.flex_attention, mode="max-autotune", dynamic=False, fullgraph=True)
torch.nn.attention.flex_attention.create_block_mask = torch.compile(torch.nn.attention.flex_attention.create_block_mask, mode="max-autotune", dynamic=False, fullgraph=True)

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_accumulation = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True


def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output


def prepare_allegro(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "rhymes-ai/Allegro"
    cache_dir = None

    pipe = AllegroPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats.",
        "height": 720,
        "width": 1280,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_cogvideox_1_0(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "THUDM/CogVideoX-5b"
    cache_dir = None

    pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt=(
            "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
            "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
            "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
            "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
            "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
            "atmosphere of this unique musical performance."
        ),
        device="cuda",
        dtype=dtype,
    )

    pipe.text_encoder.to("cpu")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "height": 480,
        "width": 720,
        "num_frames": 49,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_flux(dtype: torch.dtype, compile: bool = False, **kwargs) -> None:
    model_id = "black-forest-labs/FLUX.1-dev"
    cache_dir = "/raid/.cache/huggingface"

    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.vae.enable_tiling()

    pipe.text_encoder.to("cuda")
    pipe.text_encoder_2.to("cuda")
    prompt_embeds, pooled_prompt_embeds, _ = pipe.encode_prompt(
        prompt="A cat holding a sign that says hello world", prompt_2=None, device="cuda"
    )
    pipe.text_encoder.to("cpu")
    pipe.text_encoder_2.to("cpu")
    del pipe.text_encoder
    del pipe.text_encoder_2
    pipe.text_encoder = None
    pipe.text_encoder_2 = None
    pipe.to("cuda")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
        "height": 768,
        "width": 768,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_hunyuan_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "hunyuanvideo-community/HunyuanVideo"
    cache_dir = None

    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
        model_id, subfolder="transformer", torch_dtype=torch.bfloat16
    )
    pipe = HunyuanVideoPipeline.from_pretrained(
        model_id, transformer=transformer, torch_dtype=torch.float16, cache_dir=cache_dir
    )
    pipe.to("cuda")

    prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
        prompt="A cat wearing sunglasses and working as a lifeguard at pool.", device="cuda", dtype=torch.float16
    )
    pipe.text_encoder.to("cpu")
    pipe.text_encoder_2.to("cpu")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
        "prompt_attention_mask": prompt_attention_mask,
        "height": 320,
        "width": 512,
        "num_frames": 61,
        "num_inference_steps": 30,
    }

    return pipe, generation_kwargs


def prepare_latte(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "maxin-cn/Latte-1"
    cache_dir = None

    pipe = LattePipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt="A cat wearing sunglasses and working as a lifeguard at pool.",
        do_classifier_free_guidance=True,
        num_videos_per_prompt=1,
        device="cuda",
    )
    pipe.text_encoder.to("cpu")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "height": 512,
        "width": 512,
        "video_length": 16,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_ltx_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "a-r-r-o-w/LTX-Video-diffusers"
    cache_dir = None

    pipe = LTXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    (
        prompt_embeds,
        prompt_attention_mask,
        negative_prompt_embeds,
        negative_prompt_attention_mask,
    ) = pipe.encode_prompt(
        prompt="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage",
        negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
        do_classifier_free_guidance=True,
        num_videos_per_prompt=1,
        device="cuda",
    )
    pipe.text_encoder.to("cpu")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)
    
    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "prompt_attention_mask": prompt_attention_mask,
        "negative_prompt_embeds": negative_prompt_embeds,
        "negative_prompt_attention_mask": negative_prompt_attention_mask,
        "width": 768,
        "height": 512,
        "num_frames": 161,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_mochi(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "genmo/mochi-1-preview"
    cache_dir = None

    pipe = MochiPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
        "height": 480,
        "width": 848,
        "num_frames": 85,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_wan(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
    cache_dir = None

    pipe = WanPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    
    prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
    negative_prompt = "worst quality, low quality, blurry, distorted, out of focus, bad composition"
    
    pipe.text_encoder.to("cuda")
    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt=prompt,
        negative_prompt=negative_prompt,
        do_classifier_free_guidance=True,
        num_videos_per_prompt=1,
        device="cuda",
    )
    pipe.text_encoder.to("cpu")
    del pipe.text_encoder
    pipe.text_encoder = None

    pipe.to("cuda")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)
    
    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "height": 480,
        "width": 832,
        "num_frames": 81,
        "guidance_scale": 5.0,
        "num_inference_steps": 30,
        **kwargs,
    }

    return pipe, generation_kwargs


def decode_allegro(pipe: AllegroPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_cogvideox_1_0(pipe: CogVideoXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


def decode_hunyuan_video(pipe: HunyuanVideoPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_latte(pipe: LattePipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents, video_length=kwargs["video_length"])
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_ltx_video(pipe: LTXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latent_num_frames = (kwargs["num_frames"] - 1) // pipe.vae_temporal_compression_ratio + 1
    latent_height = kwargs["height"] // pipe.vae_spatial_compression_ratio
    latent_width = kwargs["width"] // pipe.vae_spatial_compression_ratio

    latents = pipe._unpack_latents(
        latents,
        latent_num_frames,
        latent_height,
        latent_width,
        pipe.transformer_spatial_patch_size,
        pipe.transformer_temporal_patch_size,
    )
    latents = pipe._denormalize_latents(
        latents, pipe.vae.latents_mean, pipe.vae.latents_std, pipe.vae.config.scaling_factor
    )
    latents = latents.to(pipe.vae.dtype)

    timestep = None
    video = pipe.vae.decode(latents, timestep, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=24)
    return filename


def decode_mochi(pipe: MochiPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents_std = torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_wan(pipe: WanPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents = latents.to(pipe.vae.dtype)
    latents_mean = (
        torch.tensor(pipe.vae.config.latents_mean)
        .view(1, pipe.vae.config.z_dim, 1, 1, 1)
        .to(latents.device, latents.dtype)
    )
    latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(
        latents.device, latents.dtype
    )
    latents = latents / latents_std + latents_mean
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=16)
    return filename


def reset_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()


MODEL_MAPPING = {
    "allegro": {
        "prepare": prepare_allegro,
        "decode": decode_allegro,
    },
    "cogvideox-1.0": {
        "prepare": prepare_cogvideox_1_0,
        "decode": decode_cogvideox_1_0,
    },
    "flux": {
        "prepare": prepare_flux,
        "decode": decode_flux,
    },
    "hunyuan_video": {
        "prepare": prepare_hunyuan_video,
        "decode": decode_hunyuan_video,
    },
    "latte": {
        "prepare": prepare_latte,
        "decode": decode_latte,
    },
    "ltx_video": {
        "prepare": prepare_ltx_video,
        "decode": decode_ltx_video,
    },
    "mochi": {
        "prepare": prepare_mochi,
        "decode": decode_mochi,
    },
    "wan": {
        "prepare": prepare_wan,
        "decode": decode_wan,
    }
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator().manual_seed(181201)
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


from diffusers.hooks import ModelHook, HookRegistry
from accelerate.utils import send_to_device

class MoveToCUDAHook(ModelHook):
    def pre_forward(self, module, *args, **kwargs):
        args = send_to_device(args, "cuda")
        kwargs = send_to_device(kwargs, "cuda")
        return args, kwargs

    def post_forward(self, module, output):
        output = send_to_device(output, "cpu")
        return output


@torch.no_grad()
def main(model_id: str, output_dir: str, dtype: str, offloading_type: str, num_blocks_per_group: int, use_stream: bool, compile: bool, attn_provider: str, num_images_per_prompt: int):
    if attn_provider == "flex":
        import torch.nn.attention.flex_attention as flex_attention

        flex_attention.flex_attention = torch.compile(flex_attention.flex_attention, mode="max-autotune-no-cudagraphs", fullgraph=True)
        flex_attention.create_block_mask = torch.compile(flex_attention.create_block_mask, mode="max-autotune-no-cudagraphs", fullgraph=True)

    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]
    reset_memory()

    try:
        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=compute_dtype)

        extra_keys = {}
        if model_id == "wan":
            extra_keys = {"num_videos_per_prompt": num_images_per_prompt}
        else:
            extra_keys = {"num_images_per_prompt": num_images_per_prompt}
        generation_kwargs.update(extra_keys)

        # 2. Apply group offloading
        if offloading_type == "model":
            pipe.enable_model_cpu_offload()
        elif offloading_type == "sequential":
            pipe.enable_sequential_cpu_offload()
        elif offloading_type in ["block_level", "leaf_level"]:
            apply_group_offloading(
                pipe.transformer,
                offload_type=offloading_type,
                num_blocks_per_group=num_blocks_per_group,
                offload_device=torch.device("cpu"),
                onload_device=torch.device("cuda"),
                non_blocking=True,
                use_stream=use_stream,
            )
        else:
            pipe.transformer.to("cuda")
            # registry = HookRegistry.check_if_exists_or_initialize(pipe.transformer)
            # registry.register_hook(MoveToCUDAHook(), "MoveToCUDAHook")
        
        pipe.vae.to("cuda")
        torch.cuda.synchronize()

        reset_memory()
        model_max_memory_reserved = round(torch.cuda.max_memory_allocated() / 1024**3, 3)

        if compile:
            pipe.transformer = torch.compile(
                pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False
            )

        registry_vae = HookRegistry.check_if_exists_or_initialize(pipe.vae.decoder)
        registry_vae.register_hook(MoveToCUDAHook(), "MoveToCUDAHook")

        # 3. Warmup
        num_warmups = 1
        original_num_inference_steps = generation_kwargs["num_inference_steps"]
        generation_kwargs["num_inference_steps"] = 2
        with attention_backend(attn_provider):
            for _ in range(num_warmups):
                run_inference(pipe, generation_kwargs)
        generation_kwargs["num_inference_steps"] = original_num_inference_steps

        # 4. Benchmark
        with attention_backend(attn_provider):
            time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_max_memory_reserved = round(torch.cuda.max_memory_allocated() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---attn_provider-{attn_provider}---dtype-{dtype}---offloading_type-{offloading_type}---num_blocks_per_group-{num_blocks_per_group}---use_stream-{use_stream}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            num_frames=generation_kwargs.get("num_frames", None),
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "attn_provider": attn_provider,
            "time": time,
            "offloading_type": offloading_type,
            "use_stream": use_stream,
            "num_blocks": num_blocks_per_group,
            "model_memory": model_max_memory_reserved,
            "inference_memory": inference_max_memory_reserved,
            "compile": compile,
            "compute_dtype": dtype,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "attn_provider": attn_provider,
            "time": None,
            "offloading_type": offloading_type,
            "use_stream": use_stream,
            "num_blocks": num_blocks_per_group,
            "model_memory": None,
            "inference_memory": None,
            "compile": compile,
            "compute_dtype": dtype,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux", "cogvideox-1.0", "latte", "allegro", "hunyuan_video", "mochi", "ltx_video", "wan"],
        help="Model to run benchmark for.",
    )
    parser.add_argument("--attn_provider", type=str, default="native", choices=[x.value for x in AttentionBackendName.__members__.values()])
    parser.add_argument("--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt.")
    parser.add_argument(
        "--output_dir", required=True, type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument("--offloading_type", type=str, default="none", choices=["none", "model", "block_level", "leaf_level"], help="Type of offloading to use.")
    parser.add_argument("--num_blocks_per_group", type=int, default=None, help="Number of layers per group for group offloading.")
    parser.add_argument("--use_stream", action="store_true", default=False, help="Whether to use CUDA streams for offloading.")
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether to torch.compile the denoiser.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(
        args.model_id,
        args.output_dir,
        args.dtype,
        args.offloading_type,
        args.num_blocks_per_group,
        args.use_stream,
        args.compile,
        args.attn_provider,
        args.num_images_per_prompt,
    )
Results: 4090

Results with PyTorch 2.7 stable, CUDA 12.6

Wan

model_id attn_provider time offloading_type use_stream num_blocks model_memory inference_memory compile
wan flash 142.816 none False 2.912 4.455 False
wan flash_varlen 144.221 none False 2.912 4.455 False
wan flex 146.176 none False 2.912 4.455 False
wan native 144.692 none False 2.912 4.455 False
wan _native_cudnn 144.901 none False 2.912 4.455 False
wan _native_efficient 184.593 none False 2.912 4.455 False
wan _native_flash 144.611 none False 2.912 4.455 False
wan sage 102.281 none False 2.912 4.455 False
wan sage_varlen 112.254 none False 2.912 4.455 False
wan xformers 142.909 none False 2.912 4.455 False
wan flash 147.230 leaf_level True 0.249 1.819 False
wan flash_varlen 148.197 leaf_level True 0.249 1.819 False
wan flex 150.197 leaf_level True 0.249 1.819 False
wan native 148.783 leaf_level True 0.249 1.819 False
wan _native_cudnn 149.177 leaf_level True 0.249 1.819 False
wan _native_efficient 188.643 leaf_level True 0.249 1.819 False
wan _native_flash 148.753 leaf_level True 0.249 1.819 False
wan sage 106.032 leaf_level True 0.249 1.819 False
wan sage_varlen 116.081 leaf_level True 0.249 1.819 False
wan xformers 147.119 leaf_level True 0.249 1.819 False
Results: A100

Results with PyTorch 2.7 stable, CUDA 12.2

Wan

model_id attn_provider time offloading_type use_stream num_blocks model_memory inference_memory compile
wan flash 123.107 none False 2.912 4.455 False
wan flash_varlen 125.355 none False 2.912 4.455 False
wan flex 143.088 none False 2.912 4.455 False
wan native 130.183 none False 2.912 4.455 False
wan _native_cudnn 137.591 none False 2.912 4.455 False
wan _native_efficient 183.795 none False 2.912 4.455 False
wan _native_flash 131.384 none False 2.912 4.455 False
wan sage 119.741 none False 2.912 4.455 False
wan sage_varlen 131.515 none False 2.912 4.455 False
wan xformers 125.414 none False 2.912 4.455 False
wan flash 127.484 leaf_level True 0.249 1.819 False
wan flash_varlen 129.351 leaf_level True 0.249 1.819 False
wan flex 146.739 leaf_level True 0.249 1.819 False
wan native 133.718 leaf_level True 0.249 1.819 False
wan _native_cudnn 141.970 leaf_level True 0.249 1.819 False
wan _native_efficient 188.268 leaf_level True 0.249 1.819 False
wan _native_flash 133.996 leaf_level True 0.249 1.819 False
wan sage 123.269 leaf_level True 0.249 1.819 False
wan sage_varlen 133.422 leaf_level True 0.249 1.819 False
wan xformers 127.743 leaf_level True 0.249 1.819 False

cc @DN6 @sayakpaul @yiyixuxu

supported: flash, flash_varlen, flex, native, sage, sage_varlen, xformers
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting PR! I only left some higher-level comments. My major comment is around having an attention config class instead of environment vars. Or would that be too much for this PR?


For the attention config class (if decided to proceed that route), I was thinking of the following APIs:

attn_config = AttentionConfig(
    attn_implementation="...",
    enable_gqa=...
)
model.set_attn_config(attn_config)

@a-r-r-o-w
Copy link
Member Author

The environment vars were initially only for my quick testing from CLI instead of changing the code everytime. We can get rid of it completely.

The intended API in my mind, and what currently exists in the PR is with context managers:

from diffusers import attention_provider

with attention_provider("sage_varlen"):
    model(...)

Can change once we finalize something

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's looking good 👍🏽 Nice work! Registry makes sense here. Just some minor comments on the initial pass.

Would also add torch NPU backend and XLA flash attention

hidden_states = torch_npu.npu_fusion_attention(

from torch_xla.experimental.custom_kernel import flash_attention

I do also think configuring attention without env variables and context manager might be needed. e.g. You want to run the transformer in the pipeline with sageattention but the other components can use regular SDPA. Config object that @sayakpaul suggested makes sense.

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review May 16, 2025 12:22
@a-r-r-o-w
Copy link
Member Author

For the attention config class (if decided to proceed that route), I was thinking of the following APIs:

attn_config = AttentionConfig(
    attn_implementation="...",
    enable_gqa=...
)
model.set_attn_config(attn_config)

@sayakpaul @DN6 How would you recommend we set per-model attention backend? The backend info needs to be propagated to the attention dispatcher when the forward method is called. The easiest way and how I've done it for training/CP is to attach a simple pre-forward hook that sets the backend, cp_mesh, and any other attributes, when the forward method is invoked. If you have recommendations, I'll modify the implementation accordingly.

Currently, you need to first replace the calls to F.scaled_dot_product_attention with diffusers.models.attention_dispatch.dispatch_attention_fn in the modeling code and then invoke one or more models under the attention_backend context manager:

from diffusers import attention_backend

with attention_backend("flash_varlen"):
    output = transformer(...)

If context manager is not used, it defaults to the original behaviour of calling native torch attention.

@a-r-r-o-w a-r-r-o-w requested review from DN6, sayakpaul and yiyixuxu May 16, 2025 18:54
@sayakpaul
Copy link
Member

How would you recommend we set per-model attention backend? The backend info needs to be propagated to the attention dispatcher when the forward method is called. The easiest way and how I've done it for training/CP is to attach a simple pre-forward hook that sets the backend, cp_mesh, and any other attributes, when the forward method is invoked.

I was thinking that upon calling set_attn_config() we would set them? I prefer the set method w.r.t the context manager approach as I feel it's a bit more explicit.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking really good. I am currently trying to benchmark against FA3 as well while we're at it. I will update this thread once we have the results.

@sayakpaul
Copy link
Member

@a-r-r-o-w I was able to run FA3 with your code and here are some results:

Expand
Sequence lengths: [1024, 4096, 17550, 32768, 32760]

Shape: torch.Size([1, 12, 1024, 128])
TFLOPs: 0.01
===== TFLOPS =====
                     (flash): 67.79
              (native_flash): 68.02
              (native_cudnn): 60.44
==========

Shape: torch.Size([1, 12, 4096, 128])
TFLOPs: 0.10
===== TFLOPS =====
                     (flash): 677.87
              (native_flash): 325.90
              (native_cudnn): 660.36
==========

Shape: torch.Size([1, 12, 17550, 128])
TFLOPs: 1.89
===== TFLOPS =====
                     (flash): 740.64
              (native_flash): 348.37
              (native_cudnn): 626.17
==========

Shape: torch.Size([1, 12, 32768, 128])
TFLOPs: 6.60
===== TFLOPS =====
                     (flash): 724.79
              (native_flash): 363.38
              (native_cudnn): 701.14
==========

Shape: torch.Size([1, 12, 32760, 128])
TFLOPs: 6.59
===== TFLOPS =====
                     (flash): 669.26
              (native_flash): 353.29
              (native_cudnn): 586.65
==========

I can open a PR to your branch for the changes I had to make to make it work. LMK.

@a-r-r-o-w
Copy link
Member Author

@sayakpaul Super cool, thanks! I hope you didn't face too much trouble with building FA3 😅

I actually already have the required changes for FA3 (and some other things like NPU and XLA) locally. I didn't benchmark yet though so thanks for that, and I can push my changes soon

@sayakpaul
Copy link
Member

I hope you didn't face too much trouble with building FA3 😅

It just took time. I used Docker instead of the default env of the cluster.

@a-r-r-o-w
Copy link
Member Author

Pushed some changes to support FA3, NPU and XLA. They are all marked private since FA3 is a beta release and NPU and XLA are untested.

Pytorch's cudnn backend is close to FA3, but in almost all problem shapes the latter is faster, similar to FA2 from source

@a-r-r-o-w
Copy link
Member Author

@sayakpaul @DN6 Based on our discussion, I've added support for set_attention_backend at the ModelMixin level. There's now two ways to enable dispatcher.

`set_attention_backend("...")` for diffusers native implementations
import torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.attention_processor import Attention


class MyModel(ModelMixin):
    def __init__(self):
        super().__init__()

        self.attention = Attention(
            query_dim=10,
            heads=2,
            dim_head=5,
        )
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(10, 20),
            torch.nn.ReLU(),
            torch.nn.Linear(20, 10),
        )

    def forward(self, x: torch.Tensor):
        return x + self.mlp(x + self.attention(x))


dtype = torch.bfloat16
device = "cuda"

model = MyModel().to(device, dtype=dtype)
input = torch.randn(2, 64, 10).to(device, dtype=dtype)

output_native = model(input)

model.set_attention_backend("flash")
output_flash = model(input)

model.set_attention_backend("sage")
output_sage = model(input)

model.set_attention_backend("_native_math")
output_native_math = model(input)

diff1 = torch.abs(output_native - output_flash).max()
diff2 = torch.abs(output_native - output_sage).max()
diff3 = torch.abs(output_native - output_native_math).max()
print(diff1, diff2, diff3)
context manager for custom implementations
import torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.attention_dispatch import dispatch_attention_fn, attention_backend


class AttentionProcessor:
    def __call__(
        self,
        attn,
        x: torch.Tensor,
    ) -> torch.Tensor:
        q, k, v = (y.unflatten(2, (attn.heads, -1)).permute(0, 2, 1, 3).contiguous() for y in attn.qkv(x).chunk(3, dim=-1))
        return attn.o(dispatch_attention_fn(q, k, v).permute(0, 2, 1, 3).flatten(2))


class Attention(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.heads = 2
        self.qkv = torch.nn.Linear(10, 30)
        self.o = torch.nn.Linear(10, 10)
        self.processor = AttentionProcessor()
    
    def forward(self, x: torch.Tensor):
        return self.processor(self, x)


class MyModel(ModelMixin):
    def __init__(self):
        super().__init__()

        self.attention = Attention()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(10, 20),
            torch.nn.ReLU(),
            torch.nn.Linear(20, 10),
        )

    def forward(self, x: torch.Tensor):
        return x + self.mlp(x + self.attention(x))


dtype = torch.bfloat16
device = "cuda"

model = MyModel().to(device, dtype=dtype)
input = torch.randn(2, 64, 10).to(device, dtype=dtype)

output_native = model(input)

with attention_backend("flash"):
    output_flash = model(input)

with attention_backend("sage"):
    output_sage = model(input)

with attention_backend("_native_math"):
    output_native_math = model(input)

diff1 = torch.abs(output_native - output_flash).max()
diff2 = torch.abs(output_native - output_sage).max()
diff3 = torch.abs(output_native - output_native_math).max()
print(diff1, diff2, diff3)

@sayakpaul
Copy link
Member

This is looking much much better IMO!

I think with proper documentation, we can make the differences between the scopes of set_attn_backend() and set_attn_processor() much clearer.

@a-r-r-o-w
Copy link
Member Author

Continued in #11916

@a-r-r-o-w a-r-r-o-w closed this Jul 15, 2025
@a-r-r-o-w a-r-r-o-w deleted the attention-dispatcher branch July 15, 2025 10:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants