-
Notifications
You must be signed in to change notification settings - Fork 530
[CORE] concurrent partial prefills #2356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,14 +16,16 @@ | |
| # | ||
|
|
||
| from dataclasses import dataclass, fields | ||
| from typing import Type, Union | ||
| from typing import Optional, Type, Union | ||
|
|
||
| from vllm.config import SchedulerConfig | ||
|
|
||
|
|
||
| @dataclass | ||
| class AscendSchedulerConfig(SchedulerConfig): | ||
| enable_chunked_prefill: bool = False | ||
| max_long_partial_prefills: Optional[Union[int, float]] = None | ||
| long_prefill_token_threshold: Optional[Union[int, float]] = None | ||
| policy: str = "fcfs" | ||
| num_scheduler_steps: int = 1 | ||
| scheduler_cls: Union[str, Type[object]] = ( | ||
|
|
@@ -41,6 +43,8 @@ def initialize_from_config( | |
| } | ||
| # Override default values into original SchedulerConfig | ||
| scheduler_config["enable_chunked_prefill"] = False | ||
| scheduler_config["max_long_partial_prefills"] = None | ||
| scheduler_config["long_prefill_token_threshold"] = None | ||
| scheduler_config["policy"] = "fcfs" | ||
| scheduler_config["num_scheduler_steps"] = 1 | ||
| scheduler_config["scheduler_cls"] = ( | ||
|
|
@@ -55,6 +59,17 @@ def __post_init__(self) -> None: | |
| self.max_num_encoder_input_tokens = self.max_num_batched_tokens | ||
| self.encoder_cache_size = self.max_num_batched_tokens | ||
| self.chunked_prefill_enabled = self.enable_chunked_prefill | ||
| # concurrent partial prefills. Default is inf | ||
| if self.max_long_partial_prefills is None: | ||
| self.max_long_partial_prefills = float('inf') | ||
| self.long_prefill_token_threshold = float('inf') | ||
| else: | ||
| if self.long_prefill_token_threshold is None: | ||
| self.long_prefill_token_threshold = \ | ||
| int(self.max_model_len * 0.04) | ||
|
Comment on lines
+67
to
+69
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The default value calculation for if self.long_prefill_token_threshold is None:
self.long_prefill_token_threshold = max(1, int(self.max_model_len * 0.04)) |
||
|
|
||
| assert (self.max_long_partial_prefills > 0) | ||
| assert (self.long_prefill_token_threshold > 0) | ||
| if self.policy != "fcfs": | ||
| raise NotImplementedError( | ||
| f"currently AscendScheduler only supports fcfs policy, got {self.policy}" | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -75,6 +75,10 @@ def schedule(self) -> SchedulerOutput: | |||||||
| # and put back at the head of the waiting queue later | ||||||||
| skipped_waiting_requests: deque[Request] = deque() | ||||||||
|
|
||||||||
| # Skip long prompt requests in prefill stage. | ||||||||
| # long_prefill_budget is float('inf') if not use. | ||||||||
| long_prefill_budget = self.vllm_config.scheduler_config.max_long_partial_prefills | ||||||||
|
|
||||||||
| # Schedule prefill requests first. | ||||||||
| while self.waiting and token_budget > 0: | ||||||||
| if len(self.running) == self.max_num_running_reqs: | ||||||||
|
|
@@ -173,6 +177,11 @@ def skip_cur_request(): | |||||||
| skip_cur_request() | ||||||||
| continue | ||||||||
|
|
||||||||
| if num_new_tokens > self.vllm_config.scheduler_config.long_prefill_token_threshold \ | ||||||||
| and long_prefill_budget <= 0: | ||||||||
| skip_cur_request() | ||||||||
| continue | ||||||||
|
|
||||||||
| new_blocks = self.kv_cache_manager.allocate_slots( | ||||||||
| request, | ||||||||
| num_new_tokens + num_external_computed_tokens, | ||||||||
|
|
@@ -222,6 +231,7 @@ def skip_cur_request(): | |||||||
| # Update request info. | ||||||||
| num_scheduled_tokens[request.request_id] = num_new_tokens | ||||||||
| token_budget -= num_new_tokens | ||||||||
| long_prefill_budget -= 1 | ||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||
| request.status = RequestStatus.RUNNING | ||||||||
| request.num_computed_tokens = num_computed_tokens | ||||||||
| # Count the number of prefix cached tokens. | ||||||||
|
|
||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value for
long_prefill_token_thresholdis documented asFalse, which is inconsistent with its typeUnion[int, float]and the implementation. In the code, its default isNone, and it's then either set tofloat('inf')or calculated based onmax_model_lenifmax_long_partial_prefillsis set. UsingFalsehere is misleading for users. Please update the default value toNoneto align with the implementation.