-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Attention Dispatcher #11368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attention Dispatcher #11368
Conversation
supported: flash, flash_varlen, flex, native, sage, sage_varlen, xformers
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
… flux attention processors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting PR! I only left some higher-level comments. My major comment is around having an attention config class instead of environment vars. Or would that be too much for this PR?
For the attention config class (if decided to proceed that route), I was thinking of the following APIs:
attn_config = AttentionConfig(
attn_implementation="...",
enable_gqa=...
)
model.set_attn_config(attn_config)
The environment vars were initially only for my quick testing from CLI instead of changing the code everytime. We can get rid of it completely. The intended API in my mind, and what currently exists in the PR is with context managers: from diffusers import attention_provider
with attention_provider("sage_varlen"):
model(...) Can change once we finalize something |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's looking good 👍🏽 Nice work! Registry makes sense here. Just some minor comments on the initial pass.
Would also add torch NPU backend and XLA flash attention
hidden_states = torch_npu.npu_fusion_attention( |
from torch_xla.experimental.custom_kernel import flash_attention |
I do also think configuring attention without env variables and context manager might be needed. e.g. You want to run the transformer in the pipeline with sageattention but the other components can use regular SDPA. Config object that @sayakpaul suggested makes sense.
@sayakpaul @DN6 How would you recommend we set per-model attention backend? The backend info needs to be propagated to the attention dispatcher when the forward method is called. The easiest way and how I've done it for training/CP is to attach a simple pre-forward hook that sets the backend, cp_mesh, and any other attributes, when the forward method is invoked. If you have recommendations, I'll modify the implementation accordingly. Currently, you need to first replace the calls to from diffusers import attention_backend
with attention_backend("flash_varlen"):
output = transformer(...) If context manager is not used, it defaults to the original behaviour of calling native torch attention. |
I was thinking that upon calling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking really good. I am currently trying to benchmark against FA3 as well while we're at it. I will update this thread once we have the results.
@a-r-r-o-w I was able to run FA3 with your code and here are some results: ExpandSequence lengths: [1024, 4096, 17550, 32768, 32760]
Shape: torch.Size([1, 12, 1024, 128])
TFLOPs: 0.01
===== TFLOPS =====
(flash): 67.79
(native_flash): 68.02
(native_cudnn): 60.44
==========
Shape: torch.Size([1, 12, 4096, 128])
TFLOPs: 0.10
===== TFLOPS =====
(flash): 677.87
(native_flash): 325.90
(native_cudnn): 660.36
==========
Shape: torch.Size([1, 12, 17550, 128])
TFLOPs: 1.89
===== TFLOPS =====
(flash): 740.64
(native_flash): 348.37
(native_cudnn): 626.17
==========
Shape: torch.Size([1, 12, 32768, 128])
TFLOPs: 6.60
===== TFLOPS =====
(flash): 724.79
(native_flash): 363.38
(native_cudnn): 701.14
==========
Shape: torch.Size([1, 12, 32760, 128])
TFLOPs: 6.59
===== TFLOPS =====
(flash): 669.26
(native_flash): 353.29
(native_cudnn): 586.65
========== I can open a PR to your branch for the changes I had to make to make it work. LMK. |
@sayakpaul Super cool, thanks! I hope you didn't face too much trouble with building FA3 😅 I actually already have the required changes for FA3 (and some other things like NPU and XLA) locally. I didn't benchmark yet though so thanks for that, and I can push my changes soon |
It just took time. I used Docker instead of the default env of the cluster. |
Pushed some changes to support FA3, NPU and XLA. They are all marked private since FA3 is a beta release and NPU and XLA are untested. Pytorch's cudnn backend is close to FA3, but in almost all problem shapes the latter is faster, similar to FA2 from source |
@sayakpaul @DN6 Based on our discussion, I've added support for `set_attention_backend("...")` for diffusers native implementationsimport torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.attention_processor import Attention
class MyModel(ModelMixin):
def __init__(self):
super().__init__()
self.attention = Attention(
query_dim=10,
heads=2,
dim_head=5,
)
self.mlp = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 10),
)
def forward(self, x: torch.Tensor):
return x + self.mlp(x + self.attention(x))
dtype = torch.bfloat16
device = "cuda"
model = MyModel().to(device, dtype=dtype)
input = torch.randn(2, 64, 10).to(device, dtype=dtype)
output_native = model(input)
model.set_attention_backend("flash")
output_flash = model(input)
model.set_attention_backend("sage")
output_sage = model(input)
model.set_attention_backend("_native_math")
output_native_math = model(input)
diff1 = torch.abs(output_native - output_flash).max()
diff2 = torch.abs(output_native - output_sage).max()
diff3 = torch.abs(output_native - output_native_math).max()
print(diff1, diff2, diff3) context manager for custom implementationsimport torch
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.attention_dispatch import dispatch_attention_fn, attention_backend
class AttentionProcessor:
def __call__(
self,
attn,
x: torch.Tensor,
) -> torch.Tensor:
q, k, v = (y.unflatten(2, (attn.heads, -1)).permute(0, 2, 1, 3).contiguous() for y in attn.qkv(x).chunk(3, dim=-1))
return attn.o(dispatch_attention_fn(q, k, v).permute(0, 2, 1, 3).flatten(2))
class Attention(torch.nn.Module):
def __init__(self):
super().__init__()
self.heads = 2
self.qkv = torch.nn.Linear(10, 30)
self.o = torch.nn.Linear(10, 10)
self.processor = AttentionProcessor()
def forward(self, x: torch.Tensor):
return self.processor(self, x)
class MyModel(ModelMixin):
def __init__(self):
super().__init__()
self.attention = Attention()
self.mlp = torch.nn.Sequential(
torch.nn.Linear(10, 20),
torch.nn.ReLU(),
torch.nn.Linear(20, 10),
)
def forward(self, x: torch.Tensor):
return x + self.mlp(x + self.attention(x))
dtype = torch.bfloat16
device = "cuda"
model = MyModel().to(device, dtype=dtype)
input = torch.randn(2, 64, 10).to(device, dtype=dtype)
output_native = model(input)
with attention_backend("flash"):
output_flash = model(input)
with attention_backend("sage"):
output_sage = model(input)
with attention_backend("_native_math"):
output_native_math = model(input)
diff1 = torch.abs(output_native - output_flash).max()
diff2 = torch.abs(output_native - output_sage).max()
diff3 = torch.abs(output_native - output_native_math).max()
print(diff1, diff2, diff3) |
This is looking much much better IMO! I think with proper documentation, we can make the differences between the scopes of |
Continued in #11916 |
Usage
attention-only benchmark
Model benchmark
Results: 4090
Results with PyTorch 2.7 stable, CUDA 12.6
Wan
Results: A100
Results with PyTorch 2.7 stable, CUDA 12.2
Wan
cc @DN6 @sayakpaul @yiyixuxu