diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 4e9e5506bb58..d823c45d5990 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -119,16 +119,21 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None: def _validate_structured_output(self, params: SamplingParams) -> None: if not params.guided_decoding or not self.decoding_config: return - if self.decoding_config.guided_decoding_backend != "xgrammar": - raise ValueError( - "Only xgrammar structured output is supported in V1.") - if (params.guided_decoding.backend - and params.guided_decoding.backend != 'xgrammar'): - raise ValueError( - "Only xgrammar structured output is supported in V1.") - if self.vllm_config.speculative_config: - raise ValueError("Structured output is not supported with " - "speculative decoding.") + + supported_backends = ["xgrammar"] + engine_level_backend = self.decoding_config.guided_decoding_backend + if engine_level_backend not in supported_backends: + raise ValueError(f"Only {supported_backends} structured output is " + "supported in V1.") + if params.guided_decoding.backend: + if params.guided_decoding.backend != engine_level_backend: + raise ValueError("Request-level structured output backend " + "must match engine-level backend. " + f"{params.guided_decoding.backend}" + f" != {engine_level_backend}") + else: + params.guided_decoding.backend = engine_level_backend + if vllm.platforms.current_platform.is_tpu(): raise ValueError("Structured output is not supported on TPU.") diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 5ed7b832aac5..58ac00e985a9 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -7,75 +7,27 @@ from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs -from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer -from vllm.utils import LazyLoader -from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar) +from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend if TYPE_CHECKING: import numpy as np import numpy.typing as npt - import xgrammar as xgr + import torch from vllm.v1.request import Request -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") logger = init_logger(__name__) class StructuredOutputManager: + """Engine-level manager for structured output requests.""" def __init__(self, vllm_config: VllmConfig): + self.backend: Optional[StructuredOutputBackend] = None self.vllm_config = vllm_config - self.init_complete = False - - def _delayed_init(self): - """Initialization delayed until we know it is needed.""" - tokenizer_group = init_tokenizer_from_configs( - model_config=self.vllm_config.model_config, - scheduler_config=self.vllm_config.scheduler_config, - parallel_config=self.vllm_config.parallel_config, - lora_config=self.vllm_config.lora_config) # type: ignore[arg-type] - tokenizer_group.ping() - - tokenizer = tokenizer_group.get_lora_tokenizer(None) - self.vocab_size = self.vllm_config.model_config.get_vocab_size() - if isinstance(tokenizer, MistralTokenizer): - # NOTE: ideally, xgrammar should handle this accordingly. - # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 - try: - encoded_vocab = [ - token for token, _ in sorted( - tokenizer.get_vocab().items(), - key=lambda x: x[1], - ) - ] - stop_token_ids = None - if hasattr( - tokenizer, - "eos_token_id", - ) and tokenizer.eos_token_id is not None: - stop_token_ids = [tokenizer.eos_token_id] - except AttributeError as e: - raise ValueError( - f"Cannot get the vocabulary of the tokenizer " - f"{type(tokenizer)}. The tokenizer should have a " - "get_vocab method.") from e - tokenizer_info = xgr.TokenizerInfo( - encoded_vocab=encoded_vocab, - # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 - vocab_type=xgr.VocabType.BYTE_FALLBACK, - vocab_size=self.vocab_size, - stop_token_ids=stop_token_ids, - add_prefix_space=True, - ) - else: - tokenizer_info = xgr.TokenizerInfo.from_huggingface( - tokenizer, - vocab_size=self.vocab_size, - ) - self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + self._grammar_bitmask: Optional[torch.Tensor] = None # The default max_workers if not specified is the number of CPUs * 5, # which is way too high since these tasks are CPU-bound, not I/O bound. @@ -83,28 +35,30 @@ def _delayed_init(self): # compilation, so we set it to half the number of CPUs. max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) - self._grammar_bitmask = xgr.allocate_token_bitmask( - self.vllm_config.scheduler_config.max_num_seqs, - self.vocab_size, - ) - - self.init_complete = True def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: return - # The first time this is called, we need to finish initialization - # of xgrammar. We defer it to avoid the import of xgrammar and - # initialization cost if it is not going to be used. - if not self.init_complete: - self._delayed_init() + # Initialize the backend the first time it is needed. + # + # NOTE: We only support a single backend. We do NOT support different + # backends on a per-request basis in V1 (for now, anyway...). + if self.backend is None: + backend_name = request.sampling_params.guided_decoding.backend_name + if backend_name == "xgrammar": + self.backend = XgrammarBackend(self.vllm_config) + else: + raise ValueError( + f"Unsupported structured output backend: {backend_name}") - grammar: Future[Grammar] = self.executor.submit( - self._async_create_grammar, request) + grammar: Future[StructuredOutputGrammar] = self.executor.submit( + self._async_create_grammar, request, self.backend) request.structured_output_request.grammar = grammar # type: ignore[assignment] - def _async_create_grammar(self, request: Request) -> Grammar: + def _async_create_grammar( + self, request: Request, + backend: StructuredOutputBackend) -> StructuredOutputGrammar: key = request.structured_output_request.structured_output_key # type: ignore[union-attr] # Note that the request was validated in the engine core client, @@ -114,28 +68,8 @@ def _async_create_grammar(self, request: Request) -> Grammar: # though it should be unlikely as we test that up front as well. request_type, grammar_spec = key - if request_type == StructuredOutputOptions.JSON: - # TODO -- allow any_whitespace to be configurable - # pending merge of https://github.com/vllm-project/vllm/pull/12744 - ctx = self.compiler.compile_json_schema(grammar_spec, - any_whitespace=False) - elif request_type == StructuredOutputOptions.JSON_OBJECT: - ctx = self.compiler.compile_builtin_json_grammar() - elif request_type == StructuredOutputOptions.GRAMMAR: - ctx = self.compiler.compile_grammar(grammar_spec) - elif request_type == StructuredOutputOptions.REGEX: - ctx = self.compiler.compile_regex(grammar_spec) - else: - logger.error("Validation should have already occurred. " - "Please file an issue.") - raise ValueError( - f"grammar is not of valid supported types. ({request_type!s})") - - return Grammar( - matcher=xgr.GrammarMatcher(ctx), - vocab_size=self.vocab_size, - ctx=ctx, - ) + assert self.backend is not None + return self.backend.compile_grammar(request_type, grammar_spec) def grammar_bitmask( self, @@ -147,6 +81,11 @@ def grammar_bitmask( if not structured_output_request_ids: return None + if self._grammar_bitmask is None: + assert self.backend is not None + self._grammar_bitmask = self.backend.allocate_token_bitmask( + self.vllm_config.scheduler_config.max_num_seqs) + # Fill the bitmask using the index of each request equal to its # position in the batch. Resize the bitmask down to the size of # the batch. @@ -154,7 +93,7 @@ def grammar_bitmask( for req_id, batch_index in structured_output_request_ids.items(): request = requests[req_id].structured_output_request assert request is not None and request.grammar is not None - if not request.grammar.matcher.is_terminated(): + if not request.grammar.is_terminated(): request.grammar.fill_bitmask(bitmask_tensor, batch_index) if batch_len < self._grammar_bitmask.shape[0]: bitmask_tensor = self._grammar_bitmask[:batch_len] diff --git a/vllm/v1/structured_output/backend_types.py b/vllm/v1/structured_output/backend_types.py new file mode 100644 index 000000000000..6dc2a92411de --- /dev/null +++ b/vllm/v1/structured_output/backend_types.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 + +import enum +from abc import ABC, abstractmethod + +import torch + + +class StructuredOutputOptions(enum.Enum): + JSON = enum.auto() + JSON_OBJECT = enum.auto() + REGEX = enum.auto() + GRAMMAR = enum.auto() + CHOICE = enum.auto() + + +StructuredOutputKey = tuple[StructuredOutputOptions, str] + + +class StructuredOutputGrammar(ABC): + """Request-level backend for structured output requests.""" + + @abstractmethod + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + """ + Determines whether the provided tokens are accepted for the + given request. + + Args: + request_id (str): The unique identifier for the request. + tokens (list[int]): A list of token IDs to evaluate. + + Returns: + bool: True if the tokens are accepted, False otherwise. + """ + + @abstractmethod + def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None: + """ + Fills the bitmask for a specific batch index. + + Args: + bitmask (torch.Tensor): The bitmask to fill + batch_index (int): The index in the bitmask to fill + """ + + @abstractmethod + def is_terminated(self) -> bool: + """ + Checks whether the structured output process has terminated. + + Returns: + bool: True if the process is terminated, False otherwise. + """ + + @abstractmethod + def reset(self): + """ + Resets the state of the structured output grammar. + """ + + +class StructuredOutputBackend(ABC): + """Engine-level backend for structured output requests.""" + + @abstractmethod + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + """ + Compiles a grammar specification into a structured output grammar. + + Args: + request_type (StructuredOutputOptions): The type of structured + output request. + grammar_spec (str): The grammar specification to compile. + + Returns: + StructuredOutputGrammar: The compiled structured output grammar. + """ + + @abstractmethod + def allocate_token_bitmask(self, max_num_seqs: int): + """ + Allocates a token bitmask for the specified maximum number of sequences. + + Args: + max_num_seqs (int): The maximum number of sequences for which + to allocate the bitmask. + """ diff --git a/vllm/v1/structured_output/backend_xgrammar.py b/vllm/v1/structured_output/backend_xgrammar.py new file mode 100644 index 000000000000..ce93ca5c751b --- /dev/null +++ b/vllm/v1/structured_output/backend_xgrammar.py @@ -0,0 +1,143 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import torch + +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.utils import LazyLoader +from vllm.v1.structured_output.backend_types import (StructuredOutputBackend, + StructuredOutputGrammar, + StructuredOutputOptions) + +if TYPE_CHECKING: + import xgrammar as xgr +else: + xgr = LazyLoader("xgr", globals(), "xgrammar") + +logger = init_logger(__name__) + + +class XgrammarBackend(StructuredOutputBackend): + + def __init__(self, vllm_config: VllmConfig): + self.vllm_config = vllm_config + tokenizer_group = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + parallel_config=vllm_config.parallel_config, + lora_config=vllm_config.lora_config) # type: ignore[arg-type] + tokenizer_group.ping() + + tokenizer = tokenizer_group.get_lora_tokenizer(None) + self.vocab_size = vllm_config.model_config.get_vocab_size() + if isinstance(tokenizer, MistralTokenizer): + # NOTE: ideally, xgrammar should handle this accordingly. + # refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 + try: + encoded_vocab = [ + token for token, _ in sorted( + tokenizer.get_vocab().items(), + key=lambda x: x[1], + ) + ] + stop_token_ids = None + if hasattr( + tokenizer, + "eos_token_id", + ) and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"{type(tokenizer)}. The tokenizer should have a " + "get_vocab method.") from e + tokenizer_info = xgr.TokenizerInfo( # type: ignore + encoded_vocab=encoded_vocab, + # NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501 + vocab_type=xgr.VocabType.BYTE_FALLBACK, + vocab_size=self.vocab_size, + stop_token_ids=stop_token_ids, + add_prefix_space=True, + ) + else: + tokenizer_info = xgr.TokenizerInfo.from_huggingface( + tokenizer, + vocab_size=self.vocab_size, + ) + self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) + + def compile_grammar(self, request_type: StructuredOutputOptions, + grammar_spec: str) -> StructuredOutputGrammar: + if request_type == StructuredOutputOptions.JSON: + ctx = self.compiler.compile_json_schema(grammar_spec, + any_whitespace=False) + elif request_type == StructuredOutputOptions.JSON_OBJECT: + ctx = self.compiler.compile_builtin_json_grammar() + elif request_type == StructuredOutputOptions.GRAMMAR: + ctx = self.compiler.compile_grammar(grammar_spec) + elif request_type == StructuredOutputOptions.REGEX: + ctx = self.compiler.compile_regex(grammar_spec) + else: + logger.error( + "Validation should have already occurred. Please file an issue." + ) + raise ValueError( + f"grammar is not of valid supported types. ({request_type!s})") + + return XgrammarGrammar( + matcher=xgr.GrammarMatcher(ctx), + vocab_size=self.vocab_size, + ctx=ctx, + ) + + def allocate_token_bitmask(self, max_num_seqs: int): + return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size) + + +@dataclass +class XgrammarGrammar(StructuredOutputGrammar): + # NOTE: This would be a generic-enough class for + # supporting different backends, in the future. + # For now, just xgrammar. + # + # TODO: support max_rollback_tokens + # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string + # for jump-forward decoding + + vocab_size: int + matcher: xgr.GrammarMatcher = field(hash=False) + ctx: xgr.CompiledGrammar = field(hash=False) + num_processed_tokens: int = field(default_factory=lambda: 0, + repr=False, + hash=False, + init=False) + + def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: + """Accepts a list of tokens and advances the FSM. + + Returns True if the FSM was advanced successfully. + Returns False if the FSM failed to advance. + """ + for token in tokens: + if not self.matcher.accept_token(token): + logger.error( + "Failed to advance FSM for request %s " + "for tokens %s. Please file an issue.", request_id, token) + return False + self.num_processed_tokens += 1 + return True + + def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None: + self.matcher.fill_next_token_bitmask(bitmask, idx) + + def is_terminated(self) -> bool: + return self.matcher.is_terminated() + + def reset(self): + self.num_processed_tokens = 0 + self.matcher.reset() diff --git a/vllm/v1/structured_output/grammar.py b/vllm/v1/structured_output/grammar.py deleted file mode 100644 index 0e9b2b172261..000000000000 --- a/vllm/v1/structured_output/grammar.py +++ /dev/null @@ -1,77 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations - -import enum -from dataclasses import dataclass, field -from typing import TYPE_CHECKING - -import torch - -from vllm.logger import init_logger -from vllm.utils import LazyLoader - -if TYPE_CHECKING: - import xgrammar as xgr -else: - xgr = LazyLoader("xgr", globals(), "xgrammar") - -logger = init_logger(__name__) - - -class StructuredOutputOptions(enum.Enum): - JSON = enum.auto() - JSON_OBJECT = enum.auto() - REGEX = enum.auto() - GRAMMAR = enum.auto() - CHOICE = enum.auto() - - -StructuredOutputKey = tuple[StructuredOutputOptions, str] - - -@dataclass -class Grammar: - # NOTE: This would be a generic-enough class for - # supporting different backends, in the future. - # For now, just xgrammar. - # - # TODO: support max_rollback_tokens - # https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string - # for jump-forward decoding - - vocab_size: int - matcher: xgr.GrammarMatcher = field(hash=False) - ctx: xgr.CompiledGrammar = field(hash=False) - num_processed_tokens: int = field(default_factory=lambda: 0, - repr=False, - hash=False, - init=False) - - def accept_tokens(self, request_id: str, tokens: list[int]) -> bool: - """Accepts a list of tokens and advances the FSM. - - Returns True if the FSM was advanced successfully. - Returns False if the FSM failed to advance. - """ - for token in tokens: - if not self.matcher.accept_token(token): - logger.error( - "Failed to advance FSM for request %s " - "for tokens %s. Please file an issue.", request_id, token) - return False - self.num_processed_tokens += 1 - return True - - def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool: - return self.matcher.fill_next_token_bitmask(bitmask, idx) - - def reset(self): - self.num_processed_tokens = 0 - self.matcher.reset() - - def __copy__(self): - return Grammar( - matcher=xgr.GrammarMatcher(self.ctx), - vocab_size=self.vocab_size, - ctx=self.ctx, - ) diff --git a/vllm/v1/structured_output/request.py b/vllm/v1/structured_output/request.py index fbcfd541df54..718fa5834edb 100644 --- a/vllm/v1/structured_output/request.py +++ b/vllm/v1/structured_output/request.py @@ -9,15 +9,17 @@ from typing import Optional, Union, cast from vllm.sampling_params import SamplingParams -from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey, - StructuredOutputOptions) +from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar, + StructuredOutputKey, + StructuredOutputOptions) @dataclasses.dataclass class StructuredOutputRequest: sampling_params: SamplingParams - _grammar: Optional[Union[Future[Grammar], Grammar]] = None + _grammar: Optional[Union[Future[StructuredOutputGrammar], + StructuredOutputGrammar]] = None def _check_grammar_completion(self) -> bool: # NOTE: We have to lazy import to gate circular imports @@ -37,12 +39,16 @@ def is_grammar_ready(self) -> bool: return self._check_grammar_completion() @property - def grammar(self) -> Optional[Grammar]: + def grammar(self) -> Optional[StructuredOutputGrammar]: completed = self._check_grammar_completion() - return cast(Optional[Grammar], self._grammar) if completed else None + return cast(Optional[StructuredOutputGrammar], + self._grammar) if completed else None @grammar.setter - def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None: + def grammar( + self, grammar: Union[StructuredOutputGrammar, + Future[StructuredOutputGrammar]] + ) -> None: self._grammar = grammar @functools.cached_property