1616
1717import vllm .envs as envs
1818from vllm import version
19- from vllm .config import (CacheConfig , CompilationConfig , ConfigFormat ,
20- DecodingConfig , DeviceConfig ,
19+ from vllm .config import (CacheConfig , CompilationConfig , Config , ConfigFormat ,
20+ DecodingConfig , Device , DeviceConfig ,
2121 DistributedExecutorBackend , HfOverrides ,
2222 KVTransferConfig , LoadConfig , LoadFormat , LoRAConfig ,
2323 ModelConfig , ModelImpl , ObservabilityConfig ,
24- ParallelConfig , PoolerConfig , PromptAdapterConfig ,
25- SchedulerConfig , SchedulerPolicy , SpeculativeConfig ,
26- TaskOption , TokenizerPoolConfig , VllmConfig ,
27- get_attr_docs )
24+ ParallelConfig , PoolerConfig , PoolType ,
25+ PromptAdapterConfig , SchedulerConfig , SchedulerPolicy ,
26+ SpeculativeConfig , TaskOption , TokenizerPoolConfig ,
27+ VllmConfig , get_attr_docs , get_field )
2828from vllm .executor .executor_base import ExecutorBase
2929from vllm .logger import init_logger
3030from vllm .model_executor .layers .quantization import QUANTIZATION_METHODS
4444
4545ALLOWED_DETAILED_TRACE_MODULES = ["model" , "worker" , "all" ]
4646
47- DEVICE_OPTIONS = [
48- "auto" ,
49- "cuda" ,
50- "neuron" ,
51- "cpu" ,
52- "tpu" ,
53- "xpu" ,
54- "hpu" ,
55- ]
56-
5747# object is used to allow for special typing forms
5848T = TypeVar ("T" )
5949TypeHint = Union [type [Any ], object ]
6050TypeHintT = Union [type [T ], object ]
6151
6252
63- def optional_arg (val : str , return_type : type [ T ]) -> Optional [T ]:
53+ def optional_arg (val : str , return_type : Callable [[ str ], T ]) -> Optional [T ]:
6454 if val == "" or val == "None" :
6555 return None
6656 try :
67- return cast ( Callable , return_type ) (val )
57+ return return_type (val )
6858 except ValueError as e :
6959 raise argparse .ArgumentTypeError (
7060 f"Value { val } cannot be converted to { return_type } ." ) from e
@@ -82,8 +72,11 @@ def optional_float(val: str) -> Optional[float]:
8272 return optional_arg (val , float )
8373
8474
85- def nullable_kvs (val : str ) -> Optional [Mapping [str , int ]]:
86- """Parses a string containing comma separate key [str] to value [int]
75+ def nullable_kvs (val : str ) -> Optional [dict [str , int ]]:
76+ """NOTE: This function is deprecated, args should be passed as JSON
77+ strings instead.
78+
79+ Parses a string containing comma separate key [str] to value [int]
8780 pairs into a dictionary.
8881
8982 Args:
@@ -117,6 +110,17 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
117110 return out_dict
118111
119112
113+ def optional_dict (val : str ) -> Optional [dict [str , int ]]:
114+ try :
115+ return optional_arg (val , json .loads )
116+ except ValueError :
117+ logger .warning (
118+ "Failed to parse JSON string. Attempting to parse as "
119+ "comma-separated key=value pairs. This will be deprecated in a "
120+ "future release." )
121+ return nullable_kvs (val )
122+
123+
120124@dataclass
121125class EngineArgs :
122126 """Arguments for vLLM engine."""
@@ -178,12 +182,14 @@ class EngineArgs:
178182 enforce_eager : Optional [bool ] = None
179183 max_seq_len_to_capture : int = 8192
180184 disable_custom_all_reduce : bool = ParallelConfig .disable_custom_all_reduce
181- tokenizer_pool_size : int = 0
185+ tokenizer_pool_size : int = TokenizerPoolConfig . pool_size
182186 # Note: Specifying a tokenizer pool by passing a class
183187 # is intended for expert use only. The API may change without
184188 # notice.
185- tokenizer_pool_type : Union [str , Type ["BaseTokenizerGroup" ]] = "ray"
186- tokenizer_pool_extra_config : Optional [Dict [str , Any ]] = None
189+ tokenizer_pool_type : Union [PoolType , Type ["BaseTokenizerGroup" ]] = \
190+ TokenizerPoolConfig .pool_type
191+ tokenizer_pool_extra_config : dict [str , Any ] = \
192+ get_field (TokenizerPoolConfig , "extra_config" )
187193 limit_mm_per_prompt : Optional [Mapping [str , int ]] = None
188194 mm_processor_kwargs : Optional [Dict [str , Any ]] = None
189195 disable_mm_preprocessor_cache : bool = False
@@ -199,14 +205,14 @@ class EngineArgs:
199205 long_lora_scaling_factors : Optional [Tuple [float ]] = None
200206 lora_dtype : Optional [Union [str , torch .dtype ]] = 'auto'
201207 max_cpu_loras : Optional [int ] = None
202- device : str = 'auto'
208+ device : Device = DeviceConfig . device
203209 num_scheduler_steps : int = SchedulerConfig .num_scheduler_steps
204210 multi_step_stream_outputs : bool = SchedulerConfig .multi_step_stream_outputs
205211 ray_workers_use_nsight : bool = ParallelConfig .ray_workers_use_nsight
206212 num_gpu_blocks_override : Optional [int ] = None
207213 num_lookahead_slots : int = SchedulerConfig .num_lookahead_slots
208- model_loader_extra_config : Optional [
209- dict ] = LoadConfig . model_loader_extra_config
214+ model_loader_extra_config : dict = \
215+ get_field ( LoadConfig , " model_loader_extra_config" )
210216 ignore_patterns : Optional [Union [str ,
211217 List [str ]]] = LoadConfig .ignore_patterns
212218 preemption_mode : Optional [str ] = SchedulerConfig .preemption_mode
@@ -294,14 +300,15 @@ def is_custom_type(cls: TypeHint) -> bool:
294300 """Check if the class is a custom type."""
295301 return cls .__module__ != "builtins"
296302
297- def get_kwargs (cls : type [Any ]) -> dict [str , Any ]:
303+ def get_kwargs (cls : type [Config ]) -> dict [str , Any ]:
298304 cls_docs = get_attr_docs (cls )
299305 kwargs = {}
300306 for field in fields (cls ):
301307 name = field .name
302- # One of these will always be present
303- default = (field .default_factory
304- if field .default is MISSING else field .default )
308+ default = field .default
309+ # This will only be True if default is MISSING
310+ if field .default_factory is not MISSING :
311+ default = field .default_factory ()
305312 kwargs [name ] = {"default" : default , "help" : cls_docs [name ]}
306313
307314 # Make note of if the field is optional and get the actual
@@ -331,8 +338,9 @@ def get_kwargs(cls: type[Any]) -> dict[str, Any]:
331338 elif can_be_type (field_type , float ):
332339 kwargs [name ][
333340 "type" ] = optional_float if optional else float
341+ elif can_be_type (field_type , dict ):
342+ kwargs [name ]["type" ] = optional_dict
334343 elif (can_be_type (field_type , str )
335- or can_be_type (field_type , dict )
336344 or is_custom_type (field_type )):
337345 kwargs [name ]["type" ] = optional_str if optional else str
338346 else :
@@ -674,25 +682,19 @@ def get_kwargs(cls: type[Any]) -> dict[str, Any]:
674682 'Additionally for encoder-decoder models, if the '
675683 'sequence length of the encoder input is larger '
676684 'than this, we fall back to the eager mode.' )
677- parser .add_argument ('--tokenizer-pool-size' ,
678- type = int ,
679- default = EngineArgs .tokenizer_pool_size ,
680- help = 'Size of tokenizer pool to use for '
681- 'asynchronous tokenization. If 0, will '
682- 'use synchronous tokenization.' )
683- parser .add_argument ('--tokenizer-pool-type' ,
684- type = str ,
685- default = EngineArgs .tokenizer_pool_type ,
686- help = 'Type of tokenizer pool to use for '
687- 'asynchronous tokenization. Ignored '
688- 'if tokenizer_pool_size is 0.' )
689- parser .add_argument ('--tokenizer-pool-extra-config' ,
690- type = optional_str ,
691- default = EngineArgs .tokenizer_pool_extra_config ,
692- help = 'Extra config for tokenizer pool. '
693- 'This should be a JSON string that will be '
694- 'parsed into a dictionary. Ignored if '
695- 'tokenizer_pool_size is 0.' )
685+
686+ # Tokenizer arguments
687+ tokenizer_kwargs = get_kwargs (TokenizerPoolConfig )
688+ tokenizer_group = parser .add_argument_group (
689+ title = "TokenizerPoolConfig" ,
690+ description = TokenizerPoolConfig .__doc__ ,
691+ )
692+ tokenizer_group .add_argument ('--tokenizer-pool-size' ,
693+ ** tokenizer_kwargs ["pool_size" ])
694+ tokenizer_group .add_argument ('--tokenizer-pool-type' ,
695+ ** tokenizer_kwargs ["pool_type" ])
696+ tokenizer_group .add_argument ('--tokenizer-pool-extra-config' ,
697+ ** tokenizer_kwargs ["extra_config" ])
696698
697699 # Multimodal related configs
698700 parser .add_argument (
@@ -784,11 +786,15 @@ def get_kwargs(cls: type[Any]) -> dict[str, Any]:
784786 type = int ,
785787 default = EngineArgs .max_prompt_adapter_token ,
786788 help = 'Max number of PromptAdapters tokens' )
787- parser .add_argument ("--device" ,
788- type = str ,
789- default = EngineArgs .device ,
790- choices = DEVICE_OPTIONS ,
791- help = 'Device type for vLLM execution.' )
789+
790+ # Device arguments
791+ device_kwargs = get_kwargs (DeviceConfig )
792+ device_group = parser .add_argument_group (
793+ title = "DeviceConfig" ,
794+ description = DeviceConfig .__doc__ ,
795+ )
796+ device_group .add_argument ("--device" , ** device_kwargs ["device" ])
797+
792798 parser .add_argument ('--num-scheduler-steps' ,
793799 type = int ,
794800 default = 1 ,
@@ -1302,8 +1308,6 @@ def create_engine_config(
13021308
13031309 if self .qlora_adapter_name_or_path is not None and \
13041310 self .qlora_adapter_name_or_path != "" :
1305- if self .model_loader_extra_config is None :
1306- self .model_loader_extra_config = {}
13071311 self .model_loader_extra_config [
13081312 "qlora_adapter_name_or_path" ] = self .qlora_adapter_name_or_path
13091313
0 commit comments