Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 202 additions & 1 deletion tests/v1/engine/test_engine_core_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import uuid
from dataclasses import dataclass
from threading import Thread
from typing import Optional, Union
from typing import Any, Optional, Union
from unittest.mock import MagicMock

import pytest
Expand Down Expand Up @@ -331,6 +331,46 @@ def echo_dc(
return [val for _ in range(3)] if return_list else val


# Dummy utility function to test dict serialization with custom types.
def echo_dc_dict(
self,
msg: str,
return_dict: bool = False,
) -> Union[MyDataclass, dict[str, MyDataclass]]:
print(f"echo dc dict util function called: {msg}")
val = None if msg is None else MyDataclass(msg)
# Return dict of dataclasses to verify support for returning dicts
# with custom value types.
if return_dict:
return {"key1": val, "key2": val, "key3": val}
else:
return val


# Dummy utility function to test nested structures with custom types.
def echo_dc_nested(
self,
msg: str,
structure_type: str = "list_of_dicts",
) -> Any:
print(f"echo dc nested util function called: {msg}, "
f"structure: {structure_type}")
val = None if msg is None else MyDataclass(msg)

if structure_type == "list_of_dicts": # noqa
# Return list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
return [{"a": val, "b": val}, {"c": val, "d": val}]
elif structure_type == "dict_of_lists":
# Return dict of lists: {"list1": [val, val], "list2": [val, val]}
return {"list1": [val, val], "list2": [val, val]}
elif structure_type == "deep_nested":
# Return deeply nested: {"outer": [{"inner": [val, val]},
# {"inner": [val]}]}
return {"outer": [{"inner": [val, val]}, {"inner": [val]}]}
else:
return val


@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_custom_return(
monkeypatch: pytest.MonkeyPatch):
Expand Down Expand Up @@ -384,6 +424,167 @@ async def test_engine_core_client_util_method_custom_return(
client.shutdown()


@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_custom_dict_return(
monkeypatch: pytest.MonkeyPatch):

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

# Must set insecure serialization to allow returning custom types.
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

# Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo_dc_dict", echo_dc_dict, raising=False)

engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)

try:
# Test utility method returning custom / non-native data type.
core_client: AsyncMPClient = client

# Test single object return
result = await core_client.call_utility_async(
"echo_dc_dict", "testarg3", False)
assert isinstance(result,
MyDataclass) and result.message == "testarg3"

# Test dict return with custom value types
result = await core_client.call_utility_async(
"echo_dc_dict", "testarg3", True)
assert isinstance(result, dict) and len(result) == 3
for key, val in result.items():
assert key in ["key1", "key2", "key3"]
assert isinstance(val,
MyDataclass) and val.message == "testarg3"

# Test returning dict with None values
result = await core_client.call_utility_async(
"echo_dc_dict", None, True)
assert isinstance(result, dict) and len(result) == 3
for key, val in result.items():
assert key in ["key1", "key2", "key3"]
assert val is None

finally:
client.shutdown()


@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_util_method_nested_structures(
monkeypatch: pytest.MonkeyPatch):

with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

# Must set insecure serialization to allow returning custom types.
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")

# Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo_dc_nested", echo_dc_nested, raising=False)

engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)

with set_default_torch_num_threads(1):
client = EngineCoreClient.make_client(
multiprocess_mode=True,
asyncio_mode=True,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=True,
)

try:
core_client: AsyncMPClient = client

# Test list of dicts: [{"a": val, "b": val}, {"c": val, "d": val}]
result = await core_client.call_utility_async(
"echo_dc_nested", "nested1", "list_of_dicts")
assert isinstance(result, list) and len(result) == 2
for i, item in enumerate(result):
assert isinstance(item, dict)
if i == 0:
assert "a" in item and "b" in item
assert isinstance(
item["a"],
MyDataclass) and item["a"].message == "nested1"
assert isinstance(
item["b"],
MyDataclass) and item["b"].message == "nested1"
else:
assert "c" in item and "d" in item
assert isinstance(
item["c"],
MyDataclass) and item["c"].message == "nested1"
assert isinstance(
item["d"],
MyDataclass) and item["d"].message == "nested1"

