Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_distributed_inference_npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
llm = LLM(
model="Qwen/Qwen2.5-0.5B-Instruct",
tensor_parallel_size=2,
distributed_executor_backend="mp",
distributed_executor_backend="ray",
trust_remote_code=True,
)

Expand Down
22 changes: 22 additions & 0 deletions vllm_ascend/patch/ray_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch_npu # noqa: F401
import vllm
from vllm.executor.ray_utils import RayWorkerWrapper

if RayWorkerWrapper is not None:

class NPURayWorkerWrapper(RayWorkerWrapper):
"""Importing torch_npu in other Ray processes through an empty class and
a monkey patch.

When Ray performs a remote call, it serializes the Task or Actor and passes
it to the Worker process, where it is deserialized and executed.

If no patch is applied, the default code of the RayWorkerWrapper provided
by vLLM is used, which does not import torch_npu, causing an error in the
Worker process.
See https://github.com/vllm-project/vllm-ascend/pull/92.
"""

pass

vllm.executor.ray_utils.RayWorkerWrapper = NPURayWorkerWrapper
3 changes: 3 additions & 0 deletions vllm_ascend/platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def mem_get_info(cls) -> Tuple[int, int]:

@classmethod
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# RayWorkerWrapper monkey patch when setup
from vllm_ascend.patch import ray_patch # noqa: F401

parallel_config = vllm_config.parallel_config
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "vllm_ascend.worker.NPUWorker"
Expand Down