|
1 | 1 | from typing import Dict, List, Set, Tuple
|
2 | 2 |
|
3 |
| -from vllm.executor.executor_base import ExecutorBase |
| 3 | +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase |
4 | 4 | from vllm.logger import init_logger
|
5 | 5 | from vllm.lora.request import LoRARequest
|
6 | 6 | from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
| 7 | +from vllm.utils import make_async |
7 | 8 |
|
8 | 9 | logger = init_logger(__name__)
|
9 | 10 |
|
@@ -73,3 +74,22 @@ def check_health(self) -> None:
|
73 | 74 | # NeuronExecutor will always be healthy as long as
|
74 | 75 | # it's running.
|
75 | 76 | return
|
| 77 | + |
| 78 | + |
| 79 | +class NeuronExecutorAsync(NeuronExecutor, ExecutorAsyncBase): |
| 80 | + |
| 81 | + async def execute_model_async( |
| 82 | + self, |
| 83 | + seq_group_metadata_list: List[SequenceGroupMetadata], |
| 84 | + blocks_to_swap_in: Dict[int, int], |
| 85 | + blocks_to_swap_out: Dict[int, int], |
| 86 | + blocks_to_copy: Dict[int, List[int]], |
| 87 | + ) -> SamplerOutput: |
| 88 | + output = await make_async(self.driver_worker.execute_model)( |
| 89 | + seq_group_metadata_list=seq_group_metadata_list, ) |
| 90 | + return output |
| 91 | + |
| 92 | + async def check_health_async(self) -> None: |
| 93 | + # NeuronExecutor will always be healthy as long as |
| 94 | + # it's running. |
| 95 | + return |
0 commit comments