1
1
import asyncio
2
- from typing import Any , List , Optional
2
+ from typing import Any , Callable , List , Optional , Union
3
+
4
+ import cloudpickle
3
5
4
6
from vllm .executor .executor_base import DistributedExecutorBase
5
7
from vllm .executor .multiproc_worker_utils import (
9
11
from vllm .model_executor .layers .sampler import SamplerOutput
10
12
from vllm .sequence import ExecuteModelRequest
11
13
from vllm .utils import (_run_task_with_lock , get_distributed_init_method ,
12
- get_ip , get_open_port , make_async )
14
+ get_ip , get_open_port , make_async , run_method )
13
15
from vllm .worker .worker_base import WorkerWrapperBase
14
16
15
17
logger = init_logger (__name__ )
@@ -107,7 +109,7 @@ def _driver_execute_model(
107
109
108
110
def _run_workers (
109
111
self ,
110
- method : str ,
112
+ method : Union [ str , Callable ] ,
111
113
* args ,
112
114
async_run_tensor_parallel_workers_only : bool = False ,
113
115
max_concurrent_workers : Optional [int ] = None ,
@@ -121,6 +123,11 @@ def _run_workers(
121
123
It will also be run asynchronously and return a list of futures
122
124
rather than blocking on the results.
123
125
"""
126
+ if isinstance (method , str ):
127
+ sent_method = method
128
+ else :
129
+ sent_method = cloudpickle .dumps (method )
130
+ del method
124
131
125
132
if max_concurrent_workers :
126
133
raise NotImplementedError (
@@ -129,18 +136,18 @@ def _run_workers(
129
136
if async_run_tensor_parallel_workers_only :
130
137
# Run only non-driver workers and just return futures.
131
138
return [
132
- worker .execute_method (method , * args , ** kwargs )
139
+ worker .execute_method (sent_method , * args , ** kwargs )
133
140
for worker in self .non_driver_workers
134
141
]
135
142
136
143
# Start all remote workers first.
137
144
worker_outputs = [
138
- worker .execute_method (method , * args , ** kwargs )
145
+ worker .execute_method (sent_method , * args , ** kwargs )
139
146
for worker in self .workers
140
147
]
141
148
142
- driver_worker_method = getattr (self .driver_worker , method )
143
- driver_worker_output = driver_worker_method ( * args , ** kwargs )
149
+ driver_worker_output = run_method (self .driver_worker , sent_method ,
150
+ args , kwargs )
144
151
145
152
# Get the results of the workers.
146
153
return [driver_worker_output
0 commit comments