Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions benchmarks/benchmark_serving_structured_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,11 +999,12 @@ def main(args: argparse.Namespace):
type=float,
default=1.0,
help="Ratio of Structured Outputs requests")
parser.add_argument("--structured-output-backend",
type=str,
choices=["outlines", "lm-format-enforcer", "xgrammar"],
default="xgrammar",
help="Backend to use for structured outputs")
parser.add_argument(
"--structured-output-backend",
type=str,
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
default="xgrammar",
help="Backend to use for structured outputs")

args = parser.parse_args()
main(args)
1 change: 1 addition & 0 deletions requirements/common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pillow # Required for image processing
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer >= 0.10.11, < 0.11
llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
outlines == 0.1.11
lark == 1.2.2
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"
Expand Down
4 changes: 3 additions & 1 deletion tests/entrypoints/llm/test_guided_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
GUIDED_DECODING_BACKENDS = [
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
]


@pytest.fixture(scope="module")
Expand Down
4 changes: 3 additions & 1 deletion tests/model_executor/test_guided_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from vllm.sampling_params import GuidedDecodingParams

MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
GUIDED_DECODING_BACKENDS = [
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
]
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

Expand Down
4 changes: 3 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2785,7 +2785,9 @@ def compute_hash(self) -> str:
return hash_str

def __post_init__(self):
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
valid_guided_backends = [
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
]

backend = GuidedDecodingParams(
backend=self.guided_decoding_backend).backend_name
Expand Down
27 changes: 22 additions & 5 deletions vllm/model_executor/guided_decoding/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
"xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines")

elif guided_params.json_object:
# https://github.com/mlc-ai/xgrammar/issues/256
fallback_or_error(guided_params,
"xgrammar does not support json_object.",
"guidance")

# If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback.
elif not xgr_installed:
Expand All @@ -88,9 +94,9 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,

if (guided_params.backend_name == "outlines"
and guided_params.json_object is not None):
# outlines doesn't support json_object, fallback to xgrammar
# outlines doesn't support json_object, fallback to guidance
fallback_or_error(guided_params,
"outlines does not support json_object.", "xgrammar")
"outlines does not support json_object.", "guidance")

return guided_params

Expand Down Expand Up @@ -122,10 +128,15 @@ async def get_guided_decoding_logits_processor(
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)

if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)
raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)


def get_local_guided_decoding_logits_processor(
Expand Down Expand Up @@ -155,7 +166,13 @@ def get_local_guided_decoding_logits_processor(
get_local_xgrammar_guided_decoding_logits_processor)
return get_local_xgrammar_guided_decoding_logits_processor(
guided_params, tokenizer, model_config, reasoner)
if guided_params.backend_name == 'guidance':
from vllm.model_executor.guided_decoding.guidance_decoding import (
get_local_guidance_guided_decoding_logits_processor)
return get_local_guidance_guided_decoding_logits_processor(
guided_params, tokenizer)

raise ValueError(
f"Unknown guided decoding backend '{guided_params.backend}'. "
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
)
44 changes: 44 additions & 0 deletions vllm/model_executor/guided_decoding/guidance_decoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# SPDX-License-Identifier: Apache-2.0
from re import escape as regex_escape

import llguidance
from transformers import PreTrainedTokenizerBase

from vllm.model_executor.guided_decoding.guidance_logits_processors import (
GuidanceLogitsProcessor)
from vllm.sampling_params import GuidedDecodingParams


def get_local_guidance_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
"""

grm = ""
if guided_params.json:
grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
elif guided_params.json_object:
grm = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}',
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
elif guided_params.regex:
grm = llguidance.grammar_from("regex", guided_params.regex)
elif guided_params.choice:
# choice just uses regex
choices = (regex_escape(str(choice))
for choice in guided_params.choice)
choices_regex = "(" + "|".join(choices) + ")"
grm = llguidance.grammar_from("regex", choices_regex)
elif guided_params.grammar:
# this supports Lark and GBNF
grm = llguidance.grammar_from("grammar", guided_params.grammar)

if grm:
return GuidanceLogitsProcessor(grm, tokenizer)

raise ValueError("Unknown guided decoding mode")
85 changes: 85 additions & 0 deletions vllm/model_executor/guided_decoding/guidance_logits_processors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Any, List

import llguidance
import llguidance.hf
import llguidance.torch
import torch
from transformers import PreTrainedTokenizerBase

from vllm.logger import init_logger

logger = init_logger(__name__)


class GuidanceLogitsProcessor:
"""Base Guidance Logits Processor"""

cached_tokenizers: dict[str, Any] = {}

def __init__(
self,
grammar: str,
tokenizer: PreTrainedTokenizerBase,
) -> None:
"""Base Guidance Logits Processor

Args:
grammar (str)
grammar to guide the generation
tokenizer (PreTrainedTokenizerBase)
model's tokenizer
"""
self.grammar = grammar
self.tokenizer = tokenizer
self.tokenizer_name = tokenizer.name_or_path
self.new_sampling = False
self.initialized = False

def _initialize(self):
if self.initialized:
return

ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
None)
if ll_tokenizer is None:
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer

self.ll_tokenizer = ll_tokenizer
self.ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer,
self.grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)

# create reusable bitmask
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size)

self.initialized = True

def __call__(
self,
input_ids: List[int],
scores: torch.Tensor,
) -> torch.Tensor:
# we initialize the guidance model here
# to avoid pickling ll_tokenizer and ll_interpreter
self._initialize()

if self.new_sampling and len(input_ids) > 0:
self.ll_matcher.consume_token(input_ids[-1])
err = self.ll_matcher.get_error()
if err:
logger.warning("Error in LLMatcher: %s", err)

llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
0)
llguidance.torch.apply_token_bitmask_inplace(
scores, self.bitmask.to(scores.device))

self.new_sampling = True

return scores