Skip to content

Commit e55b43a

Browse files
liuyanyirasmith
authored andcommitted
[Frontend] Support override generation config in args (vllm-project#12409)
Signed-off-by: liuyanyi <[email protected]>
1 parent 5f21f8d commit e55b43a

File tree

3 files changed

+100
-8
lines changed

3 files changed

+100
-8
lines changed

tests/test_config.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,3 +281,73 @@ def test_uses_mrope(model_id, uses_mrope):
281281
)
282282

283283
assert config.uses_mrope == uses_mrope
284+
285+
286+
def test_generation_config_loading():
287+
model_id = "Qwen/Qwen2.5-1.5B-Instruct"
288+
289+
# When set generation_config to None, the default generation config
290+
# will not be loaded.
291+
model_config = ModelConfig(model_id,
292+
task="auto",
293+
tokenizer=model_id,
294+
tokenizer_mode="auto",
295+
trust_remote_code=False,
296+
seed=0,
297+
dtype="float16",
298+
generation_config=None)
299+
assert model_config.get_diff_sampling_param() == {}
300+
301+
# When set generation_config to "auto", the default generation config
302+
# should be loaded.
303+
model_config = ModelConfig(model_id,
304+
task="auto",
305+
tokenizer=model_id,
306+
tokenizer_mode="auto",
307+
trust_remote_code=False,
308+
seed=0,
309+
dtype="float16",
310+
generation_config="auto")
311+
312+
correct_generation_config = {
313+
"repetition_penalty": 1.1,
314+
"temperature": 0.7,
315+
"top_p": 0.8,
316+
"top_k": 20,
317+
}
318+
319+
assert model_config.get_diff_sampling_param() == correct_generation_config
320+
321+
# The generation config could be overridden by the user.
322+
override_generation_config = {"temperature": 0.5, "top_k": 5}
323+
324+
model_config = ModelConfig(
325+
model_id,
326+
task="auto",
327+
tokenizer=model_id,
328+
tokenizer_mode="auto",
329+
trust_remote_code=False,
330+
seed=0,
331+
dtype="float16",
332+
generation_config="auto",
333+
override_generation_config=override_generation_config)
334+
335+
override_result = correct_generation_config.copy()
336+
override_result.update(override_generation_config)
337+
338+
assert model_config.get_diff_sampling_param() == override_result
339+
340+
# When generation_config is set to None and override_generation_config
341+
# is set, the override_generation_config should be used directly.
342+
model_config = ModelConfig(
343+
model_id,
344+
task="auto",
345+
tokenizer=model_id,
346+
tokenizer_mode="auto",
347+
trust_remote_code=False,
348+
seed=0,
349+
dtype="float16",
350+
generation_config=None,
351+
override_generation_config=override_generation_config)
352+
353+
assert model_config.get_diff_sampling_param() == override_generation_config

vllm/config.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ class ModelConfig:
165165
`logits_processors` extra completion argument. Defaults to None,
166166
which allows no processors.
167167
generation_config: Configuration parameter file for generation.
168+
override_generation_config: Override the generation config with the
169+
given config.
168170
"""
169171

170172
def compute_hash(self) -> str:
@@ -225,6 +227,7 @@ def __init__(
225227
logits_processor_pattern: Optional[str] = None,
226228
generation_config: Optional[str] = None,
227229
enable_sleep_mode: bool = False,
230+
override_generation_config: Optional[Dict[str, Any]] = None,
228231
) -> None:
229232
self.model = model
230233
self.tokenizer = tokenizer
@@ -368,6 +371,7 @@ def __init__(
368371
self.logits_processor_pattern = logits_processor_pattern
369372

370373
self.generation_config = generation_config
374+
self.override_generation_config = override_generation_config or {}
371375

372376
self._verify_quantization()
373377
self._verify_cuda_graph()
@@ -904,8 +908,13 @@ def get_diff_sampling_param(self) -> Dict[str, Any]:
904908
"""
905909
if self.generation_config is None:
906910
# When generation_config is not set
907-
return {}
908-
config = self.try_get_generation_config()
911+
config = {}
912+
else:
913+
config = self.try_get_generation_config()
914+
915+
# Overriding with given generation config
916+
config.update(self.override_generation_config)
917+
909918
available_params = [
910919
"repetition_penalty",
911920
"temperature",

vllm/engine/arg_utils.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ class EngineArgs:
195195
kv_transfer_config: Optional[KVTransferConfig] = None
196196

197197
generation_config: Optional[str] = None
198+
override_generation_config: Optional[Dict[str, Any]] = None
198199
enable_sleep_mode: bool = False
199200

200201
calculate_kv_scales: Optional[bool] = None
@@ -936,12 +937,23 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
936937
type=nullable_str,
937938
default=None,
938939
help="The folder path to the generation config. "
939-
"Defaults to None, will use the default generation config in vLLM. "
940-
"If set to 'auto', the generation config will be automatically "
941-
"loaded from model. If set to a folder path, the generation config "
942-
"will be loaded from the specified folder path. If "
943-
"`max_new_tokens` is specified, then it sets a server-wide limit "
944-
"on the number of output tokens for all requests.")
940+
"Defaults to None, no generation config is loaded, vLLM defaults "
941+
"will be used. If set to 'auto', the generation config will be "
942+
"loaded from model path. If set to a folder path, the generation "
943+
"config will be loaded from the specified folder path. If "
944+
"`max_new_tokens` is specified in generation config, then "
945+
"it sets a server-wide limit on the number of output tokens "
946+
"for all requests.")
947+
948+
parser.add_argument(
949+
"--override-generation-config",
950+
type=json.loads,
951+
default=None,
952+
help="Overrides or sets generation config in JSON format. "
953+
"e.g. ``{\"temperature\": 0.5}``. If used with "
954+
"--generation-config=auto, the override parameters will be merged "
955+
"with the default config from the model. If generation-config is "
956+
"None, only the override parameters are used.")
945957

946958
parser.add_argument("--enable-sleep-mode",
947959
action="store_true",
@@ -1002,6 +1014,7 @@ def create_model_config(self) -> ModelConfig:
10021014
override_pooler_config=self.override_pooler_config,
10031015
logits_processor_pattern=self.logits_processor_pattern,
10041016
generation_config=self.generation_config,
1017+
override_generation_config=self.override_generation_config,
10051018
enable_sleep_mode=self.enable_sleep_mode,
10061019
)
10071020

0 commit comments

Comments
 (0)