|
1 | 1 | --- |
2 | 2 | layout: blog_detail |
3 | 3 | title: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention" |
4 | | -author: "Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong" |
| 4 | +author: "Team PyTorch: Driss Guessous, Yanbo Liang, Joy Dong, Horace He" |
5 | 5 | --- |
6 | 6 |
|
7 | 7 | {:style="width:100%"} |
@@ -131,7 +131,7 @@ Alibi is similar to relative positional encodings with one exception \- it has a |
131 | 131 | alibi_bias = generate_alibi_bias() # [num_heads] |
132 | 132 |
|
133 | 133 | def alibi(score, b, h, q_idx, kv_idx): |
134 | | - bias = alibi_bias[h] * (q_idx - kv_idx) |
| 134 | + bias = alibi_bias[h] * (kv_idx - q_idx) |
135 | 135 | return score + bias |
136 | 136 | ``` |
137 | 137 |
|
@@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx): |
218 | 218 | return causal_mask & window_mask |
219 | 219 |
|
220 | 220 | # If you want to be cute... |
221 | | -from torch.nn.attention import or_masks |
| 221 | +from torch.nn.attention import and_masks |
222 | 222 |
|
223 | 223 | def sliding_window(b, h, q_idx, kv_idx) |
224 | 224 | return q_idx - kv_idx <= SLIDING_WINDOW |
225 | 225 |
|
226 | | -sliding_window_causal = or_masks(causal_mask, sliding_window) |
| 226 | +sliding_window_causal = and_masks(causal_mask, sliding_window) |
227 | 227 | ``` |
228 | 228 |
|
229 | 229 | We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity. |
@@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti |
479 | 479 | - The Jax team's work on SplashAttention |
480 | 480 | - Philippe Tillet and Keren Zhou for helping us with Triton |
481 | 481 | - Ali Hassani for discussions on neighborhood attention |
482 | | -- Everybody who's complained about attention kernels not supporting their favorite attention variant :) |
| 482 | +- Everybody who's complained about attention kernels not supporting their favorite attention variant :) |
0 commit comments