diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 955c25f30051..193d3c466f1e 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -662,7 +662,7 @@ def copy_and_call(*args): class ConcreteSizeEntry: runtime_shape: int need_to_compile: bool # the size is in compile_sizes - use_cudagraph: bool # the size is in capture_sizes + use_cudagraph: bool # the size is in cudagraph_capture_sizes compiled: bool = False runnable: Callable = None # type: ignore @@ -709,8 +709,8 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, self.compile_sizes: Set[int] = set( self.compilation_config.compile_sizes) - self.capture_sizes: Set[int] = set( - self.compilation_config.capture_sizes + self.cudagraph_capture_sizes: Set[int] = set( + self.compilation_config.cudagraph_capture_sizes ) if self.compilation_config.use_cudagraph else set() self.first_run_finished = False @@ -728,11 +728,11 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig, # to_be_compiled_sizes tracks the remaining sizes to compile, # and updates during the compilation process, so we need to copy it self.to_be_compiled_sizes: Set[int] = self.compile_sizes.copy() - for shape in self.compile_sizes.union(self.capture_sizes): + for shape in self.compile_sizes.union(self.cudagraph_capture_sizes): self.concrete_size_entries[shape] = ConcreteSizeEntry( runtime_shape=shape, need_to_compile=shape in self.compile_sizes, - use_cudagraph=shape in self.capture_sizes, + use_cudagraph=shape in self.cudagraph_capture_sizes, ) def check_for_ending_compilation(self): diff --git a/vllm/config.py b/vllm/config.py index b0a92b2e2134..69990fa910b3 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2703,10 +2703,11 @@ class CompilationConfig(BaseModel): - use_inductor: whether to use inductor compilation. - False: inductor compilation is not used. graph runs in eager. - True: inductor compilation is used. one graph for symbolic shape - is compiled. In addition, compile for cudagraph sizes that are - in candidate_compile_sizes, using configurations - in inductor_compile_config. - - candidate_compile_sizes: sizes to compile for inductor. + is compiled. In addition, compile for compile_sizes, + using configurations in inductor_compile_config. + - compile_sizes: sizes to compile for inductor. In addition + to integers, it also supports "cudagraph_capture_sizes" to + specify the sizes for cudagraph capture. - inductor_compile_config: additional configurations for inductor. - None: use default configurations. - inductor_passes: additional passes for inductor. It is a dictionary @@ -2734,7 +2735,7 @@ class CompilationConfig(BaseModel): splitting_ops: List[str] = Field(default=None) # type: ignore use_inductor: bool = True - candidate_compile_sizes: Optional[List[int]] = Field(default=None) + compile_sizes: Optional[List[Union[int, str]]] = Field(default=None) inductor_compile_config: Dict = Field(default_factory=dict) inductor_passes: Dict[str, str] = Field(default_factory=dict) @@ -2782,8 +2783,6 @@ def model_post_init(self, __context: Any) -> None: pass_config: PassConfig = Field(default_factory=PassConfig) # not configurable, computed after init - compile_sizes: List[int] = PrivateAttr - capture_sizes: List[int] = PrivateAttr max_capture_size: int = PrivateAttr # optimization: # Intuitively, bs_to_padded_graph_size should be Dict[int, int]. @@ -2909,43 +2908,47 @@ def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: from vllm.compilation.backends import VllmBackend return VllmBackend(vllm_config) - def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): + def init_with_cudagraph_sizes(self, + cudagraph_capture_sizes: List[int]) -> None: """To complete the initialization of config, we need to know the cudagraph sizes.""" if self.cudagraph_capture_sizes is None: - self.capture_sizes = sizes_to_specialize + self.cudagraph_capture_sizes = cudagraph_capture_sizes else: - self.capture_sizes = self.cudagraph_capture_sizes + # de-duplicate the sizes provided by the config + self.cudagraph_capture_sizes = list( + set(self.cudagraph_capture_sizes)) logger.info(("cudagraph sizes specified by model runner" " %s is overridden by config %s"), - sizes_to_specialize, self.cudagraph_capture_sizes) - - if self.candidate_compile_sizes is None: - self.candidate_compile_sizes = [] - self.compile_sizes = [ - x for x in self.candidate_compile_sizes if x in self.capture_sizes - ] - ignored_sizes = [ - x for x in self.candidate_compile_sizes - if x not in self.capture_sizes - ] - if ignored_sizes: - logger.warning(("candidate_compile_sizes %s are ignored " - "because they are not cudagraph capture sizes."), - ignored_sizes) + cudagraph_capture_sizes, self.cudagraph_capture_sizes) + + computed_compile_sizes = [] + if self.compile_sizes is not None: + # de-duplicate the sizes provided by the config + self.compile_sizes = list(set(self.compile_sizes)) + for x in self.compile_sizes: + if isinstance(x, str): + assert x == "cudagraph_capture_sizes", \ + "Unrecognized size type in compile_sizes, " \ + f"expect 'cudagraph_capture_sizes', got {x}" + computed_compile_sizes.extend(self.cudagraph_capture_sizes) + else: + assert isinstance(x, int) + computed_compile_sizes.append(x) + self.compile_sizes = computed_compile_sizes # type: ignore # sort to make sure cudagraph capture sizes are in descending order - self.capture_sizes.sort(reverse=True) - self.max_capture_size = self.capture_sizes[ - 0] if self.capture_sizes else 0 + self.cudagraph_capture_sizes.sort(reverse=True) + self.max_capture_size = self.cudagraph_capture_sizes[ + 0] if self.cudagraph_capture_sizes else 0 # pre-compute the mapping from batch size to padded graph size self.bs_to_padded_graph_size = [ 0 for i in range(self.max_capture_size + 1) ] - for end, start in zip(self.capture_sizes, - self.capture_sizes[1:] + [0]): + for end, start in zip(self.cudagraph_capture_sizes, + self.cudagraph_capture_sizes[1:] + [0]): for bs in range(start, end): if bs == start: self.bs_to_padded_graph_size[bs] = start @@ -3216,14 +3219,14 @@ def _set_cudagraph_sizes(self): However, if users specify the cudagraph capture sizes through compilation config, we will use the specified sizes instead. - In the end, `vllm_config.compilation_config.capture_sizes` will be the - final sizes to capture cudagraph (in descending order). + In the end, `vllm_config.compilation_config.cudagraph_capture_sizes` + will be the final sizes to capture cudagraph (in descending order). During runtime, if batchsize is larger than - `vllm_config.compilation_config.capture_sizes`, + `vllm_config.compilation_config.cudagraph_capture_sizes`, no cudagraph will be used. If the batch size is no larger than - `vllm_config.compilation_config.capture_sizes`, + `vllm_config.compilation_config.cudagraph_capture_sizes`, we can quickly find the padded graph size for a given batch size by looking up `vllm_config.compilation_config.bs_to_padded_graph_size`. """ diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index c8aec8dd3afa..f7ce21d0ae98 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -120,7 +120,8 @@ def __init__(self, labelnames: List[str], vllm_config: VllmConfig): labelnames=labelnames) buckets = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8096] if not vllm_config.model_config.enforce_eager: - buckets = vllm_config.compilation_config.capture_sizes.copy() + buckets = vllm_config.compilation_config.\ + cudagraph_capture_sizes.copy() buckets.sort() self.histogram_iteration_tokens = self._histogram_cls( name="vllm:iteration_tokens_total", diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 2350074c23a5..8fd54d0772da 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,6 +1,6 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast import numpy as np import torch @@ -127,7 +127,8 @@ def __init__( # self.cudagraph_batch_sizes sorts in ascending order. # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( - reversed(self.vllm_config.compilation_config.capture_sizes)) + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) # Cache the device properties. self.device_properties = torch.cuda.get_device_properties(self.device) @@ -826,10 +827,12 @@ def load_model(self) -> None: @torch.inference_mode() def _dummy_run( self, - model: nn.Module, num_tokens: int, - kv_caches: List[torch.Tensor], + kv_caches: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: + model = self.model + if kv_caches is None: + kv_caches = self.kv_caches if self.is_multimodal_model: input_ids = None inputs_embeds = self.inputs_embeds[:num_tokens] @@ -955,8 +958,7 @@ def profile_run(self) -> None: self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) + hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) logits = self.model.compute_logits(hidden_states, None) logits = logits[:self.max_num_tokens] # TODO(woosuk): Consider the memory usage of the sampler. @@ -982,8 +984,8 @@ def capture_model(self) -> None: for num_tokens in reversed(self.cudagraph_batch_sizes): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): - self._dummy_run(self.model, num_tokens, self.kv_caches) - self._dummy_run(self.model, num_tokens, self.kv_caches) + self._dummy_run(num_tokens) + self._dummy_run(num_tokens) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index bd40112aea5e..d2a3dc4b3a3c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -170,6 +170,18 @@ def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: self.model_runner.initialize_kv_cache(kv_cache_config) def compile_or_warm_up_model(self) -> None: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() + if not self.model_config.enforce_eager: + warmup_sizes = [ + x for x in warmup_sizes if x not in + self.vllm_config.compilation_config.cudagraph_capture_sizes + ] + for size in sorted(warmup_sizes, reverse=True): + logger.info("Compile and warming up model for size %d", size) + self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model() # Reset the seed to ensure that the random state is not affected by diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index cb2ff0c934da..901e424d3285 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1244,13 +1244,19 @@ def set_in_profile_run(self): @torch.inference_mode() def profile_run(self) -> None: + max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + self._dummy_run(max_num_batched_tokens, max_num_seqs) + + def _dummy_run(self, + max_num_batched_tokens: int, + max_num_seqs: int = 1) -> None: with self.set_in_profile_run(): # Enable top-k sampling to reflect the accurate memory usage. sampling_params = \ SamplingParams(top_p=0.99, top_k=self.vocab_size - 1) - max_num_batched_tokens = \ - self.scheduler_config.max_num_batched_tokens - max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests # that will have unique loras, an therefore the max amount of memory # consumption create dummy lora request copies from the lora request @@ -1479,13 +1485,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: for virtual_engine in range( self.parallel_config.pipeline_parallel_size): # Only rank 0 should print progress bar during capture - capture_sizes = ( - tqdm( - self.vllm_config.compilation_config.capture_sizes, - desc="Capturing CUDA graph shapes", - ) if get_tensor_model_parallel_rank() == 0 else - self.vllm_config.compilation_config.capture_sizes) - for batch_size in capture_sizes: + cudagraph_capture_sizes = (tqdm( + self.vllm_config.compilation_config. + cudagraph_capture_sizes, + desc="Capturing CUDA graph shapes", + ) if get_tensor_model_parallel_rank() == 0 else + self.vllm_config.compilation_config. + cudagraph_capture_sizes) + for batch_size in cudagraph_capture_sizes: attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 29d62ddda3dc..ff8eb31c6ca5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -288,6 +288,18 @@ def _init_cache_engine(self): self.gpu_cache) def _warm_up_model(self) -> None: + # warm up sizes that are not in cudagraph capture sizes, + # but users still want to compile for better performance, + # e.g. for the max-num-batched token size in chunked prefill. + warmup_sizes = self.vllm_config.compilation_config.compile_sizes.copy() + if not self.model_config.enforce_eager: + warmup_sizes = [ + x for x in warmup_sizes if x not in + self.vllm_config.compilation_config.cudagraph_capture_sizes + ] + for size in sorted(warmup_sizes, reverse=True): + logger.info("Compile and warming up model for size %d", size) + self.model_runner._dummy_run(size) if not self.model_config.enforce_eager: self.model_runner.capture_model(self.gpu_cache) # Reset the seed to ensure that the random state is not affected by