Skip to content

Commit 26efeb9

Browse files
LucasWilkinsonFeiDaLI
authored andcommitted
[BugFix] Potential Fix for FA3 full-cudagraph IMA (vllm-project#25490)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent bb58365 commit 26efeb9

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
194194

195195
self.use_full_cuda_graph = \
196196
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
197+
self.max_cudagraph_size = self.compilation_config.max_capture_size
197198

198199
if self.use_full_cuda_graph and self.aot_schedule:
199-
self.max_cudagraph_size = self.compilation_config.max_capture_size
200-
201200
if self.max_cudagraph_size > 992:
202201
# This condition derives from FA3's internal heuristic.
203202
# TODO(woosuk): Support larger cudagraph sizes.
@@ -259,6 +258,15 @@ def build(self,
259258
self.aot_schedule = False
260259
aot_schedule = False
261260

261+
max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
262+
if self.use_full_cuda_graph and \
263+
num_actual_tokens <= self.max_cudagraph_size:
264+
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
265+
# usage, because the intermediate buffers of size [num_splits,
266+
# num_heads, num_tokens, head_size] are allocated. Therefore,
267+
# we only set num_splits when using cuda graphs.
268+
max_num_splits = self.max_num_splits
269+
262270
def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
263271
max_seq_len, causal):
264272
cache_dtype = self.cache_config.cache_dtype
@@ -281,7 +289,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
281289
page_size=self.block_size,
282290
causal=causal,
283291
window_size=self.aot_sliding_window,
284-
num_splits=self.max_num_splits,
292+
num_splits=max_num_splits,
285293
)
286294
return None
287295

@@ -322,7 +330,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
322330
max_seq_len=max_seq_len,
323331
causal=causal)
324332
# For FA3 + full cudagraph
325-
max_num_splits = 0
326333
if self.use_full_cuda_graph and scheduler_metadata is not None:
327334
n = scheduler_metadata.shape[0]
328335
self.scheduler_metadata[:n] = scheduler_metadata
@@ -333,13 +340,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
333340
self.scheduler_metadata[n:] = 0
334341
scheduler_metadata = self.scheduler_metadata[:n]
335342

336-
if num_actual_tokens <= self.max_cudagraph_size:
337-
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
338-
# usage, because the intermediate buffers of size [num_splits,
339-
# num_heads, num_tokens, head_size] are allocated. Therefore,
340-
# we only set num_splits when using cuda graphs.
341-
max_num_splits = self.max_num_splits
342-
343343
attn_metadata = FlashAttentionMetadata(
344344
num_actual_tokens=num_actual_tokens,
345345
max_query_len=max_query_len,

0 commit comments

Comments
 (0)