Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,21 @@ def _validate_lora(self, lora_request: Optional[LoRARequest]) -> None:
def _validate_structured_output(self, params: SamplingParams) -> None:
if not params.guided_decoding or not self.decoding_config:
return
if self.decoding_config.guided_decoding_backend != "xgrammar":
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if (params.guided_decoding.backend
and params.guided_decoding.backend != 'xgrammar'):
raise ValueError(
"Only xgrammar structured output is supported in V1.")
if self.vllm_config.speculative_config:
raise ValueError("Structured output is not supported with "
"speculative decoding.")

supported_backends = ["xgrammar"]
engine_level_backend = self.decoding_config.guided_decoding_backend
if engine_level_backend not in supported_backends:
raise ValueError(f"Only {supported_backends} structured output is "
"supported in V1.")
if params.guided_decoding.backend:
if params.guided_decoding.backend != engine_level_backend:
raise ValueError("Request-level structured output backend "
"must match engine-level backend. "
f"{params.guided_decoding.backend}"
f" != {engine_level_backend}")
else:
params.guided_decoding.backend = engine_level_backend

if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")

Expand Down
123 changes: 31 additions & 92 deletions vllm/v1/structured_output/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,104 +7,58 @@

from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
from vllm.utils import LazyLoader
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
StructuredOutputGrammar)
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend

if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import xgrammar as xgr
import torch

from vllm.v1.request import Request
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")

logger = init_logger(__name__)


class StructuredOutputManager:
"""Engine-level manager for structured output requests."""

def __init__(self, vllm_config: VllmConfig):
self.backend: Optional[StructuredOutputBackend] = None
self.vllm_config = vllm_config
self.init_complete = False

def _delayed_init(self):
"""Initialization delayed until we know it is needed."""
tokenizer_group = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
parallel_config=self.vllm_config.parallel_config,
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()

tokenizer = tokenizer_group.get_lora_tokenizer(None)
self.vocab_size = self.vllm_config.model_config.get_vocab_size()
if isinstance(tokenizer, MistralTokenizer):
# NOTE: ideally, xgrammar should handle this accordingly.
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
try:
encoded_vocab = [
token for token, _ in sorted(
tokenizer.get_vocab().items(),
key=lambda x: x[1],
)
]
stop_token_ids = None
if hasattr(
tokenizer,
"eos_token_id",
) and tokenizer.eos_token_id is not None:
stop_token_ids = [tokenizer.eos_token_id]
except AttributeError as e:
raise ValueError(
f"Cannot get the vocabulary of the tokenizer "
f"{type(tokenizer)}. The tokenizer should have a "
"get_vocab method.") from e
tokenizer_info = xgr.TokenizerInfo(
encoded_vocab=encoded_vocab,
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
vocab_type=xgr.VocabType.BYTE_FALLBACK,
vocab_size=self.vocab_size,
stop_token_ids=stop_token_ids,
add_prefix_space=True,
)
else:
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
tokenizer,
vocab_size=self.vocab_size,
)
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
self._grammar_bitmask: Optional[torch.Tensor] = None

# The default max_workers if not specified is the number of CPUs * 5,
# which is way too high since these tasks are CPU-bound, not I/O bound.
# We also know we would never dominate CPU usage with just grammar
# compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs,
self.vocab_size,
)

self.init_complete = True

def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None:
return

# The first time this is called, we need to finish initialization
# of xgrammar. We defer it to avoid the import of xgrammar and
# initialization cost if it is not going to be used.
if not self.init_complete:
self._delayed_init()
# Initialize the backend the first time it is needed.
#
# NOTE: We only support a single backend. We do NOT support different
# backends on a per-request basis in V1 (for now, anyway...).
if self.backend is None:
backend_name = request.sampling_params.guided_decoding.backend_name
if backend_name == "xgrammar":
self.backend = XgrammarBackend(self.vllm_config)
else:
raise ValueError(
f"Unsupported structured output backend: {backend_name}")

grammar: Future[Grammar] = self.executor.submit(
self._async_create_grammar, request)
grammar: Future[StructuredOutputGrammar] = self.executor.submit(
self._async_create_grammar, request, self.backend)
request.structured_output_request.grammar = grammar # type: ignore[assignment]

def _async_create_grammar(self, request: Request) -> Grammar:
def _async_create_grammar(
self, request: Request,
backend: StructuredOutputBackend) -> StructuredOutputGrammar:
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]

