-
-
Notifications
You must be signed in to change notification settings - Fork 10.7k
[WIP][V1] Ray executor #10725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
[WIP][V1] Ray executor #10725
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
from vllm.engine.arg_utils import EngineArgs | ||
from vllm.engine.metrics_types import StatLoggerBase | ||
from vllm.envs import VLLM_ENABLE_V1_MULTIPROCESSING | ||
from vllm.executor.ray_utils import initialize_ray_cluster | ||
from vllm.inputs import INPUT_REGISTRY, InputRegistry, PromptType | ||
from vllm.logger import init_logger | ||
from vllm.lora.request import LoRARequest | ||
|
@@ -18,6 +19,7 @@ | |
from vllm.v1.engine.detokenizer import Detokenizer | ||
from vllm.v1.engine.processor import Processor | ||
from vllm.v1.executor.gpu_executor import GPUExecutor | ||
from vllm.v1.executor.ray_gpu_executor import RayGPUExecutor | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
@@ -99,7 +101,11 @@ def from_engine_args( | |
|
||
@classmethod | ||
def _get_executor_cls(cls, vllm_config: VllmConfig): | ||
return GPUExecutor | ||
if vllm_config.parallel_config.distributed_executor_backend == "ray": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We also need to have this in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remeber to hear AsyncLLM will be removed or something. Can you sync with them before supporting it? |
||
initialize_ray_cluster(vllm_config.parallel_config) | ||
return RayGPUExecutor | ||
else: | ||
return GPUExecutor | ||
|
||
def stop_remote_worker_execution_loop(self) -> None: | ||
raise NotImplementedError("TP not implemented yet.") | ||
|
@@ -158,8 +164,6 @@ def step(self) -> List[RequestOutput]: | |
|
||
return request_outputs | ||
|
||
# TODO(rob): Can we get rid of these? | ||
|
||
def get_model_config(self): | ||
return self.model_config | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from abc import abstractmethod | ||
from typing import Any, Optional, Tuple | ||
|
||
from vllm.config import VllmConfig | ||
from vllm.logger import init_logger | ||
from vllm.v1.core.scheduler import SchedulerOutput | ||
from vllm.v1.outputs import ModelRunnerOutput | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class DistributedGPUExecutor: | ||
"""Abstract superclass of multi-GPU executor implementations.""" | ||
|
||
def __init__(self, vllm_config: VllmConfig): | ||
self.vllm_config = vllm_config | ||
self.model_config = vllm_config.model_config | ||
self.cache_config = vllm_config.cache_config | ||
self.lora_config = vllm_config.lora_config | ||
self.load_config = vllm_config.load_config | ||
self.parallel_config = vllm_config.parallel_config | ||
self.scheduler_config = vllm_config.scheduler_config | ||
self.device_config = vllm_config.device_config | ||
self.speculative_config = vllm_config.speculative_config | ||
self.prompt_adapter_config = vllm_config.prompt_adapter_config | ||
self.observability_config = vllm_config.observability_config | ||
|
||
def determine_num_available_blocks(self) -> Tuple[int, int]: | ||
"""Determine the number of available KV blocks. | ||
|
||
This invokes `determine_num_available_blocks` on each worker and takes | ||
the min of the results, guaranteeing that the selected cache sizes are | ||
compatible with all workers. | ||
|
||
Returns: | ||
- tuple[num_gpu_blocks, num_cpu_blocks] | ||
""" | ||
# Get the maximum number of blocks that can be allocated on GPU and CPU. | ||
num_blocks = self._run_workers("determine_num_available_blocks") | ||
|
||
# Since we use a shared centralized controller, we take the minimum | ||
# number of blocks across all workers to make sure all the memory | ||
# operators can be applied to all workers. | ||
num_gpu_blocks = min(b[0] for b in num_blocks) | ||
return num_gpu_blocks, 0 | ||
|
||
def initialize_cache(self, num_gpu_blocks: int) -> None: | ||
"""Initialize the KV cache in all workers. | ||
""" | ||
# NOTE: This is logged in the executor because there can be >1 worker | ||
# with other executors. We could log in the engine level, but work | ||
# remains to abstract away the device for non-GPU configurations. | ||
logger.info("# GPU blocks: %d", num_gpu_blocks) | ||
self._run_workers("initialize_cache", num_gpu_blocks) | ||
self._run_workers("compile_or_warm_up_model") | ||
|
||
@abstractmethod | ||
def execute_model( | ||
self, | ||
scheduler_output: SchedulerOutput, | ||
) -> ModelRunnerOutput: | ||
raise NotImplementedError | ||
|
||
def save_sharded_state( | ||
self, | ||
path: str, | ||
pattern: Optional[str] = None, | ||
max_size: Optional[int] = None, | ||
) -> None: | ||
self._run_workers("save_sharded_state", | ||
path=path, | ||
pattern=pattern, | ||
max_size=max_size) | ||
|
||
@abstractmethod | ||
def _run_workers( | ||
self, | ||
method: str, | ||
*args, | ||
async_run_tensor_parallel_workers_only: bool = False, | ||
max_concurrent_workers: Optional[int] = None, | ||
**kwargs, | ||
) -> Any: | ||
"""Runs the given method on all workers. | ||
|
||
Args: | ||
async_run_tensor_parallel_workers_only: If True the method will be | ||
run only in the remote TP workers, not the driver worker. | ||
It will also be run asynchronously and return a list of futures | ||
rather than blocking on the results. | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def check_health(self) -> None: | ||
raise NotImplementedError |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not needed anymore for V1?