Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
# SPDX-License-Identifier: Apache-2.0

from dataclasses import asdict
from dataclasses import MISSING, Field, asdict, dataclass, field

import pytest

from vllm.config import ModelConfig, PoolerConfig
from vllm.config import ModelConfig, PoolerConfig, get_field
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform


def test_get_field():

@dataclass
class TestConfig:
a: int
b: dict = field(default_factory=dict)
c: str = "default"

with pytest.raises(ValueError):
get_field(TestConfig, "a")

b = get_field(TestConfig, "b")
assert isinstance(b, Field)
assert b.default is MISSING
assert b.default_factory is dict

c = get_field(TestConfig, "c")
assert isinstance(c, Field)
assert c.default == "default"
assert c.default_factory is MISSING


@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
Expand Down
73 changes: 51 additions & 22 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,23 @@ def config(cls: type[Config]) -> type[Config]:
return cls


def get_field(cls: type[Config], name: str) -> Field:
"""Get the default factory field of a dataclass by name. Used for getting
default factory fields in `EngineArgs`."""
if not is_dataclass(cls):
raise TypeError("The given class is not a dataclass.")
cls_fields = {f.name: f for f in fields(cls)}
if name not in cls_fields:
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
named_field: Field = cls_fields.get(name)
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory)
if (default := named_field.default) is not MISSING:
return field(default=default)
raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory.")


class ModelConfig:
"""Configuration for the model.

Expand Down Expand Up @@ -1356,20 +1373,26 @@ def verify_with_parallel_config(
logger.warning("Possibly too large swap space. %s", msg)


PoolType = Literal["ray"]


@config
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.
"""Configuration for the tokenizer pool."""

Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
extra_config: Additional config for the pool.
The way the config will be used depends on the
pool type.
"""
pool_size: int
pool_type: Union[str, type["BaseTokenizerGroup"]]
extra_config: dict
pool_size: int = 0
"""Number of tokenizer workers in the pool to use for asynchronous
tokenization. If 0, will use synchronous tokenization."""

pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray"
"""Type of tokenizer pool to use for asynchronous tokenization. Ignored if
tokenizer_pool_size is 0."""

extra_config: dict = field(default_factory=dict)
"""Additional config for the pool. The way the config will be used depends
on the pool type. This should be a JSON string that will be parsed into a
dictionary. Ignored if tokenizer_pool_size is 0."""

def compute_hash(self) -> str:
"""
Expand Down Expand Up @@ -1400,7 +1423,7 @@ def __post_init__(self):
@classmethod
def create_config(
cls, tokenizer_pool_size: int,
tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]],
tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]],
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.
Expand Down Expand Up @@ -1475,7 +1498,7 @@ class LoadConfig:
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
model_loader_extra_config: Optional[Union[str, dict]] = None
model_loader_extra_config: dict = field(default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format. This should be a JSON string that
will be parsed into a dictionary."""
Expand Down Expand Up @@ -1506,10 +1529,6 @@ def compute_hash(self) -> str:
return hash_str

