Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import signal
import time
import uuid
from dataclasses import dataclass
from threading import Thread
from typing import Optional
from typing import Optional, Union
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -292,6 +293,68 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
client.shutdown()


@dataclass
class MyDataclass:
message: str


# Dummy utility function to monkey-patch into engine core.
def echo_dc(
self,
msg: str,
return_list: bool = False,
) -> Union[MyDataclass, list[MyDataclass]]:
print(f"echo dc util function called: {msg}")
# Return dataclass to verify support for returning custom types
# (for which there is special handling to make it work with msgspec).
return [MyDataclass(msg) for _ in range(3)] if return_list \
else MyDataclass(msg)


@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_custom_return(
monkeypatch: pytest.MonkeyPatch):

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

# Must set insecure serialization to allow returning custom types.
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

# Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo_dc", echo_dc, raising=False)

engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)

try:
# Test utility method returning custom / non-native data type.
core_client: AsyncMPClient = client

result = await core_client.call_utility_async(
"echo_dc", "testarg2", False)
assert isinstance(result,
MyDataclass) and result.message == "testarg2"
result = await core_client.call_utility_async(
"echo_dc", "testarg2", True)
assert isinstance(result, list) and all(
isinstance(r, MyDataclass) and r.message == "testarg2"
for r in result)
finally:
client.shutdown()


@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp"), (False, "inproc")],
Expand Down
9 changes: 8 additions & 1 deletion vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,13 @@ def finished(self) -> bool:
return self.finish_reason is not None


class UtilityResult:
"""Wrapper for special handling when serializing/deserializing."""

def __init__(self, r: Any = None):
self.result = r


class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
Expand All @@ -132,7 +139,7 @@ class UtilityOutput(

# Non-None implies the call failed, result should be None.
failure_message: Optional[str] = None
result: Any = None
result: Optional[UtilityResult] = None


class EngineCoreOutputs(
Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput)
UtilityOutput, UtilityResult)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor
Expand Down Expand Up @@ -710,8 +710,8 @@ def _handle_client_request(self, request_type: EngineCoreRequestType,
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
output.result = method(
*self._convert_msgspec_args(method, args))
result = method(*self._convert_msgspec_args(method, args))
output.result = UtilityResult(result)
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (f"Call to {method_name} method"
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/engine/core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,8 @@ def _process_utility_output(output: UtilityOutput,
if output.failure_message is not None:
future.set_exception(Exception(output.failure_message))
else:
future.set_result(output.result)
assert output.result is not None
future.set_result(output.result.result)


class SyncMPClient(MPClient):
Expand Down
44 changes: 44 additions & 0 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import dataclasses
import importlib
import pickle
from collections.abc import Sequence
from inspect import isclass
from types import FunctionType
from typing import Any, Optional, Union

import cloudpickle
import msgspec
import numpy as np
import torch
import zmq
Expand All @@ -22,6 +24,7 @@
MultiModalFlatField, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField, NestedTensors)
from vllm.v1.engine import UtilityResult

logger = init_logger(__name__)

Expand All @@ -46,6 +49,10 @@ def _log_insecure_serialization_warning():
"VLLM_ALLOW_INSECURE_SERIALIZATION=1")


def _typestr(t: type):
return t.__module__, t.__qualname__


class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.

Expand Down Expand Up @@ -122,6 +129,18 @@ def enc_hook(self, obj: Any) -> Any:
for itemlist in mm._items_by_modality.values()
for item in itemlist]

if isinstance(obj, UtilityResult):
result = obj.result
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION or result is None:
return None, result
# Since utility results are not strongly typed, we also encode
# the type (or a list of types in the case it's a list) to
# help with correct msgspec deserialization.
cls = result.__class__
return _typestr(cls) if cls is not list else [
_typestr(type(v)) for v in result
], result

if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise TypeError(f"Object of type {type(obj)} is not serializable"
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
Expand Down Expand Up @@ -237,8 +256,33 @@ def dec_hook(self, t: type, obj: Any) -> Any:
k: self._decode_nested_tensors(v)
for k, v in obj.items()
})
if t is UtilityResult:
return self._decode_utility_result(obj)
return obj

def _decode_utility_result(self, obj: Any) -> UtilityResult:
result_type, result = obj
if result_type is not None:
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
"be set to use custom utility result types")
assert isinstance(result_type, list)
if len(result_type) == 2 and isinstance(result_type[0], str):
result = self._convert_result(result_type, result)
else:
assert isinstance(result, list)
result = [
self._convert_result(rt, r)
for rt, r in zip(result_type, result)
]
return UtilityResult(result)

def _convert_result(self, result_type: Sequence[str], result: Any):
mod_name, name = result_type
mod = importlib.import_module(mod_name)
result_type = getattr(mod, name)
return msgspec.convert(result, result_type, dec_hook=self.dec_hook)

def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
# zero-copy decode. We assume the ndarray will not be kept around,
Expand Down