diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8c4ae36c5654..01c6ddb0fe69 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -143,6 +143,7 @@ [ "AllegroTransformer3DModel", "AsymmetricAutoencoderKL", + "AttentionBackendName", "AuraFlowTransformer2DModel", "AutoencoderDC", "AutoencoderKL", @@ -215,6 +216,7 @@ "UVit2DModel", "VQModel", "WanTransformer3DModel", + "attention_backend", ] ) _import_structure["optimization"] = [ @@ -749,6 +751,7 @@ from .models import ( AllegroTransformer3DModel, AsymmetricAutoencoderKL, + AttentionBackendName, AuraFlowTransformer2DModel, AutoencoderDC, AutoencoderKL, @@ -820,6 +823,7 @@ UVit2DModel, VQModel, WanTransformer3DModel, + attention_backend, ) from .optimization import ( get_constant_schedule, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 58322800332a..b75d4ae8a542 100755 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -26,6 +26,7 @@ if is_torch_available(): _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"] + _import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"] _import_structure["auto_model"] = ["AutoModel"] _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_dc"] = ["AutoencoderDC"] @@ -109,6 +110,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if is_torch_available(): from .adapter import MultiAdapter, T2IAdapter + from .attention_dispatch import AttentionBackendName, attention_backend from .auto_model import AutoModel from .autoencoders import ( AsymmetricAutoencoderKL, diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py new file mode 100644 index 000000000000..c6c78a44a632 --- /dev/null +++ b/src/diffusers/models/attention_dispatch.py @@ -0,0 +1,1098 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import functools +import inspect +import math +from enum import Enum +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union + +import torch + +from ..utils import ( + get_logger, + is_flash_attn_3_available, + is_flash_attn_available, + is_flash_attn_version, + is_sageattention_available, + is_sageattention_version, + is_torch_npu_available, + is_torch_version, + is_torch_xla_available, + is_torch_xla_version, + is_xformers_available, + is_xformers_version, +) +from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS + + +logger = get_logger(__name__) # pylint: disable=invalid-name + + +if is_flash_attn_available() and is_flash_attn_version(">=", "2.6.3"): + from flash_attn import flash_attn_func, flash_attn_varlen_func +else: + logger.warning("`flash-attn` is not available or the version is too old. Please install `flash-attn>=2.6.3`.") + flash_attn_func = None + flash_attn_varlen_func = None + + +if is_flash_attn_3_available(): + from flash_attn_interface import flash_attn_func as flash_attn_3_func + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func +else: + flash_attn_3_func = None + flash_attn_3_varlen_func = None + + +if is_sageattention_available() and is_sageattention_version(">=", "2.1.1"): + from sageattention import ( + sageattn, + sageattn_qk_int8_pv_fp8_cuda, + sageattn_qk_int8_pv_fp8_cuda_sm90, + sageattn_qk_int8_pv_fp16_cuda, + sageattn_qk_int8_pv_fp16_triton, + sageattn_varlen, + ) +else: + logger.warning( + "`sageattention` is not available or the version is too old. Please install `sageattention>=2.1.1`." + ) + sageattn = None + sageattn_qk_int8_pv_fp16_cuda = None + sageattn_qk_int8_pv_fp16_triton = None + sageattn_qk_int8_pv_fp8_cuda = None + sageattn_qk_int8_pv_fp8_cuda_sm90 = None + sageattn_varlen = None + + +if is_torch_version(">=", "2.5.0"): + # We cannot import the flex_attention function from the package directly because it is expected (from the + # pytorch documentation) that the user may compile it. If we import directly, we will not have access to the + # compiled function. + import torch.nn.attention.flex_attention as flex_attention + + +if is_torch_npu_available(): + from torch_npu import npu_fusion_attention +else: + npu_fusion_attention = None + + +if is_torch_xla_available() and is_torch_xla_version(">", "2.2"): + from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention +else: + xla_flash_attention = None + + +if is_xformers_available() and is_xformers_version(">=", "0.0.29"): + import xformers.ops as xops +else: + logger.warning("`xformers` is not available or the version is too old. Please install `xformers>=0.0.29`.") + xops = None + + +_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] +_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] +_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] + + +class AttentionBackendName(str, Enum): + # EAGER = "eager" + + # `flash-attn` + FLASH = "flash" + FLASH_VARLEN = "flash_varlen" + _FLASH_3 = "_flash_3" + _FLASH_VARLEN_3 = "_flash_varlen_3" + + # PyTorch native + FLEX = "flex" + NATIVE = "native" + _NATIVE_CUDNN = "_native_cudnn" + _NATIVE_EFFICIENT = "_native_efficient" + _NATIVE_FLASH = "_native_flash" + _NATIVE_MATH = "_native_math" + _NATIVE_NPU = "_native_npu" + _NATIVE_XLA = "_native_xla" + + # `sageattention` + SAGE = "sage" + SAGE_VARLEN = "sage_varlen" + _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" + _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" + _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" + _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" + # TODO: let's not add support for Sparge Attention now because it requires tuning per model + # We can look into supporting something "autotune"-ing in the future + # SPARGE = "sparge" + + # `xformers` + XFORMERS = "xformers" + + +class _AttentionBackendRegistry: + _backends = {} + _constraints = {} + _supported_arg_names = {} + _active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND) + _checks_enabled = DIFFUSERS_ATTN_CHECKS + + @classmethod + def register(cls, backend: AttentionBackendName, constraints: Optional[List[Callable]] = None): + logger.debug(f"Registering attention backend: {backend} with constraints: {constraints}") + + def decorator(func): + cls._backends[backend] = func + cls._constraints[backend] = constraints or [] + cls._supported_arg_names[backend] = set(inspect.signature(func).parameters.keys()) + return func + + return decorator + + @classmethod + def get_active_backend(cls): + return cls._active_backend, cls._backends[cls._active_backend] + + @classmethod + def list_backends(cls): + return list(cls._backends.keys()) + + +@contextlib.contextmanager +def attention_backend(backend: AttentionBackendName = AttentionBackendName.NATIVE): + """ + Context manager to set the active attention backend. + """ + if backend not in _AttentionBackendRegistry._backends: + raise ValueError(f"Backend {backend} is not registered.") + + old_backend = _AttentionBackendRegistry._active_backend + _AttentionBackendRegistry._active_backend = backend + + try: + yield + finally: + _AttentionBackendRegistry._active_backend = old_backend + + +def dispatch_attention_fn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + attention_kwargs: Optional[Dict[str, Any]] = None, + *, + backend: Optional[AttentionBackendName] = None, +) -> torch.Tensor: + attention_kwargs = attention_kwargs or {} + + if backend is None: + # If no backend is specified, we either use the default backend (set via the DIFFUSERS_ATTN_BACKEND environment + # variable), or we use a custom backend based on whether user is using the `attention_backend` context manager + backend_name, backend_fn = _AttentionBackendRegistry.get_active_backend() + else: + backend_name = AttentionBackendName(backend) + backend_fn = _AttentionBackendRegistry._backends.get(backend_name) + + kwargs = { + "query": query, + "key": key, + "value": value, + "attn_mask": attn_mask, + "dropout_p": dropout_p, + "is_causal": is_causal, + "scale": scale, + "enable_gqa": enable_gqa, + **attention_kwargs, + } + + if _AttentionBackendRegistry._checks_enabled: + removed_kwargs = set(kwargs) - set(_AttentionBackendRegistry._supported_arg_names[backend_name]) + if removed_kwargs: + logger.warning(f"Removing unsupported arguments for attention backend {backend_name}: {removed_kwargs}.") + for check in _AttentionBackendRegistry._constraints.get(backend_name): + check(**kwargs) + + kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]} + return backend_fn(**kwargs) + + +# ===== Checks ===== +# A list of very simple functions to catch common errors quickly when debugging. + + +def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: + if attn_mask is not None: + raise ValueError("Attention mask must be None for this backend.") + + +def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: + if attn_mask is not None and is_causal: + raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") + + +def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.device != key.device or query.device != value.device: + raise ValueError("Query, key, and value must be on the same device.") + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError("Query, key, and value must have the same dtype.") + + +def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device(query, key, value) + if query.device.type != "cuda": + raise ValueError("Query, key, and value must be on a CUDA device.") + + +def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: + def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_device_cuda(query, key, value) + if torch.cuda.get_device_capability(query.device) < (major, minor): + raise ValueError( + f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." + ) + + return check_device_cuda + + +def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + if query.dtype != key.dtype: + raise ValueError("Query and key must have the same dtype.") + if query.dtype != value.dtype: + raise ValueError("Query and value must have the same dtype.") + + +def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: + _check_qkv_dtype_match(query, key, value) + if query.dtype not in (torch.bfloat16, torch.float16): + raise ValueError("Query, key, and value must be either bfloat16 or float16.") + + +def _check_shape( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + **kwargs, +) -> None: + if query.shape[-1] != key.shape[-1]: + raise ValueError("Query and key must have the same last dimension.") + if query.shape[-2] != value.shape[-2]: + raise ValueError("Query and value must have the same second to last dimension.") + if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: + raise ValueError("Attention mask must match the key's second to last dimension.") + + +# ===== Helper functions ===== + + +@functools.lru_cache(maxsize=8) +def _prepare_for_flash_attn_or_sage_varlen( + batch_size: int, + seq_len_q: int, + seq_len_kv: int, + attn_mask: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, +) -> None: + seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) + if attn_mask is None: + seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) + else: + seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) + cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) + cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) + cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) + max_seqlen_q = seqlens_q.max().item() + max_seqlen_k = seqlens_k.max().item() + return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) + + +def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: + """ + Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_[q|k] in + FlashAttention/Sage varlen. + + Supports 1D to 4D shapes and common broadcasting patterns. + """ + if attn_mask.dtype != torch.bool: + raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") + + if attn_mask.ndim == 1: + # [seq_len_k] -> broadcast across batch + attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 2: + # [batch_size, seq_len_k]. Maybe broadcast across batch + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 3: + # [batch_size, seq_len_q, seq_len_k] -> reduce over query dimension + # We do this reduction because we know that arbitrary QK masks is not supported in Flash/Sage varlen. + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." + ) + attn_mask = attn_mask.any(dim=1) + attn_mask = attn_mask.expand(batch_size, seq_len_k) + + elif attn_mask.ndim == 4: + # [batch_size, num_heads, seq_len_q, seq_len_k] or broadcastable versions + if attn_mask.size(0) not in [1, batch_size]: + raise ValueError( + f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." + ) + attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) # [B, H, Q, K] + attn_mask = attn_mask.any(dim=(1, 2)) # [B, K] + + else: + raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") + + if attn_mask.shape != (batch_size, seq_len_k): + raise ValueError( + f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" + ) + + return attn_mask + + +def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return q_idx >= kv_idx + + +# ===== Attention backends ===== + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH, + constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = flash_attn_func( + q=query, + k=key, + v=value, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + out = out.permute(0, 2, 1, 3) + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLASH_VARLEN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + dropout_p: float = 0.0, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, _, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = flash_attn_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + dropout_p=dropout_p, + softmax_scale=scale, + causal=is_causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + deterministic=deterministic, + return_attn_probs=return_attn_probs, + ) + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) + + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_3, + constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_attention_3( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, +) -> torch.Tensor: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out, lse, *_ = flash_attn_3_func( + q=query, + k=key, + v=value, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + attention_chunk=0, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + out = out.permute(0, 2, 1, 3) + return (out, lse) if return_attn_probs else out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._FLASH_VARLEN_3, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _flash_varlen_attention_3( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + scale: Optional[float] = None, + is_causal: bool = False, + window_size: Tuple[int, int] = (-1, -1), + softcap: float = 0.0, + deterministic: bool = False, + return_attn_probs: bool = False, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, _, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out, lse, *_ = flash_attn_3_varlen_func( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + seqused_q=None, + seqused_k=None, + softmax_scale=scale, + causal=is_causal, + qv=None, + q_descale=None, + k_descale=None, + v_descale=None, + window_size=window_size, + softcap=softcap, + num_splits=1, + pack_gqa=None, + deterministic=deterministic, + sm_margin=0, + ) + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) + + return (out, lse) if return_attn_probs else out + + +@_AttentionBackendRegistry.register( + AttentionBackendName.FLEX, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _native_flex_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + kernel_options: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + # TODO: should we LRU cache the block mask creation? + score_mod = None + block_mask = None + batch_size, num_heads, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): + block_mask = attn_mask + elif is_causal: + block_mask = flex_attention.create_block_mask( + _flex_attention_causal_mask_mod, batch_size, num_heads, seq_len_q, seq_len_kv, query.device + ) + elif torch.is_tensor(attn_mask): + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + + attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) + + if attn_mask.dtype == torch.bool: + # TODO: this probably does not work but verify! + def mask_mod(batch_idx, head_idx, q_idx, kv_idx): + return attn_mask[batch_idx, head_idx, q_idx, kv_idx] + + block_mask = flex_attention.create_block_mask( + mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device + ) + else: + + def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): + return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] + else: + raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") + + return flex_attention.flex_attention( + query=query, + key=key, + value=value, + score_mod=score_mod, + block_mask=block_mask, + scale=scale, + enable_gqa=enable_gqa, + return_lse=return_lse, + kernel_options=kernel_options, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName.NATIVE, + constraints=[_check_device, _check_shape], +) +def _native_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_CUDNN, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_cudnn_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_EFFICIENT, + constraints=[_check_device, _check_shape], +) +def _native_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_FLASH, + constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_MATH, + constraints=[_check_device, _check_shape], +) +def _native_math_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): + return torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_NPU, + constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _native_npu_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + scale: Optional[float] = None, +) -> torch.Tensor: + return npu_fusion_attention( + query, + key, + value, + query.size(1), # num_heads + input_layout="BNSD", + pse=None, + scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale, + pre_tockens=65536, + next_tokens=65536, + keep_prob=1.0 - dropout_p, + sync=False, + inner_precise=0, + )[0] + + +# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853 +@_AttentionBackendRegistry.register( + AttentionBackendName._NATIVE_XLA, + constraints=[_check_device, _check_shape], +) +def _native_xla_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, +) -> torch.Tensor: + query = query / math.sqrt(query.shape[-1]) + return xla_flash_attention( + q=query, + k=key, + v=value, + causal=is_causal, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + sm_scale=scale, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName.SAGE_VARLEN, + constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], +) +def _sage_varlen_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + is_causal: bool = False, + scale: Optional[float] = None, + smooth_k: bool = True, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + batch_size, _, seq_len_q, _ = query.shape + _, _, seq_len_kv, _ = key.shape + + if attn_mask is not None: + attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) + + if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): + (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( + _prepare_for_flash_attn_or_sage_varlen( + batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device + ) + ) + else: + seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) + cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) + cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + key_valid, value_valid = [], [] + for b in range(batch_size): + valid_len = seqlens_k[b] + key_valid.append(key[b, :valid_len]) + value_valid.append(value[b, :valid_len]) + + query_packed = query.flatten(0, 1) + key_packed = torch.cat(key_valid, dim=0) + value_packed = torch.cat(value_valid, dim=0) + + out = sageattn_varlen( + q=query_packed, + k=key_packed, + v=value_packed, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + ) + out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) + + return out + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90, + constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], +) +def _sage_qk_int8_pv_fp8_cuda_sm90_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp8_cuda_sm90( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_cuda_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", + pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32", + smooth_k: bool = True, + smooth_v: bool = False, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_cuda( + q=query, + k=key, + v=value, + tensor_layout="HND", + is_causal=is_causal, + qk_quant_gran=qk_quant_gran, + sm_scale=scale, + pv_accum_dtype=pv_accum_dtype, + smooth_k=smooth_k, + smooth_v=smooth_v, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON, + constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], +) +def _sage_qk_int8_pv_fp16_triton_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + scale: Optional[float] = None, + quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", + smooth_k: bool = True, + return_lse: bool = False, +) -> torch.Tensor: + return sageattn_qk_int8_pv_fp16_triton( + q=query, + k=key, + v=value, + tensor_layout="HND", + quantization_backend=quantization_backend, + is_causal=is_causal, + sm_scale=scale, + smooth_k=smooth_k, + return_lse=return_lse, + ) + + +@_AttentionBackendRegistry.register( + AttentionBackendName.XFORMERS, + constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], +) +def _xformers_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, +) -> torch.Tensor: + batch_size, num_heads_q, seq_len_q, _ = query.shape + _, num_heads_kv, seq_len_kv, _ = key.shape + + # TODO: check if `contiguous` is really needed since it may cause unnecessary slowdowns + if is_causal: + attn_mask = xops.LowerTriangularMask() + elif attn_mask is not None: + if attn_mask.ndim == 2: + attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) + elif attn_mask.ndim != 4: + raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") + attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) + + # QKV need to be in [batch, seq_len, num_heads, head_dim] format for xformers + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + + if enable_gqa: + if num_heads_q % num_heads_kv != 0: + raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") + num_heads_per_group = num_heads_q // num_heads_kv + query = query.unflatten(2, (num_heads_kv, -1)) + key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) + + out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) + if enable_gqa: + out = out.flatten(2, 3) + out = out.permute(0, 2, 1, 3) + return out diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 23ae05e2ab96..802a31e101fb 100755 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -23,6 +23,7 @@ from ..utils import deprecate, is_torch_xla_available, logging from ..utils.import_utils import is_torch_npu_available, is_torch_xla_version, is_xformers_available from ..utils.torch_utils import is_torch_version, maybe_allow_in_graph +from .attention_dispatch import dispatch_attention_fn logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -996,6 +997,8 @@ def forward( class MochiAttnProcessor2_0: """Attention processor used in Mochi.""" + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") @@ -1073,8 +1076,8 @@ def apply_rotary_emb(x, freqs_cos, freqs_sin): valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2) valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2) - attn_output = F.scaled_dot_product_attention( - valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False + attn_output = dispatch_attention_fn( + valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False, backend=self._attention_backend ) valid_sequence_length = attn_output.size(2) attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length)) @@ -2275,6 +2278,8 @@ def __call__( class FluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -2339,8 +2344,14 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -2367,6 +2378,8 @@ def __call__( class FluxAttnProcessor2_0_NPU: """Attention processor used typically in processing the SD3-like self-attention projections.""" + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( @@ -2449,7 +2462,9 @@ def __call__( inner_precise=0, )[0] else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn( + query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2473,6 +2488,8 @@ def __call__( class FusedFluxAttnProcessor2_0: """Attention processor used typically in processing the SD3-like self-attention projections.""" + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( @@ -2543,7 +2560,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn( + query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2568,6 +2587,8 @@ def __call__( class FusedFluxAttnProcessor2_0_NPU: """Attention processor used typically in processing the SD3-like self-attention projections.""" + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError( @@ -2654,7 +2675,9 @@ def __call__( inner_precise=0, )[0] else: - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn( + query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2679,6 +2702,8 @@ def __call__( class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module): """Flux Attention processor for IP-Adapter.""" + _attention_backend = None + def __init__( self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None ): @@ -2776,7 +2801,9 @@ def __call__( query = apply_rotary_emb(query, image_rotary_emb) key = apply_rotary_emb(key, image_rotary_emb) - hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = dispatch_attention_fn( + query, key, value, dropout_p=0.0, is_causal=False, backend=self._attention_backend + ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.to(query.dtype) @@ -2806,8 +2833,14 @@ def __call__( ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - current_ip_hidden_states = F.scaled_dot_product_attention( - ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + current_ip_hidden_states = dispatch_attention_fn( + ip_query, + ip_key, + ip_value, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim @@ -2826,6 +2859,8 @@ class CogVideoXAttnProcessor2_0: query and key vectors, but does not include spatial normalization. """ + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -2872,8 +2907,14 @@ def __call__( if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -2895,6 +2936,8 @@ class FusedCogVideoXAttnProcessor2_0: query and key vectors, but does not include spatial normalization. """ + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -2943,8 +2986,14 @@ def __call__( if not attn.is_cross_attention: key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -3130,9 +3179,10 @@ class AttnProcessorNPU: Processor for implementing flash attention using torch_npu. Torch_npu supports only fp16 and bf16 data types. If fp32 is used, F.scaled_dot_product_attention will be used for computation, but the acceleration effect on NPU is not significant. - """ + _attention_backend = None + def __init__(self): if not is_torch_npu_available(): raise ImportError("AttnProcessorNPU requires torch_npu extensions and is supported only on npu devices.") @@ -3216,8 +3266,14 @@ def __call__( )[0] else: # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -3244,6 +3300,8 @@ class AttnProcessor2_0: Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). """ + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -3310,8 +3368,14 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) @@ -3554,6 +3618,8 @@ class MochiVaeAttnProcessor2_0: Attention processor used in Mochi VAE. """ + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -3614,8 +3680,14 @@ def __call__( # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=attn.is_causal + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=attn.is_causal, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 55ce0cf79fb9..cfd2495a0863 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -599,6 +599,52 @@ def enable_group_offload( low_cpu_mem_usage=low_cpu_mem_usage, ) + def set_attention_backend(self, backend: str) -> None: + """ + Set the attention backend for the model. + + Args: + backend (`str`): + The name of the backend to set. Must be one of the available backends defined in + `AttentionBackendName`. Available backends can be found in + `diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product + attention as backend. + """ + from .attention_dispatch import AttentionBackendName + from .attention_processor import Attention, MochiAttention + + backend = backend.lower() + available_backends = {x.value for x in AttentionBackendName.__members__.values()} + if backend not in available_backends: + raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends)) + + backend = AttentionBackendName(backend) + attention_classes = (Attention, MochiAttention) + + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = backend + + def reset_attention_backend(self) -> None: + """ + Resets the attention backend for the model. Following calls to `forward` will use the environment default or + the torch native scaled dot product attention. + """ + from .attention_processor import Attention, MochiAttention + + attention_classes = (Attention, MochiAttention) + for module in self.modules(): + if not isinstance(module, attention_classes): + continue + processor = module.processor + if processor is None or not hasattr(processor, "_attention_backend"): + continue + processor._attention_backend = None + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/transformers/transformer_lumina2.py b/src/diffusers/models/transformers/transformer_lumina2.py index a873a6ec9444..78a26b89907c 100644 --- a/src/diffusers/models/transformers/transformer_lumina2.py +++ b/src/diffusers/models/transformers/transformer_lumina2.py @@ -24,6 +24,7 @@ from ...loaders.single_file_model import FromOriginalModelMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import LuminaFeedForward +from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention from ..embeddings import TimestepEmbedding, Timesteps, apply_rotary_emb, get_1d_rotary_pos_embed from ..modeling_outputs import Transformer2DModelOutput @@ -71,6 +72,8 @@ class Lumina2AttnProcessor2_0: used in the Lumina2Transformer2DModel model. It applies normalization and RoPE on query and key vectors. """ + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") @@ -137,8 +140,8 @@ def __call__( key = key.transpose(1, 2) value = value.transpose(1, 2) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, scale=softmax_scale + hidden_states = dispatch_attention_fn( + query, key, value, attn_mask=attention_mask, scale=softmax_scale, backend=self._attention_backend ) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) hidden_states = hidden_states.type_as(query) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index baa0ede4184e..84a5e732bc0e 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -23,6 +23,7 @@ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ..attention import FeedForward +from ..attention_dispatch import dispatch_attention_fn from ..attention_processor import Attention from ..cache_utils import CacheMixin from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed @@ -35,6 +36,8 @@ class WanAttnProcessor2_0: + _attention_backend = None + def __init__(self): if not hasattr(F, "scaled_dot_product_attention"): raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") @@ -90,14 +93,26 @@ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) - hidden_states_img = F.scaled_dot_product_attention( - query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + hidden_states_img = dispatch_attention_fn( + query, + key_img, + value_img, + attn_mask=None, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) hidden_states_img = hidden_states_img.type_as(query) - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + hidden_states = dispatch_attention_fn( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + backend=self._attention_backend, ) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.type_as(query) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 2df05cb8eb36..cadcedb98a14 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -67,6 +67,9 @@ is_bitsandbytes_version, is_bs4_available, is_cosmos_guardrail_available, + is_flash_attn_3_available, + is_flash_attn_available, + is_flash_attn_version, is_flax_available, is_ftfy_available, is_gguf_available, @@ -90,6 +93,8 @@ is_peft_version, is_pytorch_retinaface_available, is_safetensors_available, + is_sageattention_available, + is_sageattention_version, is_scipy_available, is_sentencepiece_available, is_tensorboard_available, @@ -108,6 +113,7 @@ is_unidecode_available, is_wandb_available, is_xformers_available, + is_xformers_version, requires_backends, ) from .loading_utils import get_module_from_name, get_submodule_by_name, load_image, load_video diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 7c04287d33ed..f8f04cc03abd 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -41,6 +41,8 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DIFFUSERS_REQUEST_TIMEOUT = 60 +DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") +DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 97bc3f317b32..3933fec29656 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -85,6 +85,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AttentionBackendName(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AuraFlowTransformer2DModel(metaclass=DummyObject): _backends = ["torch"] @@ -1150,6 +1165,10 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +def attention_backend(*args, **kwargs): + requires_backends(attention_backend, ["torch"]) + + def get_constant_schedule(*args, **kwargs): requires_backends(get_constant_schedule, ["torch"]) diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index f7244e97b878..4fe71801e8f9 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -219,6 +219,9 @@ def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[b _better_profanity_available, _better_profanity_version = _is_package_available("better_profanity") _nltk_available, _nltk_version = _is_package_available("nltk") _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail") +_sageattention_available, _sageattention_version = _is_package_available("sageattention") +_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn") +_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3") def is_torch_available(): @@ -377,6 +380,18 @@ def is_hpu_available(): return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch")) +def is_sageattention_available(): + return _sageattention_available + + +def is_flash_attn_available(): + return _flash_attn_available + + +def is_flash_attn_3_available(): + return _flash_attn_3_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -803,6 +818,51 @@ def is_optimum_quanto_version(operation: str, version: str): return compare_versions(parse(_optimum_quanto_version), operation, version) +def is_xformers_version(operation: str, version: str): + """ + Compares the current xformers version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _xformers_available: + return False + return compare_versions(parse(_xformers_version), operation, version) + + +def is_sageattention_version(operation: str, version: str): + """ + Compares the current sageattention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _sageattention_available: + return False + return compare_versions(parse(_sageattention_version), operation, version) + + +def is_flash_attn_version(operation: str, version: str): + """ + Compares the current flash-attention version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _flash_attn_available: + return False + return compare_versions(parse(_flash_attn_version), operation, version) + + def get_objects_from_module(module): """ Returns a dict of object names and values in a module, while skipping private/internal objects