# Note that the request was validated in the engine core client,
Expand All @@ -114,28 +68,8 @@ def _async_create_grammar(self, request: Request) -> Grammar:
# though it should be unlikely as we test that up front as well.
request_type, grammar_spec = key

if request_type == StructuredOutputOptions.JSON:
# TODO -- allow any_whitespace to be configurable
# pending merge of https://github.com/vllm-project/vllm/pull/12744
ctx = self.compiler.compile_json_schema(grammar_spec,
any_whitespace=False)
elif request_type == StructuredOutputOptions.JSON_OBJECT:
ctx = self.compiler.compile_builtin_json_grammar()
elif request_type == StructuredOutputOptions.GRAMMAR:
ctx = self.compiler.compile_grammar(grammar_spec)
elif request_type == StructuredOutputOptions.REGEX:
ctx = self.compiler.compile_regex(grammar_spec)
else:
logger.error("Validation should have already occurred. "
"Please file an issue.")
raise ValueError(
f"grammar is not of valid supported types. ({request_type!s})")

return Grammar(
matcher=xgr.GrammarMatcher(ctx),
vocab_size=self.vocab_size,
ctx=ctx,
)
assert self.backend is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should try not to use assert in critical path (and I believe this is)

Given that -O and -OO will strip assert (ik that we aren't using it atm, but probably worth knowing)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, it's something that should never happen and we'd want to know if it did because we know it'll break anyway. It also gives hints to mypy, which is often how I end up adding it.

This should be covered in a style guide somewhere so we have guidelines for the project.

return self.backend.compile_grammar(request_type, grammar_spec)

def grammar_bitmask(
self,
Expand All @@ -147,14 +81,19 @@ def grammar_bitmask(
if not structured_output_request_ids:
return None

if self._grammar_bitmask is None:
assert self.backend is not None
self._grammar_bitmask = self.backend.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs)

# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.
bitmask_tensor = self._grammar_bitmask
for req_id, batch_index in structured_output_request_ids.items():
request = requests[req_id].structured_output_request
assert request is not None and request.grammar is not None
if not request.grammar.matcher.is_terminated():
if not request.grammar.is_terminated():
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
if batch_len < self._grammar_bitmask.shape[0]:
bitmask_tensor = self._grammar_bitmask[:batch_len]
Expand Down
89 changes: 89 additions & 0 deletions vllm/v1/structured_output/backend_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0

import enum
from abc import ABC, abstractmethod

import torch


class StructuredOutputOptions(enum.Enum):
JSON = enum.auto()
JSON_OBJECT = enum.auto()
REGEX = enum.auto()
GRAMMAR = enum.auto()
CHOICE = enum.auto()


StructuredOutputKey = tuple[StructuredOutputOptions, str]


class StructuredOutputGrammar(ABC):
"""Request-level backend for structured output requests."""

@abstractmethod
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
"""
Determines whether the provided tokens are accepted for the
given request.

Args:
request_id (str): The unique identifier for the request.
tokens (list[int]): A list of token IDs to evaluate.

Returns:
bool: True if the tokens are accepted, False otherwise.
"""

@abstractmethod
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
"""
Fills the bitmask for a specific batch index.

Args:
bitmask (torch.Tensor): The bitmask to fill
batch_index (int): The index in the bitmask to fill
"""

@abstractmethod
def is_terminated(self) -> bool:
"""
Checks whether the structured output process has terminated.

Returns:
bool: True if the process is terminated, False otherwise.
"""

@abstractmethod
def reset(self):
"""
Resets the state of the structured output grammar.
"""


class StructuredOutputBackend(ABC):
"""Engine-level backend for structured output requests."""

@abstractmethod
def compile_grammar(self, request_type: StructuredOutputOptions,
grammar_spec: str) -> StructuredOutputGrammar:
"""
Compiles a grammar specification into a structured output grammar.

Args:
request_type (StructuredOutputOptions): The type of structured
output request.
grammar_spec (str): The grammar specification to compile.

Returns:
StructuredOutputGrammar: The compiled structured output grammar.
"""

@abstractmethod
def allocate_token_bitmask(self, max_num_seqs: int):
"""
Allocates a token bitmask for the specified maximum number of sequences.

Args:
max_num_seqs (int): The maximum number of sequences for which
to allocate the bitmask.
"""
Loading