diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index d564cf9988ea..a2e18f970bec 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -194,10 +194,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.use_full_cuda_graph = \ self.compilation_config.cudagraph_mode.has_full_cudagraphs() + self.max_cudagraph_size = self.compilation_config.max_capture_size if self.use_full_cuda_graph and self.aot_schedule: - self.max_cudagraph_size = self.compilation_config.max_capture_size - if self.max_cudagraph_size > 992: # This condition derives from FA3's internal heuristic. # TODO(woosuk): Support larger cudagraph sizes. @@ -259,6 +258,15 @@ def build(self, self.aot_schedule = False aot_schedule = False + max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible + if self.use_full_cuda_graph and \ + num_actual_tokens <= self.max_cudagraph_size: + # NOTE(woosuk): Setting num_splits > 1 may increase the memory + # usage, because the intermediate buffers of size [num_splits, + # num_heads, num_tokens, head_size] are allocated. Therefore, + # we only set num_splits when using cuda graphs. + max_num_splits = self.max_num_splits + def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len, causal): cache_dtype = self.cache_config.cache_dtype @@ -281,7 +289,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, page_size=self.block_size, causal=causal, window_size=self.aot_sliding_window, - num_splits=self.max_num_splits, + num_splits=max_num_splits, ) return None @@ -322,7 +330,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, max_seq_len=max_seq_len, causal=causal) # For FA3 + full cudagraph - max_num_splits = 0 if self.use_full_cuda_graph and scheduler_metadata is not None: n = scheduler_metadata.shape[0] self.scheduler_metadata[:n] = scheduler_metadata @@ -333,13 +340,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, self.scheduler_metadata[n:] = 0 scheduler_metadata = self.scheduler_metadata[:n] - if num_actual_tokens <= self.max_cudagraph_size: - # NOTE(woosuk): Setting num_splits > 1 may increase the memory - # usage, because the intermediate buffers of size [num_splits, - # num_heads, num_tokens, head_size] are allocated. Therefore, - # we only set num_splits when using cuda graphs. - max_num_splits = self.max_num_splits - attn_metadata = FlashAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len,