# Test dict of lists: {"list1": [val, val], "list2": [val, val]}
result = await core_client.call_utility_async(
"echo_dc_nested", "nested2", "dict_of_lists")
assert isinstance(result, dict) and len(result) == 2
assert "list1" in result and "list2" in result
for key, lst in result.items():
assert isinstance(lst, list) and len(lst) == 2
for item in lst:
assert isinstance(
item, MyDataclass) and item.message == "nested2"

# Test deeply nested: {"outer": [{"inner": [val, val]},
# {"inner": [val]}]}
result = await core_client.call_utility_async(
"echo_dc_nested", "nested3", "deep_nested")
assert isinstance(result, dict) and "outer" in result
outer_list = result["outer"]
assert isinstance(outer_list, list) and len(outer_list) == 2

# First dict in outer list should have "inner" with 2 items
inner_dict1 = outer_list[0]
assert isinstance(inner_dict1, dict) and "inner" in inner_dict1
inner_list1 = inner_dict1["inner"]
assert isinstance(inner_list1, list) and len(inner_list1) == 2
for item in inner_list1:
assert isinstance(item,
MyDataclass) and item.message == "nested3"

# Second dict in outer list should have "inner" with 1 item
inner_dict2 = outer_list[1]
assert isinstance(inner_dict2, dict) and "inner" in inner_dict2
inner_list2 = inner_dict2["inner"]
assert isinstance(inner_list2, list) and len(inner_list2) == 1
assert isinstance(
inner_list2[0],
MyDataclass) and inner_list2[0].message == "nested3"

# Test with None values in nested structures
result = await core_client.call_utility_async(
"echo_dc_nested", None, "list_of_dicts")
assert isinstance(result, list) and len(result) == 2
for item in result:
assert isinstance(item, dict)
for val in item.values():
assert val is None

finally:
client.shutdown()


@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp"), (False, "inproc")],
Expand Down
60 changes: 44 additions & 16 deletions vllm/v1/serial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from collections.abc import Sequence
from inspect import isclass
from types import FunctionType
from typing import Any, Optional, Union
from typing import Any, Callable, Optional, Union

import cloudpickle
import msgspec
Expand Down Expand Up @@ -59,6 +59,42 @@ def _typestr(val: Any) -> Optional[tuple[str, str]]:
return t.__module__, t.__qualname__


def _encode_type_info_recursive(obj: Any) -> Any:
"""Recursively encode type information for nested structures of
lists/dicts."""
if obj is None:
return None
if type(obj) is list:
return [_encode_type_info_recursive(item) for item in obj]
if type(obj) is dict:
return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
return _typestr(obj)


def _decode_type_info_recursive(
type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any],
Any]) -> Any:
"""Recursively decode type information for nested structures of
lists/dicts."""
if type_info is None:
return data
if isinstance(type_info, dict):
assert isinstance(data, dict)
return {
k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
for k in type_info
}
if isinstance(type_info, list) and (
# Exclude serialized tensors/numpy arrays.
len(type_info) != 2 or not isinstance(type_info[0], str)):
assert isinstance(data, list)
return [
_decode_type_info_recursive(ti, d, convert_fn)
for ti, d in zip(type_info, data)
]
return convert_fn(type_info, data)


class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.

Expand Down Expand Up @@ -129,12 +165,10 @@ def enc_hook(self, obj: Any) -> Any:
result = obj.result
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
return None, result
# Since utility results are not strongly typed, we also encode
# the type (or a list of types in the case it's a list) to
# help with correct msgspec deserialization.
return _typestr(result) if type(result) is not list else [
_typestr(v) for v in result
], result
# Since utility results are not strongly typed, we recursively
# encode type information for nested structures of lists/dicts
# to help with correct msgspec deserialization.
return _encode_type_info_recursive(result), result

if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise TypeError(f"Object of type {type(obj)} is not serializable"
Expand Down Expand Up @@ -288,15 +322,9 @@ def _decode_utility_result(self, obj: Any) -> UtilityResult:
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
raise TypeError("VLLM_ALLOW_INSECURE_SERIALIZATION must "
"be set to use custom utility result types")
assert isinstance(result_type, list)
if len(result_type) == 2 and isinstance(result_type[0], str):
result = self._convert_result(result_type, result)
else:
assert isinstance(result, list)
result = [
self._convert_result(rt, r)
for rt, r in zip(result_type, result)
]
# Use recursive decoding to handle nested structures
result = _decode_type_info_recursive(result_type, result,
self._convert_result)
return UtilityResult(result)

def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
Expand Down