diff --git a/vllm/config.py b/vllm/config.py index 20ca20ad2b6d..bc0b59d767b6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2125,139 +2125,113 @@ def __post_init__(self): self.device = torch.device(self.device_type) +SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator", + "draft_model"] +SpeculativeAcceptanceMethod = Literal["rejection_sampler", + "typical_acceptance_sampler"] + + +@config @dataclass class SpeculativeConfig: - """ - Configuration for speculative decoding. - Configurable parameters include: - - General Speculative Decoding Control: - - num_speculative_tokens (int): The number of speculative - tokens, if provided. It will default to the number in the draft - model config if present, otherwise, it is required. - - model (Optional[str]): The name of the draft model, eagle head, - or additional weights, if provided. - - method (Optional[str]): The name of the speculative method to use. - If users provide and set the `model` param, the speculative method - type will be detected automatically if possible, if `model` param - is not provided, the method name must be provided. - - Possible values: - - ngram - Related additional configuration: - - prompt_lookup_max (Optional[int]): - Maximum size of ngram token window when using Ngram - proposer, required when method is set to ngram. - - prompt_lookup_min (Optional[int]): - Minimum size of ngram token window when using Ngram - proposer, if provided. Defaults to 1. - - eagle - - medusa - - mlp_speculator - - draft_model - - acceptance_method (str): The method to use for accepting draft - tokens. This can take two possible values: 'rejection_sampler' and - 'typical_acceptance_sampler' for RejectionSampler and - TypicalAcceptanceSampler respectively. If not specified, it - defaults to 'rejection_sampler'. - - Possible values: - - rejection_sampler - - typical_acceptance_sampler - Related additional configuration: - - posterior_threshold (Optional[float]): - A threshold value that sets a lower bound on the - posterior probability of a token in the target model - for it to be accepted. This threshold is used only - when we use the TypicalAcceptanceSampler for token - acceptance. - - posterior_alpha (Optional[float]): - Scaling factor for entropy-based threshold, applied - when using TypicalAcceptanceSampler. - - draft_tensor_parallel_size (Optional[int]): The degree of the tensor - parallelism for the draft model. Can only be 1 or the same as the - target model's tensor parallel size. - - disable_logprobs (bool): If set to True, token log probabilities are - not returned during speculative decoding. If set to False, token - log probabilities are returned according to the log probability - settings in SamplingParams. If not specified, it defaults to True. - - - Draft Model Configuration: - - quantization (Optional[str]): Quantization method that was used to - quantize the draft model weights. If None, we assume the - model weights are not quantized. Note that it only takes effect - when using the draft model-based speculative method. - - max_model_len (Optional[int]): The maximum model length of the - draft model. Used when testing the ability to skip - speculation for some sequences. - - revision: The specific model version to use for the draft model. It - can be a branch name, a tag name, or a commit id. If unspecified, - will use the default version. - - code_revision: The specific revision to use for the draft model code - on Hugging Face Hub. It can be a branch name, a tag name, or a - commit id. If unspecified, will use the default version. + """Configuration for speculative decoding.""" - - Advanced Control: - - disable_mqa_scorer (bool): Disable the MQA scorer and fall back to - batch expansion for scoring proposals. If not specified, it - defaults to False. - - disable_by_batch_size (Optional[int]): Disable speculative decoding - for new incoming requests when the number of enqueued requests is - larger than this value, if provided. - - Although the parameters above are structured hierarchically, there is no - need to nest them during configuration. - - Non-configurable internal parameters include: - - Model Configuration: - - target_model_config (ModelConfig): The configuration of the target - model. - - draft_model_config (ModelConfig): The configuration of the draft - model initialized internal. - - Parallelism Configuration: - - target_parallel_config (ParallelConfig): The parallel configuration - for the target model. - - draft_parallel_config (ParallelConfig): The parallel configuration - for the draft model initialized internal. - - Execution Control: - - enable_chunked_prefill (bool): Whether vLLM is configured to use - chunked prefill or not. Used for raising an error since it's not - yet compatible with speculative decode. - - disable_log_stats (bool): Whether to disable the periodic printing of - stage times in speculative decoding. - """ - # speculative configs from cli args + # General speculative decoding control num_speculative_tokens: int = field(default=None, init=True) # type: ignore - method: Optional[str] = None - acceptance_method: str = "rejection_sampler" + """The number of speculative tokens, if provided. It will default to the + number in the draft model config if present, otherwise, it is required.""" + model: Optional[str] = None + """The name of the draft model, eagle head, or additional weights, if + provided.""" + method: Optional[SpeculativeMethod] = None + """The name of the speculative method to use. If users provide and set the + `model` param, the speculative method type will be detected automatically + if possible, if `model` param is not provided, the method name must be + provided. + + If using `ngram` method, the related configuration `prompt_lookup_max` and + `prompt_lookup_min` should be considered.""" + acceptance_method: SpeculativeAcceptanceMethod = "rejection_sampler" + """The method to use for accepting draft tokens:\n + - "rejection_sampler" maps to `RejectionSampler`.\n + - "typical_acceptance_sampler" maps to `TypicalAcceptanceSampler`. + + If using `typical_acceptance_sampler`, the related configuration + `posterior_threshold` and `posterior_alpha` should be considered.""" draft_tensor_parallel_size: Optional[int] = None + """The degree of the tensor parallelism for the draft model. Can only be 1 + or the same as the target model's tensor parallel size.""" disable_logprobs: bool = True + """If set to True, token log probabilities are not returned during + speculative decoding. If set to False, token log probabilities are returned + according to the log probability settings in SamplingParams.""" - model: Optional[str] = None + # Draft model configuration quantization: Optional[str] = None + """Quantization method that was used to quantize the draft model weights. + If `None`, we assume the model weights are not quantized. Note that it only + takes effect when using the draft model-based speculative method.""" max_model_len: Optional[int] = None + """The maximum model length of the draft model. Used when testing the + ability to skip speculation for some sequences.""" revision: Optional[str] = None + """The specific model version to use for the draft model. It can be a + branch name, a tag name, or a commit id. If unspecified, will use the + default version.""" code_revision: Optional[str] = None + """The specific revision to use for the draft model code on Hugging Face + Hub. It can be a branch name, a tag name, or a commit id. If unspecified, + will use the default version.""" + # Advanced control disable_mqa_scorer: bool = False + """Disable the MQA scorer and fall back to batch expansion for scoring + proposals.""" disable_by_batch_size: Optional[int] = None + """Disable speculative decoding for new incoming requests when the number + of enqueued requests is larger than this value, if provided.""" + + # Ngram proposer configuration prompt_lookup_max: Optional[int] = None + """Maximum size of ngram token window when using Ngram proposer, required + when method is set to ngram.""" prompt_lookup_min: Optional[int] = None + """Minimum size of ngram token window when using Ngram proposer, if + provided. Defaults to 1.""" + + # Typical acceptance sampler configuration posterior_threshold: Optional[float] = None + """A threshold value that sets a lower bound on the posterior probability + of a token in the target model for it to be accepted. This threshold is + used only when we use the `TypicalAcceptanceSampler` for token acceptance. + """ posterior_alpha: Optional[float] = None + """Scaling factor for entropy-based threshold, applied when using + `TypicalAcceptanceSampler`.""" # required configuration params passed from engine target_model_config: ModelConfig = field(default=None, init=True) # type: ignore + """The configuration of the target model.""" target_parallel_config: ParallelConfig = field(default=None, init=True) # type: ignore + """The parallel configuration for the target model.""" enable_chunked_prefill: bool = field(default=None, init=True) # type: ignore + """Whether vLLM is configured to use chunked prefill or not. Used for + raising an error since it's not yet compatible with speculative decode.""" disable_log_stats: bool = field(default=None, init=True) # type: ignore + """Whether to disable the periodic printing of stage times in speculative + decoding.""" # params generated in the post-init stage draft_model_config: ModelConfig = field(default=None, init=True) # type: ignore + """The configuration of the draft model initialized internal.""" draft_parallel_config: ParallelConfig = field(default=None, init=True) # type: ignore + """The parallel configuration for the draft model initialized internal.""" def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 06529ae25a83..ea82138811c9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -766,11 +766,18 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]: help=('Maximum number of forward steps per ' 'scheduler call.')) - parser.add_argument('--speculative-config', - type=json.loads, - default=None, - help='The configurations for speculative decoding.' - ' Should be a JSON string.') + # Speculative arguments + speculative_group = parser.add_argument_group( + title="SpeculativeConfig", + description=SpeculativeConfig.__doc__, + ) + speculative_group.add_argument( + '--speculative-config', + type=json.loads, + default=None, + help='The configurations for speculative decoding.' + ' Should be a JSON string.') + parser.add_argument( '--ignore-patterns', action="append",