Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,73 @@ def test_uses_mrope(model_id, uses_mrope):
)

assert config.uses_mrope == uses_mrope


def test_generation_config_loading():
model_id = "Qwen/Qwen2.5-1.5B-Instruct"

# When set generation_config to None, the default generation config
# will not be loaded.
model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None)
assert model_config.get_diff_sampling_param() == {}

# When set generation_config to "auto", the default generation config
# should be loaded.
model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config="auto")

correct_generation_config = {
"repetition_penalty": 1.1,
"temperature": 0.7,
"top_p": 0.8,
"top_k": 20,
}

assert model_config.get_diff_sampling_param() == correct_generation_config

# The generation config could be overridden by the user.
override_generation_config = {"temperature": 0.5, "top_k": 5}

model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config="auto",
override_generation_config=override_generation_config)

override_result = correct_generation_config.copy()
override_result.update(override_generation_config)

assert model_config.get_diff_sampling_param() == override_result

# When generation_config is set to None and override_generation_config
# is set, the override_generation_config should be used directly.
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None,
override_generation_config=override_generation_config)

assert model_config.get_diff_sampling_param() == override_generation_config
13 changes: 11 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ class ModelConfig:
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
override_generation_config: Override the generation config with the
given config.
"""

def compute_hash(self) -> str:
Expand Down Expand Up @@ -224,6 +226,7 @@ def __init__(
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None,
enable_sleep_mode: bool = False,
override_generation_config: Optional[Dict[str, Any]] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand Down Expand Up @@ -366,6 +369,7 @@ def __init__(
self.logits_processor_pattern = logits_processor_pattern

self.generation_config = generation_config
self.override_generation_config = override_generation_config or {}

self._verify_quantization()
self._verify_cuda_graph()
Expand Down Expand Up @@ -902,8 +906,13 @@ def get_diff_sampling_param(self) -> Dict[str, Any]:
"""
if self.generation_config is None:
# When generation_config is not set
return {}
config = self.try_get_generation_config()
config = {}
else:
config = self.try_get_generation_config()

# Overriding with given generation config
config.update(self.override_generation_config)

available_params = [
"repetition_penalty",
"temperature",
Expand Down
25 changes: 19 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None

generation_config: Optional[str] = None
override_generation_config: Optional[Dict[str, Any]] = None
enable_sleep_mode: bool = False

calculate_kv_scales: Optional[bool] = None
Expand Down Expand Up @@ -936,12 +937,23 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=nullable_str,
default=None,
help="The folder path to the generation config. "
"Defaults to None, will use the default generation config in vLLM. "
"If set to 'auto', the generation config will be automatically "
"loaded from model. If set to a folder path, the generation config "
"will be loaded from the specified folder path. If "
"`max_new_tokens` is specified, then it sets a server-wide limit "
"on the number of output tokens for all requests.")
"Defaults to None, no generation config is loaded, vLLM defaults "
"will be used. If set to 'auto', the generation config will be "
"loaded from model path. If set to a folder path, the generation "
"config will be loaded from the specified folder path. If "
"`max_new_tokens` is specified in generation config, then "
"it sets a server-wide limit on the number of output tokens "
"for all requests.")

parser.add_argument(
"--override-generation-config",
type=json.loads,
default=None,
help="Overrides or sets generation config in JSON format. "
"e.g. ``{\"temperature\": 0.5}``. If used with "
"--generation-config=auto, the override parameters will be merged "
"with the default config from the model. If generation-config is "
"None, only the override parameters are used.")

parser.add_argument("--enable-sleep-mode",
action="store_true",
Expand Down Expand Up @@ -1002,6 +1014,7 @@ def create_model_config(self) -> ModelConfig:
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config,
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode,
)

Expand Down
Loading