def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
self.model_loader_extra_config = json.loads(
model_loader_extra_config)
if isinstance(self.load_format, str):
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
Expand Down Expand Up @@ -2021,9 +2040,19 @@ def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1


Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't really hard code this because OOT platforms may use different names

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already hard coded by EngineArgs using choices=DEVICE_OPTIONS

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @youkaichao is that intended?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To summarise:

  • The argparser behaviour is unchanged
  • In the config we have Literal type hinting (no validation of the input occurrs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already hard coded by EngineArgs using choices=DEVICE_OPTIONS

usually people just use device="auto" and let the current_platform resolve the device, so i think it should be fine.

i don't see lots of usage with explicitly adding --device=xxx



@config
@dataclass
class DeviceConfig:
device: Optional[torch.device]
device_type: str
"""Configuration for the device to use for vLLM execution."""

device: Union[Device, torch.device] = "auto"
"""Device type for vLLM execution."""
device_type: str = field(init=False)
"""Device type from the current platform. This is set in
`__post_init__`."""

def compute_hash(self) -> str:
"""
Expand All @@ -2045,8 +2074,8 @@ def compute_hash(self) -> str:
usedforsecurity=False).hexdigest()
return hash_str

def __init__(self, device: str = "auto") -> None:
if device == "auto":
def __post_init__(self):
if self.device == "auto":
# Automated device type detection
from vllm.platforms import current_platform
self.device_type = current_platform.device_type
Expand All @@ -2057,7 +2086,7 @@ def __init__(self, device: str = "auto") -> None:
"to turn on verbose logging to help debug the issue.")
else:
# Device type is assigned explicitly
self.device_type = device
self.device_type = self.device

# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
Expand Down
118 changes: 61 additions & 57 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@

import vllm.envs as envs
from vllm import version
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig,
from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
DecodingConfig, Device, DeviceConfig,
DistributedExecutorBackend, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs)
ParallelConfig, PoolerConfig, PoolType,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig, get_attr_docs, get_field)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
Expand All @@ -44,27 +44,17 @@

ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]

DEVICE_OPTIONS = [
"auto",
"cuda",
"neuron",
"cpu",
"tpu",
"xpu",
"hpu",
]

# object is used to allow for special typing forms
T = TypeVar("T")
TypeHint = Union[type[Any], object]
TypeHintT = Union[type[T], object]


def optional_arg(val: str, return_type: type[T]) -> Optional[T]:
def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]:
if val == "" or val == "None":
return None
try:
return cast(Callable, return_type)(val)
return return_type(val)
except ValueError as e:
raise argparse.ArgumentTypeError(
f"Value {val} cannot be converted to {return_type}.") from e
Expand All @@ -82,8 +72,11 @@ def optional_float(val: str) -> Optional[float]:
return optional_arg(val, float)


def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
"""Parses a string containing comma separate key [str] to value [int]
def nullable_kvs(val: str) -> Optional[dict[str, int]]:
"""NOTE: This function is deprecated, args should be passed as JSON
strings instead.

Parses a string containing comma separate key [str] to value [int]
pairs into a dictionary.

Args:
Expand Down Expand Up @@ -117,6 +110,17 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
return out_dict


def optional_dict(val: str) -> Optional[dict[str, int]]:
try:
return optional_arg(val, json.loads)
except ValueError:
logger.warning(
"Failed to parse JSON string. Attempting to parse as "
"comma-separated key=value pairs. This will be deprecated in a "
"future release.")
return nullable_kvs(val)


@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
Expand Down Expand Up @@ -178,12 +182,14 @@ class EngineArgs:
enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
tokenizer_pool_size: int = 0
tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
# Note: Specifying a tokenizer pool by passing a class
# is intended for expert use only. The API may change without
# notice.
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
tokenizer_pool_type: Union[PoolType, Type["BaseTokenizerGroup"]] = \
TokenizerPoolConfig.pool_type
tokenizer_pool_extra_config: dict[str, Any] = \
get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False
Expand All @@ -199,14 +205,14 @@ class EngineArgs:
long_lora_scaling_factors: Optional[Tuple[float]] = None
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
device: Device = DeviceConfig.device
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: Optional[
dict] = LoadConfig.model_loader_extra_config
model_loader_extra_config: dict = \
get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str,
List[str]]] = LoadConfig.ignore_patterns
preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
Expand Down Expand Up @@ -294,14 +300,15 @@ def is_custom_type(cls: TypeHint) -> bool:
"""Check if the class is a custom type."""
return cls.__module__ != "builtins"

def get_kwargs(cls: type[Any]) -> dict[str, Any]:
def get_kwargs(cls: type[Config]) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
name = field.name
# One of these will always be present
default = (field.default_factory
if field.default is MISSING else field.default)
default = field.default
# This will only be True if default is MISSING
if field.default_factory is not MISSING:
default = field.default_factory()
kwargs[name] = {"default": default, "help": cls_docs[name]}

# Make note of if the field is optional and get the actual
Expand Down Expand Up @@ -331,8 +338,9 @@ def get_kwargs(cls: type[Any]) -> dict[str, Any]:
elif can_be_type(field_type, float):
kwargs[name][
"type"] = optional_float if optional else float
elif can_be_type(field_type, dict):
kwargs[name]["type"] = optional_dict
elif (can_be_type(field_type, str)
or can_be_type(field_type, dict)
or is_custom_type(field_type)):
kwargs[name]["type"] = optional_str if optional else str
else:
Expand Down Expand Up @@ -674,25 +682,19 @@ def get_kwargs(cls: type[Any]) -> dict[str, Any]:
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.')
parser.add_argument('--tokenizer-pool-size',
type=int,
default=EngineArgs.tokenizer_pool_size,
help='Size of tokenizer pool to use for '
'asynchronous tokenization. If 0, will '
'use synchronous tokenization.')
parser.add_argument('--tokenizer-pool-type',
type=str,
default=EngineArgs.tokenizer_pool_type,
help='Type of tokenizer pool to use for '
'asynchronous tokenization. Ignored '
'if tokenizer_pool_size is 0.')
parser.add_argument('--tokenizer-pool-extra-config',
type=optional_str,
default=EngineArgs.tokenizer_pool_extra_config,
help='Extra config for tokenizer pool. '
'This should be a JSON string that will be '
'parsed into a dictionary. Ignored if '
'tokenizer_pool_size is 0.')

# Tokenizer arguments
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
tokenizer_group = parser.add_argument_group(
title="TokenizerPoolConfig",
description=TokenizerPoolConfig.__doc__,
)
tokenizer_group.add_argument('--tokenizer-pool-size',
**tokenizer_kwargs["pool_size"])
tokenizer_group.add_argument('--tokenizer-pool-type',
**tokenizer_kwargs["pool_type"])
tokenizer_group.add_argument('--tokenizer-pool-extra-config',
**tokenizer_kwargs["extra_config"])

# Multimodal related configs
parser.add_argument(
Expand Down Expand Up @@ -784,11 +786,15 @@ def get_kwargs(cls: type[Any]) -> dict[str, Any]:
type=int,
default=EngineArgs.max_prompt_adapter_token,
help='Max number of PromptAdapters tokens')
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=DEVICE_OPTIONS,
help='Device type for vLLM execution.')

# Device arguments
device_kwargs = get_kwargs(DeviceConfig)
device_group = parser.add_argument_group(
title="DeviceConfig",
description=DeviceConfig.__doc__,
)
device_group.add_argument("--device", **device_kwargs["device"])

parser.add_argument('--num-scheduler-steps',
type=int,
default=1,
Expand Down Expand Up @@ -1302,8 +1308,6 @@ def create_engine_config(

if self.qlora_adapter_name_or_path is not None and \
self.qlora_adapter_name_or_path != "":
if self.model_loader_extra_config is None:
self.model_loader_extra_config = {}
self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path

Expand Down