From 7d0cae3b72ae397e3161f898c4a66a5f970ead48 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 13:57:07 +0800 Subject: [PATCH 01/15] add adapter support Signed-off-by: youkaichao --- vllm/config.py | 4 ++++ vllm/engine/arg_utils.py | 8 ++++++++ vllm/worker/worker_base.py | 13 +++++++++++++ 3 files changed, 25 insertions(+) diff --git a/vllm/config.py b/vllm/config.py index f87d2d6e82cf..93a74a9d9c7e 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1366,6 +1366,7 @@ class ParallelConfig: # will be determined based on the platform. worker_cls: str = "auto" sd_worker_cls: str = "auto" + worker_adapter_cls: str = "" # world_size is TPxPP, it affects the number of workers we create. world_size: int = field(init=False) @@ -1523,6 +1524,9 @@ def _verify_args(self) -> None: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") + assert isinstance(self.worker_adapter_cls, str), ( + "worker_adapter_cls must be a string (qualified class name).") + @dataclass class SchedulerConfig: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 989eb4dbfd14..f163bafc813f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1015,6 +1015,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=str, default="auto", help='The worker class to use for distributed execution.') + parser.add_argument( + '--worker-adapter-cls', + type=str, + default="", + help='The worker adapter class on top of the worker cls, ' + 'it is useful if you just want to add new functions to the worker ' + 'class without changing the existing functions.') parser.add_argument( "--generation-config", type=nullable_str, @@ -1209,6 +1216,7 @@ def create_engine_config(self, ray_workers_use_nsight=self.ray_workers_use_nsight, distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, + worker_adapter_cls=self.worker_adapter_cls, ) max_model_len = model_config.max_model_len diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 7cc1562a5bce..d7317c911839 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -558,10 +558,23 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: worker_class = resolve_obj_by_qualname( self.vllm_config.parallel_config.worker_cls) else: + logger.warning( + "passing worker_cls as a class object is strongly deprecated," + " as the serialization of class objects can be tricky and" + " error-prone. To be safe, please keep the class in a separate" + " module and pass the qualified name of the class as a string." + ) assert isinstance(self.vllm_config.parallel_config.worker_cls, bytes) worker_class = cloudpickle.loads( self.vllm_config.parallel_config.worker_cls) + if self.vllm_config.parallel_config.worker_adapter_cls: + worker_adapter_class = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_adapter_cls) + if worker_adapter_class not in worker_class.__bases__: + # dynamically inherit the worker adapter class + worker_class.__bases__ = worker_class.__bases__ + ( + worker_adapter_class, ) with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization self.worker = worker_class(**kwargs) From c7cfb5327899eba05e56e6886d4d5bcb953863f5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 14:04:55 +0800 Subject: [PATCH 02/15] check duplicate Signed-off-by: youkaichao --- vllm/worker/worker_base.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d7317c911839..10ec45660070 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -573,6 +573,13 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: self.vllm_config.parallel_config.worker_adapter_cls) if worker_adapter_class not in worker_class.__bases__: # dynamically inherit the worker adapter class + for attr in dir(worker_adapter_class): + if attr.startswith("__"): + continue + assert not hasattr(worker_class, attr), ( + f"Worker class {worker_class} already has an attribute" + f" {attr}, which conflicts with the worker" + f" adapter class {worker_adapter_class}.") worker_class.__bases__ = worker_class.__bases__ + ( worker_adapter_class, ) with set_current_vllm_config(self.vllm_config): From 52c9f0eb144b4773206c0b347e19e481143c2faf Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 14:12:47 +0800 Subject: [PATCH 03/15] update colocate Signed-off-by: youkaichao --- .buildkite/test-pipeline.yaml | 6 +- examples/offline_inference/rlhf.py | 66 +-------------------- examples/offline_inference/rlhf_colocate.py | 36 +---------- 3 files changed, 8 insertions(+), 100 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d0f5c94ffd8d..d251efcbff54 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -143,8 +143,10 @@ steps: - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - - python3 ../examples/offline_inference/rlhf.py - - RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py + - pushd ../examples/offline_inference + - python3 rlhf.py + - RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - popd - label: Metrics, Tracing Test # 10min num_gpus: 2 diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index 172d18cbce2f..50152ba4d094 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -18,72 +18,11 @@ import torch from ray.util.placement_group import placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from rlhf_utils import stateless_init_process_group from transformers import AutoModelForCausalLM from vllm import LLM, SamplingParams from vllm.utils import get_ip, get_open_port -from vllm.worker.worker import Worker - - -def stateless_init_process_group(master_address, master_port, rank, world_size, - device): - """ - vLLM provides `StatelessProcessGroup` to create a process group - without considering the global process group in torch.distributed. - It is recommended to create `StatelessProcessGroup`, and then initialize - the data-plane communication (NCCL) between external (train processes) - and vLLM workers. - """ - from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator - from vllm.distributed.utils import StatelessProcessGroup - pg = StatelessProcessGroup.create(host=master_address, - port=master_port, - rank=rank, - world_size=world_size) - pynccl = PyNcclCommunicator(pg, device=device) - return pynccl - - -class MyWorker(Worker): - """ - The `MyWorker` class inherits from `Worker` to provide custom functions. - For simplicity, we define the `MyWorker` class in this self-contained - script. Normally, we should define the `MyWorker` class in a separate - file and pass the qualified name of the class to the `worker_cls` - parameter. - """ - - def init_weight_update_group(self, master_address, master_port, - rank_offset, world_size): - from vllm.distributed.parallel_state import get_world_group - rank = get_world_group().rank + rank_offset - self.model_update_group = stateless_init_process_group( - master_address, - master_port, - rank, - world_size, - self.device, - ) - - def update_weight(self, name, dtype, shape): - weight = torch.empty(shape, dtype=dtype, device="cuda") - self.model_update_group.broadcast(weight, - src=0, - stream=torch.cuda.current_stream()) - - self.model_runner.model.load_weights(weights=[(name, weight)]) - - del weight - - def check_weights_changed(self): - """ - Check if the weights are updated to 0. - """ - weights_updated = True - for name, p in self.model_runner.model.named_parameters(): - weights_updated = weights_updated and torch.allclose( - p, torch.zeros_like(p)) - return weights_updated class MyLLM(LLM): @@ -129,7 +68,7 @@ def __init__(self, *args, **kwargs): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_cls=MyWorker, + worker_adapter_cls="rlhf_utils.WorkerAdapter", tensor_parallel_size=2, distributed_executor_backend="ray", ) @@ -159,6 +98,7 @@ def __init__(self, *args, **kwargs): handle = llm.collective_rpc.remote("init_weight_update_group", args=(master_address, master_port, 1, 3)) + model_update_group = stateless_init_process_group(master_address, master_port, 0, 3, torch.device("cuda:0")) ray.get(handle) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 15dc7edc18ad..0e4b4eaab612 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -17,40 +17,6 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from vllm import LLM -from vllm.worker.worker import Worker - - -class MyWorker(Worker): - - def report_device_id(self) -> str: - from vllm.platforms import current_platform - self.device_uuid = current_platform.get_device_uuid(self.device.index) - return self.device_uuid - - def update_weights_from_ipc_handles(self, ipc_handles): - handles = ipc_handles[self.device_uuid] - device_id = self.device.index - weights = [] - for name, handle in handles.items(): - func, args = handle - list_args = list(args) - # the key is to change device id to the current device id - # in case two processes have different CUDA_VISIBLE_DEVICES - list_args[6] = device_id - tensor = func(*list_args) - weights.append((name, tensor)) - self.model_runner.model.load_weights(weights=weights) - torch.cuda.synchronize() - - def check_weights_changed(self): - """ - Check if the weights are updated to 0. - """ - weights_updated = True - for name, p in self.model_runner.model.named_parameters(): - weights_updated = weights_updated and torch.allclose( - p, torch.zeros_like(p)) - return weights_updated class MyLLM(LLM): @@ -150,7 +116,7 @@ def get_weight_ipc_handles(self): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_cls=MyWorker, + worker_adapter_cls="rlhf_utils.ColocateWorkerAdapter", tensor_parallel_size=2, distributed_executor_backend="ray", gpu_memory_utilization=0.4, From f7bb5d7cf0866bb363692d371e2322636ecc2a5b Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 14:19:20 +0800 Subject: [PATCH 04/15] add files Signed-off-by: youkaichao --- examples/offline_inference/rlhf_utils.py | 105 +++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 examples/offline_inference/rlhf_utils.py diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py new file mode 100644 index 000000000000..7a82a41c5e80 --- /dev/null +++ b/examples/offline_inference/rlhf_utils.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + + +def stateless_init_process_group(master_address, master_port, rank, world_size, + device): + """ + vLLM provides `StatelessProcessGroup` to create a process group + without considering the global process group in torch.distributed. + It is recommended to create `StatelessProcessGroup`, and then initialize + the data-plane communication (NCCL) between external (train processes) + and vLLM workers. + """ + from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator + from vllm.distributed.utils import StatelessProcessGroup + pg = StatelessProcessGroup.create(host=master_address, + port=master_port, + rank=rank, + world_size=world_size) + pynccl = PyNcclCommunicator(pg, device=device) + return pynccl + + +class WorkerAdapter: + """ + The class for vLLM's worker to inherit from. + By defining an adapter, the code can work no matter what is + the underlying worker class. This way, the code can be compatible + with both vLLM V0 and V1. + NOTE: we define this class in a separate module, and the main module + should pass the full qualified name as `worker_adapter_cls` argument. + """ + + def init_weight_update_group(self, master_address, master_port, + rank_offset, world_size): + from vllm.distributed.parallel_state import get_world_group + rank = get_world_group().rank + rank_offset + self.model_update_group = stateless_init_process_group( + master_address, + master_port, + rank, + world_size, + self.device, + ) + + def update_weight(self, name, dtype, shape): + weight = torch.empty(shape, dtype=dtype, device="cuda") + self.model_update_group.broadcast(weight, + src=0, + stream=torch.cuda.current_stream()) + + self.model_runner.model.load_weights(weights=[(name, weight)]) + + del weight + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated + + +class ColocateWorkerAdapter: + """ + The class for vLLM's worker to inherit from, in the colocate setting. + By defining an adapter, the code can work no matter what is + the underlying worker class. This way, the code can be compatible + with both vLLM V0 and V1. + NOTE: we define this class in a separate module, and the main module + should pass the full qualified name as `worker_adapter_cls` argument. + """ + + def report_device_id(self) -> str: + from vllm.platforms import current_platform + self.device_uuid = current_platform.get_device_uuid(self.device.index) + return self.device_uuid + + def update_weights_from_ipc_handles(self, ipc_handles): + handles = ipc_handles[self.device_uuid] + device_id = self.device.index + weights = [] + for name, handle in handles.items(): + func, args = handle + list_args = list(args) + # the key is to change device id to the current device id + # in case two processes have different CUDA_VISIBLE_DEVICES + list_args[6] = device_id + tensor = func(*list_args) + weights.append((name, tensor)) + self.model_runner.model.load_weights(weights=weights) + torch.cuda.synchronize() + + def check_weights_changed(self): + """ + Check if the weights are updated to 0. + """ + weights_updated = True + for name, p in self.model_runner.model.named_parameters(): + weights_updated = weights_updated and torch.allclose( + p, torch.zeros_like(p)) + return weights_updated From 1525a3ac78590fc67294312182fa003ab7721dd7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 14:21:43 +0800 Subject: [PATCH 05/15] add in engine args Signed-off-by: youkaichao --- vllm/engine/arg_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f163bafc813f..02a9b1af8791 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -202,6 +202,7 @@ class EngineArgs: override_pooler_config: Optional[PoolerConfig] = None compilation_config: Optional[CompilationConfig] = None worker_cls: str = "auto" + worker_adapter_cls: str = "" kv_transfer_config: Optional[KVTransferConfig] = None From 08416a67eafed41200d334e09a096bb257921eca Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 14:28:49 +0800 Subject: [PATCH 06/15] add logging Signed-off-by: youkaichao --- vllm/worker/worker_base.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 10ec45660070..cfe4da0250a6 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -571,6 +571,9 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: if self.vllm_config.parallel_config.worker_adapter_cls: worker_adapter_class = resolve_obj_by_qualname( self.vllm_config.parallel_config.worker_adapter_cls) + logger.info( + "Injecting worker adapter class %s" + "into worker class %s", worker_adapter_class, worker_class) if worker_adapter_class not in worker_class.__bases__: # dynamically inherit the worker adapter class for attr in dir(worker_adapter_class): From f5db641474b916ff446828918703148cff5ab5f4 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 14:31:26 +0800 Subject: [PATCH 07/15] add logging Signed-off-by: youkaichao --- vllm/worker/worker_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index cfe4da0250a6..21a3191bb4f5 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -572,8 +572,8 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: worker_adapter_class = resolve_obj_by_qualname( self.vllm_config.parallel_config.worker_adapter_cls) logger.info( - "Injecting worker adapter class %s" - "into worker class %s", worker_adapter_class, worker_class) + "Injecting %s into %s for extended collective_rpc call", + worker_adapter_class, worker_class) if worker_adapter_class not in worker_class.__bases__: # dynamically inherit the worker adapter class for attr in dir(worker_adapter_class): From dd7b3c37a1e1583be1582da92d39ee88ea00a9ec Mon Sep 17 00:00:00 2001 From: youkaichao Date: Tue, 4 Mar 2025 14:58:51 +0800 Subject: [PATCH 08/15] comments Signed-off-by: youkaichao --- vllm/worker/worker_base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 21a3191bb4f5..d65b422f52db 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -575,7 +575,7 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: "Injecting %s into %s for extended collective_rpc call", worker_adapter_class, worker_class) if worker_adapter_class not in worker_class.__bases__: - # dynamically inherit the worker adapter class + # check any conflicts between worker and worker_adapter for attr in dir(worker_adapter_class): if attr.startswith("__"): continue @@ -583,6 +583,7 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: f"Worker class {worker_class} already has an attribute" f" {attr}, which conflicts with the worker" f" adapter class {worker_adapter_class}.") + # dynamically inherit the worker adapter class worker_class.__bases__ = worker_class.__bases__ + ( worker_adapter_class, ) with set_current_vllm_config(self.vllm_config): From 731d4f623e07edc23ead4820d4775f34dad9127f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 6 Mar 2025 12:46:44 +0800 Subject: [PATCH 09/15] use mixin Signed-off-by: youkaichao --- examples/offline_inference/rlhf.py | 2 +- examples/offline_inference/rlhf_colocate.py | 2 +- examples/offline_inference/rlhf_utils.py | 4 ++-- vllm/config.py | 7 ++++--- vllm/engine/arg_utils.py | 8 ++++---- vllm/worker/worker_base.py | 16 ++++++++-------- 6 files changed, 20 insertions(+), 19 deletions(-) diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index 50152ba4d094..bc3c0439a13f 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -68,7 +68,7 @@ def __init__(self, *args, **kwargs): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_adapter_cls="rlhf_utils.WorkerAdapter", + worker_mixin_cls="rlhf_utils.WorkerAdapter", tensor_parallel_size=2, distributed_executor_backend="ray", ) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 0e4b4eaab612..2ecf809ee9f9 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -116,7 +116,7 @@ def get_weight_ipc_handles(self): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_adapter_cls="rlhf_utils.ColocateWorkerAdapter", + worker_mixin_cls="rlhf_utils.ColocateWorkerAdapter", tensor_parallel_size=2, distributed_executor_backend="ray", gpu_memory_utilization=0.4, diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index 7a82a41c5e80..6527b3d3dd46 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -28,7 +28,7 @@ class WorkerAdapter: the underlying worker class. This way, the code can be compatible with both vLLM V0 and V1. NOTE: we define this class in a separate module, and the main module - should pass the full qualified name as `worker_adapter_cls` argument. + should pass the full qualified name as `worker_mixin_cls` argument. """ def init_weight_update_group(self, master_address, master_port, @@ -71,7 +71,7 @@ class ColocateWorkerAdapter: the underlying worker class. This way, the code can be compatible with both vLLM V0 and V1. NOTE: we define this class in a separate module, and the main module - should pass the full qualified name as `worker_adapter_cls` argument. + should pass the full qualified name as `worker_mixin_cls` argument. """ def report_device_id(self) -> str: diff --git a/vllm/config.py b/vllm/config.py index 788f3ff6fca6..7dc5ba1a5285 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1366,7 +1366,7 @@ class ParallelConfig: # will be determined based on the platform. worker_cls: str = "auto" sd_worker_cls: str = "auto" - worker_adapter_cls: str = "" + worker_mixin_cls: str = "" # world_size is TPxPP, it affects the number of workers we create. world_size: int = field(init=False) @@ -1524,8 +1524,9 @@ def _verify_args(self) -> None: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") - assert isinstance(self.worker_adapter_cls, str), ( - "worker_adapter_cls must be a string (qualified class name).") + assert isinstance( + self.worker_mixin_cls, + str), ("worker_mixin_cls must be a string (qualified class name).") @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 02a9b1af8791..fe6650a49f8a 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -202,7 +202,7 @@ class EngineArgs: override_pooler_config: Optional[PoolerConfig] = None compilation_config: Optional[CompilationConfig] = None worker_cls: str = "auto" - worker_adapter_cls: str = "" + worker_mixin_cls: str = "" kv_transfer_config: Optional[KVTransferConfig] = None @@ -1017,10 +1017,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default="auto", help='The worker class to use for distributed execution.') parser.add_argument( - '--worker-adapter-cls', + '--worker-mixin-cls', type=str, default="", - help='The worker adapter class on top of the worker cls, ' + help='The worker mixin class on top of the worker cls, ' 'it is useful if you just want to add new functions to the worker ' 'class without changing the existing functions.') parser.add_argument( @@ -1217,7 +1217,7 @@ def create_engine_config(self, ray_workers_use_nsight=self.ray_workers_use_nsight, distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, - worker_adapter_cls=self.worker_adapter_cls, + worker_mixin_cls=self.worker_mixin_cls, ) max_model_len = model_config.max_model_len diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index d65b422f52db..796cabfd02be 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -568,24 +568,24 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: bytes) worker_class = cloudpickle.loads( self.vllm_config.parallel_config.worker_cls) - if self.vllm_config.parallel_config.worker_adapter_cls: - worker_adapter_class = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_adapter_cls) + if self.vllm_config.parallel_config.worker_mixin_cls: + worker_mixin_cls = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_mixin_cls) logger.info( "Injecting %s into %s for extended collective_rpc call", - worker_adapter_class, worker_class) - if worker_adapter_class not in worker_class.__bases__: + worker_mixin_cls, worker_class) + if worker_mixin_cls not in worker_class.__bases__: # check any conflicts between worker and worker_adapter - for attr in dir(worker_adapter_class): + for attr in dir(worker_mixin_cls): if attr.startswith("__"): continue assert not hasattr(worker_class, attr), ( f"Worker class {worker_class} already has an attribute" f" {attr}, which conflicts with the worker" - f" adapter class {worker_adapter_class}.") + f" mixin class {worker_mixin_cls}.") # dynamically inherit the worker adapter class worker_class.__bases__ = worker_class.__bases__ + ( - worker_adapter_class, ) + worker_mixin_cls, ) with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization self.worker = worker_class(**kwargs) From ae2d12f80d66a9993f49082b9651dcf21023f50f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 6 Mar 2025 12:48:01 +0800 Subject: [PATCH 10/15] use mixin Signed-off-by: youkaichao --- examples/offline_inference/rlhf.py | 2 +- examples/offline_inference/rlhf_colocate.py | 2 +- examples/offline_inference/rlhf_utils.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index bc3c0439a13f..f660947b07b2 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -68,7 +68,7 @@ def __init__(self, *args, **kwargs): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_mixin_cls="rlhf_utils.WorkerAdapter", + worker_mixin_cls="rlhf_utils.WorkerMixin", tensor_parallel_size=2, distributed_executor_backend="ray", ) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 2ecf809ee9f9..410e7aaad0cb 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -116,7 +116,7 @@ def get_weight_ipc_handles(self): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_mixin_cls="rlhf_utils.ColocateWorkerAdapter", + worker_mixin_cls="rlhf_utils.ColocateWorkerMixin", tensor_parallel_size=2, distributed_executor_backend="ray", gpu_memory_utilization=0.4, diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index 6527b3d3dd46..93591cc567f0 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -21,7 +21,7 @@ def stateless_init_process_group(master_address, master_port, rank, world_size, return pynccl -class WorkerAdapter: +class WorkerMixin: """ The class for vLLM's worker to inherit from. By defining an adapter, the code can work no matter what is @@ -64,7 +64,7 @@ def check_weights_changed(self): return weights_updated -class ColocateWorkerAdapter: +class ColocateWorkerMixin: """ The class for vLLM's worker to inherit from, in the colocate setting. By defining an adapter, the code can work no matter what is From 3661da3683002cfe7c08746cb1c20b4c456f288f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 6 Mar 2025 12:48:53 +0800 Subject: [PATCH 11/15] use mixin Signed-off-by: youkaichao --- examples/offline_inference/rlhf_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index 93591cc567f0..7cdcc511c18c 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -24,7 +24,7 @@ def stateless_init_process_group(master_address, master_port, rank, world_size, class WorkerMixin: """ The class for vLLM's worker to inherit from. - By defining an adapter, the code can work no matter what is + By defining a mixin class, the code can work no matter what is the underlying worker class. This way, the code can be compatible with both vLLM V0 and V1. NOTE: we define this class in a separate module, and the main module @@ -67,7 +67,7 @@ def check_weights_changed(self): class ColocateWorkerMixin: """ The class for vLLM's worker to inherit from, in the colocate setting. - By defining an adapter, the code can work no matter what is + By defining a mixin class, the code can work no matter what is the underlying worker class. This way, the code can be compatible with both vLLM V0 and V1. NOTE: we define this class in a separate module, and the main module From 1b43146741e9fe29bda147c19b7287d11c3cc85c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 6 Mar 2025 12:51:48 +0800 Subject: [PATCH 12/15] use mixin Signed-off-by: youkaichao --- vllm/worker/worker_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 796cabfd02be..275d349b6e91 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -575,7 +575,7 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: "Injecting %s into %s for extended collective_rpc call", worker_mixin_cls, worker_class) if worker_mixin_cls not in worker_class.__bases__: - # check any conflicts between worker and worker_adapter + # check any conflicts between worker and worker_mixin_cls for attr in dir(worker_mixin_cls): if attr.startswith("__"): continue @@ -583,7 +583,7 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: f"Worker class {worker_class} already has an attribute" f" {attr}, which conflicts with the worker" f" mixin class {worker_mixin_cls}.") - # dynamically inherit the worker adapter class + # dynamically inherit the worker mixin class worker_class.__bases__ = worker_class.__bases__ + ( worker_mixin_cls, ) with set_current_vllm_config(self.vllm_config): From 00c6adc33e988d51f030fd81d4b3a929c6720fe3 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 6 Mar 2025 12:53:26 +0800 Subject: [PATCH 13/15] polish logging Signed-off-by: youkaichao --- vllm/worker/worker_base.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 275d349b6e91..43c55d22753a 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -571,9 +571,7 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: if self.vllm_config.parallel_config.worker_mixin_cls: worker_mixin_cls = resolve_obj_by_qualname( self.vllm_config.parallel_config.worker_mixin_cls) - logger.info( - "Injecting %s into %s for extended collective_rpc call", - worker_mixin_cls, worker_class) + extended_calls = [] if worker_mixin_cls not in worker_class.__bases__: # check any conflicts between worker and worker_mixin_cls for attr in dir(worker_mixin_cls): @@ -583,9 +581,13 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: f"Worker class {worker_class} already has an attribute" f" {attr}, which conflicts with the worker" f" mixin class {worker_mixin_cls}.") + extended_calls.append(attr) # dynamically inherit the worker mixin class worker_class.__bases__ = worker_class.__bases__ + ( worker_mixin_cls, ) + logger.info( + "Injected %s into %s for extended collective_rpc calls %s", + worker_mixin_cls, worker_class, extended_calls) with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization self.worker = worker_class(**kwargs) From 86743da2ee584b1a1fe8e89ec8ebefa1e853e5f8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 7 Mar 2025 00:16:45 +0800 Subject: [PATCH 14/15] rename to worker_extension_cls Signed-off-by: youkaichao --- examples/offline_inference/rlhf.py | 2 +- examples/offline_inference/rlhf_colocate.py | 2 +- examples/offline_inference/rlhf_utils.py | 12 +++++------ vllm/config.py | 7 +++---- vllm/engine/arg_utils.py | 8 +++---- vllm/worker/worker_base.py | 23 +++++++++++---------- 6 files changed, 27 insertions(+), 27 deletions(-) diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index f660947b07b2..e2dec14269ad 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -68,7 +68,7 @@ def __init__(self, *args, **kwargs): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_mixin_cls="rlhf_utils.WorkerMixin", + worker_extension_cls="rlhf_utils.WorkerMixin", tensor_parallel_size=2, distributed_executor_backend="ray", ) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 410e7aaad0cb..41c40b98f04a 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -116,7 +116,7 @@ def get_weight_ipc_handles(self): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_mixin_cls="rlhf_utils.ColocateWorkerMixin", + worker_extension_cls="rlhf_utils.ColocateWorkerMixin", tensor_parallel_size=2, distributed_executor_backend="ray", gpu_memory_utilization=0.4, diff --git a/examples/offline_inference/rlhf_utils.py b/examples/offline_inference/rlhf_utils.py index 7cdcc511c18c..11b73b7c4a0a 100644 --- a/examples/offline_inference/rlhf_utils.py +++ b/examples/offline_inference/rlhf_utils.py @@ -21,14 +21,14 @@ def stateless_init_process_group(master_address, master_port, rank, world_size, return pynccl -class WorkerMixin: +class WorkerExtension: """ The class for vLLM's worker to inherit from. - By defining a mixin class, the code can work no matter what is + By defining an extension class, the code can work no matter what is the underlying worker class. This way, the code can be compatible with both vLLM V0 and V1. NOTE: we define this class in a separate module, and the main module - should pass the full qualified name as `worker_mixin_cls` argument. + should pass the full qualified name as `worker_extension_cls` argument. """ def init_weight_update_group(self, master_address, master_port, @@ -64,14 +64,14 @@ def check_weights_changed(self): return weights_updated -class ColocateWorkerMixin: +class ColocateWorkerExtension: """ The class for vLLM's worker to inherit from, in the colocate setting. - By defining a mixin class, the code can work no matter what is + By defining an extension class, the code can work no matter what is the underlying worker class. This way, the code can be compatible with both vLLM V0 and V1. NOTE: we define this class in a separate module, and the main module - should pass the full qualified name as `worker_mixin_cls` argument. + should pass the full qualified name as `worker_extension_cls` argument. """ def report_device_id(self) -> str: diff --git a/vllm/config.py b/vllm/config.py index 7dc5ba1a5285..9b84d0405dc9 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1366,7 +1366,7 @@ class ParallelConfig: # will be determined based on the platform. worker_cls: str = "auto" sd_worker_cls: str = "auto" - worker_mixin_cls: str = "" + worker_extension_cls: str = "" # world_size is TPxPP, it affects the number of workers we create. world_size: int = field(init=False) @@ -1524,9 +1524,8 @@ def _verify_args(self) -> None: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") - assert isinstance( - self.worker_mixin_cls, - str), ("worker_mixin_cls must be a string (qualified class name).") + assert isinstance(self.worker_extension_cls, str), ( + "worker_extension_cls must be a string (qualified class name).") @dataclass diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fe6650a49f8a..d033acff5b0d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -202,7 +202,7 @@ class EngineArgs: override_pooler_config: Optional[PoolerConfig] = None compilation_config: Optional[CompilationConfig] = None worker_cls: str = "auto" - worker_mixin_cls: str = "" + worker_extension_cls: str = "" kv_transfer_config: Optional[KVTransferConfig] = None @@ -1017,10 +1017,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default="auto", help='The worker class to use for distributed execution.') parser.add_argument( - '--worker-mixin-cls', + '--worker-extension-cls', type=str, default="", - help='The worker mixin class on top of the worker cls, ' + help='The worker extension class on top of the worker cls, ' 'it is useful if you just want to add new functions to the worker ' 'class without changing the existing functions.') parser.add_argument( @@ -1217,7 +1217,7 @@ def create_engine_config(self, ray_workers_use_nsight=self.ray_workers_use_nsight, distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, - worker_mixin_cls=self.worker_mixin_cls, + worker_extension_cls=self.worker_extension_cls, ) max_model_len = model_config.max_model_len diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 43c55d22753a..e5662e69343c 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -568,26 +568,27 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: bytes) worker_class = cloudpickle.loads( self.vllm_config.parallel_config.worker_cls) - if self.vllm_config.parallel_config.worker_mixin_cls: - worker_mixin_cls = resolve_obj_by_qualname( - self.vllm_config.parallel_config.worker_mixin_cls) + if self.vllm_config.parallel_config.worker_extension_cls: + worker_extension_cls = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_extension_cls) extended_calls = [] - if worker_mixin_cls not in worker_class.__bases__: - # check any conflicts between worker and worker_mixin_cls - for attr in dir(worker_mixin_cls): + if worker_extension_cls not in worker_class.__bases__: + # check any conflicts between worker and worker_extension_cls + for attr in dir(worker_extension_cls): if attr.startswith("__"): continue assert not hasattr(worker_class, attr), ( f"Worker class {worker_class} already has an attribute" f" {attr}, which conflicts with the worker" - f" mixin class {worker_mixin_cls}.") - extended_calls.append(attr) - # dynamically inherit the worker mixin class + f" extension class {worker_extension_cls}.") + if callable(getattr(worker_extension_cls, attr)): + extended_calls.append(attr) + # dynamically inherit the worker extension class worker_class.__bases__ = worker_class.__bases__ + ( - worker_mixin_cls, ) + worker_extension_cls, ) logger.info( "Injected %s into %s for extended collective_rpc calls %s", - worker_mixin_cls, worker_class, extended_calls) + worker_extension_cls, worker_class, extended_calls) with set_current_vllm_config(self.vllm_config): # To make vLLM config available during worker initialization self.worker = worker_class(**kwargs) From 6d9f76bb719a86dedb9404669d0fab8176b80a93 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Fri, 7 Mar 2025 00:18:27 +0800 Subject: [PATCH 15/15] rename Signed-off-by: youkaichao --- examples/offline_inference/rlhf.py | 2 +- examples/offline_inference/rlhf_colocate.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/rlhf.py b/examples/offline_inference/rlhf.py index e2dec14269ad..b0418c092ca3 100644 --- a/examples/offline_inference/rlhf.py +++ b/examples/offline_inference/rlhf.py @@ -68,7 +68,7 @@ def __init__(self, *args, **kwargs): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_extension_cls="rlhf_utils.WorkerMixin", + worker_extension_cls="rlhf_utils.WorkerExtension", tensor_parallel_size=2, distributed_executor_backend="ray", ) diff --git a/examples/offline_inference/rlhf_colocate.py b/examples/offline_inference/rlhf_colocate.py index 41c40b98f04a..3ceac0fa2e20 100644 --- a/examples/offline_inference/rlhf_colocate.py +++ b/examples/offline_inference/rlhf_colocate.py @@ -116,7 +116,7 @@ def get_weight_ipc_handles(self): )(MyLLM).remote( model="facebook/opt-125m", enforce_eager=True, - worker_extension_cls="rlhf_utils.ColocateWorkerMixin", + worker_extension_cls="rlhf_utils.ColocateWorkerExtension", tensor_parallel_size=2, distributed_executor_backend="ray", gpu_memory_utilization=0.4,