-
-
Notifications
You must be signed in to change notification settings - Fork 10.6k
[Core][V0] Add guidance backend for structured output #14589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
vllm-bot
merged 1 commit into
vllm-project:main
from
russellb:llguidance-v0-integration
Mar 20, 2025
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from re import escape as regex_escape | ||
|
||
import llguidance | ||
from transformers import PreTrainedTokenizerBase | ||
|
||
from vllm.model_executor.guided_decoding.guidance_logits_processors import ( | ||
GuidanceLogitsProcessor) | ||
from vllm.sampling_params import GuidedDecodingParams | ||
|
||
|
||
def get_local_guidance_guided_decoding_logits_processor( | ||
guided_params: GuidedDecodingParams, | ||
tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor: | ||
""" | ||
Given an OpenAI-compatible request, check for guided decoding parameters | ||
and get the necessary logits processor for the given guide. | ||
""" | ||
|
||
grm = "" | ||
if guided_params.json: | ||
grm = llguidance.LLMatcher.grammar_from_json_schema( | ||
guided_params.json, | ||
overrides={"whitespace_pattern": guided_params.whitespace_pattern}) | ||
elif guided_params.json_object: | ||
grm = llguidance.LLMatcher.grammar_from_json_schema( | ||
'{"type": "object"}', | ||
overrides={"whitespace_pattern": guided_params.whitespace_pattern}) | ||
elif guided_params.regex: | ||
grm = llguidance.grammar_from("regex", guided_params.regex) | ||
elif guided_params.choice: | ||
# choice just uses regex | ||
choices = (regex_escape(str(choice)) | ||
for choice in guided_params.choice) | ||
choices_regex = "(" + "|".join(choices) + ")" | ||
grm = llguidance.grammar_from("regex", choices_regex) | ||
elif guided_params.grammar: | ||
# this supports Lark and GBNF | ||
grm = llguidance.grammar_from("grammar", guided_params.grammar) | ||
|
||
if grm: | ||
return GuidanceLogitsProcessor(grm, tokenizer) | ||
|
||
raise ValueError("Unknown guided decoding mode") |
85 changes: 85 additions & 0 deletions
85
vllm/model_executor/guided_decoding/guidance_logits_processors.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
import os | ||
from typing import Any, List | ||
|
||
import llguidance | ||
import llguidance.hf | ||
import llguidance.torch | ||
import torch | ||
from transformers import PreTrainedTokenizerBase | ||
|
||
from vllm.logger import init_logger | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class GuidanceLogitsProcessor: | ||
"""Base Guidance Logits Processor""" | ||
|
||
cached_tokenizers: dict[str, Any] = {} | ||
|
||
def __init__( | ||
self, | ||
grammar: str, | ||
tokenizer: PreTrainedTokenizerBase, | ||
) -> None: | ||
"""Base Guidance Logits Processor | ||
|
||
Args: | ||
grammar (str) | ||
grammar to guide the generation | ||
tokenizer (PreTrainedTokenizerBase) | ||
model's tokenizer | ||
""" | ||
self.grammar = grammar | ||
self.tokenizer = tokenizer | ||
self.tokenizer_name = tokenizer.name_or_path | ||
self.new_sampling = False | ||
self.initialized = False | ||
|
||
def _initialize(self): | ||
if self.initialized: | ||
return | ||
|
||
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path, | ||
None) | ||
if ll_tokenizer is None: | ||
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None) | ||
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer | ||
|
||
self.ll_tokenizer = ll_tokenizer | ||
self.ll_matcher = llguidance.LLMatcher( | ||
self.ll_tokenizer, | ||
self.grammar, | ||
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")), | ||
) | ||
|
||
# create reusable bitmask | ||
self.bitmask = llguidance.torch.allocate_token_bitmask( | ||
1, self.ll_tokenizer.vocab_size) | ||
|
||
self.initialized = True | ||
|
||
def __call__( | ||
self, | ||
input_ids: List[int], | ||
scores: torch.Tensor, | ||
) -> torch.Tensor: | ||
# we initialize the guidance model here | ||
# to avoid pickling ll_tokenizer and ll_interpreter | ||
self._initialize() | ||
|
||
if self.new_sampling and len(input_ids) > 0: | ||
self.ll_matcher.consume_token(input_ids[-1]) | ||
err = self.ll_matcher.get_error() | ||
if err: | ||
logger.warning("Error in LLMatcher: %s", err) | ||
|
||
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask, | ||
0) | ||
llguidance.torch.apply_token_bitmask_inplace( | ||
scores, self.bitmask.to(scores.device)) | ||
|
||
self.new_sampling = True | ||
|
||
return scores |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.