@@ -194,10 +194,9 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
194194
195195 self .use_full_cuda_graph = \
196196 self .compilation_config .cudagraph_mode .has_full_cudagraphs ()
197+ self .max_cudagraph_size = self .compilation_config .max_capture_size
197198
198199 if self .use_full_cuda_graph and self .aot_schedule :
199- self .max_cudagraph_size = self .compilation_config .max_capture_size
200-
201200 if self .max_cudagraph_size > 992 :
202201 # This condition derives from FA3's internal heuristic.
203202 # TODO(woosuk): Support larger cudagraph sizes.
@@ -259,6 +258,15 @@ def build(self,
259258 self .aot_schedule = False
260259 aot_schedule = False
261260
261+ max_num_splits = 0 # 0 means use FA3's heuristics, not CG compatible
262+ if self .use_full_cuda_graph and \
263+ num_actual_tokens <= self .max_cudagraph_size :
264+ # NOTE(woosuk): Setting num_splits > 1 may increase the memory
265+ # usage, because the intermediate buffers of size [num_splits,
266+ # num_heads, num_tokens, head_size] are allocated. Therefore,
267+ # we only set num_splits when using cuda graphs.
268+ max_num_splits = self .max_num_splits
269+
262270 def schedule (batch_size , cu_query_lens , max_query_len , seqlens ,
263271 max_seq_len , causal ):
264272 cache_dtype = self .cache_config .cache_dtype
@@ -281,7 +289,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
281289 page_size = self .block_size ,
282290 causal = causal ,
283291 window_size = self .aot_sliding_window ,
284- num_splits = self . max_num_splits ,
292+ num_splits = max_num_splits ,
285293 )
286294 return None
287295
@@ -322,7 +330,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
322330 max_seq_len = max_seq_len ,
323331 causal = causal )
324332 # For FA3 + full cudagraph
325- max_num_splits = 0
326333 if self .use_full_cuda_graph and scheduler_metadata is not None :
327334 n = scheduler_metadata .shape [0 ]
328335 self .scheduler_metadata [:n ] = scheduler_metadata
@@ -333,13 +340,6 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
333340 self .scheduler_metadata [n :] = 0
334341 scheduler_metadata = self .scheduler_metadata [:n ]
335342
336- if num_actual_tokens <= self .max_cudagraph_size :
337- # NOTE(woosuk): Setting num_splits > 1 may increase the memory
338- # usage, because the intermediate buffers of size [num_splits,
339- # num_heads, num_tokens, head_size] are allocated. Therefore,
340- # we only set num_splits when using cuda graphs.
341- max_num_splits = self .max_num_splits
342-
343343 attn_metadata = FlashAttentionMetadata (
344344 num_actual_tokens = num_actual_tokens ,
345345 max_query_len = max_query_len ,
0 commit comments