@@ -2711,10 +2711,11 @@ class CompilationConfig(BaseModel):
2711
2711
- use_inductor: whether to use inductor compilation.
2712
2712
- False: inductor compilation is not used. graph runs in eager.
2713
2713
- True: inductor compilation is used. one graph for symbolic shape
2714
- is compiled. In addition, compile for cudagraph sizes that are
2715
- in candidate_compile_sizes, using configurations
2716
- in inductor_compile_config.
2717
- - candidate_compile_sizes: sizes to compile for inductor.
2714
+ is compiled. In addition, compile for compile_sizes,
2715
+ using configurations in inductor_compile_config.
2716
+ - compile_sizes: sizes to compile for inductor. In addition
2717
+ to integers, it also supports "cudagraph_capture_sizes" to
2718
+ specify the sizes for cudagraph capture.
2718
2719
- inductor_compile_config: additional configurations for inductor.
2719
2720
- None: use default configurations.
2720
2721
- inductor_passes: additional passes for inductor. It is a dictionary
@@ -2742,7 +2743,7 @@ class CompilationConfig(BaseModel):
2742
2743
splitting_ops : List [str ] = Field (default = None ) # type: ignore
2743
2744
2744
2745
use_inductor : bool = True
2745
- candidate_compile_sizes : Optional [List [int ]] = Field (default = None )
2746
+ compile_sizes : Optional [List [Union [ int , str ] ]] = Field (default = None )
2746
2747
inductor_compile_config : Dict = Field (default_factory = dict )
2747
2748
inductor_passes : Dict [str , str ] = Field (default_factory = dict )
2748
2749
@@ -2790,8 +2791,6 @@ def model_post_init(self, __context: Any) -> None:
2790
2791
pass_config : PassConfig = Field (default_factory = PassConfig )
2791
2792
2792
2793
# not configurable, computed after init
2793
- compile_sizes : List [int ] = PrivateAttr
2794
- capture_sizes : List [int ] = PrivateAttr
2795
2794
max_capture_size : int = PrivateAttr
2796
2795
local_cache_dir : str = PrivateAttr # local cache dir for each rank
2797
2796
# optimization:
@@ -2918,43 +2917,47 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
2918
2917
from vllm .compilation .backends import VllmBackend
2919
2918
return VllmBackend (vllm_config )
2920
2919
2921
- def init_with_cudagraph_sizes (self , sizes_to_specialize : List [int ]):
2920
+ def init_with_cudagraph_sizes (self ,
2921
+ cudagraph_capture_sizes : List [int ]) -> None :
2922
2922
"""To complete the initialization of config,
2923
2923
we need to know the cudagraph sizes."""
2924
2924
2925
2925
if self .cudagraph_capture_sizes is None :
2926
- self .capture_sizes = sizes_to_specialize
2926
+ self .cudagraph_capture_sizes = cudagraph_capture_sizes
2927
2927
else :
2928
- self .capture_sizes = self .cudagraph_capture_sizes
2928
+ # de-duplicate the sizes provided by the config
2929
+ self .cudagraph_capture_sizes = list (
2930
+ set (self .cudagraph_capture_sizes ))
2929
2931
logger .info (("cudagraph sizes specified by model runner"
2930
2932
" %s is overridden by config %s" ),
2931
- sizes_to_specialize , self .cudagraph_capture_sizes )
2932
-
2933
- if self .candidate_compile_sizes is None :
2934
- self .candidate_compile_sizes = []
2935
- self .compile_sizes = [
2936
- x for x in self .candidate_compile_sizes if x in self .capture_sizes
2937
- ]
2938
- ignored_sizes = [
2939
- x for x in self .candidate_compile_sizes
2940
- if x not in self .capture_sizes
2941
- ]
2942
- if ignored_sizes :
2943
- logger .warning (("candidate_compile_sizes %s are ignored "
2944
- "because they are not cudagraph capture sizes." ),
2945
- ignored_sizes )
2933
+ cudagraph_capture_sizes , self .cudagraph_capture_sizes )
2934
+
2935
+ computed_compile_sizes = []
2936
+ if self .compile_sizes is not None :
2937
+ # de-duplicate the sizes provided by the config
2938
+ self .compile_sizes = list (set (self .compile_sizes ))
2939
+ for x in self .compile_sizes :
2940
+ if isinstance (x , str ):
2941
+ assert x == "cudagraph_capture_sizes" , \
2942
+ "Unrecognized size type in compile_sizes, " \
2943
+ f"expect 'cudagraph_capture_sizes', got { x } "
2944
+ computed_compile_sizes .extend (self .cudagraph_capture_sizes )
2945
+ else :
2946
+ assert isinstance (x , int )
2947
+ computed_compile_sizes .append (x )
2948
+ self .compile_sizes = computed_compile_sizes # type: ignore
2946
2949
2947
2950
# sort to make sure cudagraph capture sizes are in descending order
2948
- self .capture_sizes .sort (reverse = True )
2949
- self .max_capture_size = self .capture_sizes [
2950
- 0 ] if self .capture_sizes else 0
2951
+ self .cudagraph_capture_sizes .sort (reverse = True )
2952
+ self .max_capture_size = self .cudagraph_capture_sizes [
2953
+ 0 ] if self .cudagraph_capture_sizes else 0
2951
2954
2952
2955
# pre-compute the mapping from batch size to padded graph size
2953
2956
self .bs_to_padded_graph_size = [
2954
2957
0 for i in range (self .max_capture_size + 1 )
2955
2958
]
2956
- for end , start in zip (self .capture_sizes ,
2957
- self .capture_sizes [1 :] + [0 ]):
2959
+ for end , start in zip (self .cudagraph_capture_sizes ,
2960
+ self .cudagraph_capture_sizes [1 :] + [0 ]):
2958
2961
for bs in range (start , end ):
2959
2962
if bs == start :
2960
2963
self .bs_to_padded_graph_size [bs ] = start
@@ -3225,14 +3228,14 @@ def _set_cudagraph_sizes(self):
3225
3228
However, if users specify the cudagraph capture sizes through
3226
3229
compilation config, we will use the specified sizes instead.
3227
3230
3228
- In the end, `vllm_config.compilation_config.capture_sizes` will be the
3229
- final sizes to capture cudagraph (in descending order).
3231
+ In the end, `vllm_config.compilation_config.cudagraph_capture_sizes`
3232
+ will be the final sizes to capture cudagraph (in descending order).
3230
3233
3231
3234
During runtime, if batchsize is larger than
3232
- `vllm_config.compilation_config.capture_sizes `,
3235
+ `vllm_config.compilation_config.cudagraph_capture_sizes `,
3233
3236
no cudagraph will be used.
3234
3237
If the batch size is no larger than
3235
- `vllm_config.compilation_config.capture_sizes `,
3238
+ `vllm_config.compilation_config.cudagraph_capture_sizes `,
3236
3239
we can quickly find the padded graph size for a given batch size by
3237
3240
looking up `vllm_config.compilation_config.bs_to_padded_graph_size`.
3238
3241
"""
0 commit comments