Skip to content

Commit 2631863

Browse files
russellblochuynh1412mmoskalaarnphm
committed
Add Guidance backend to V0 structured output
This commit is based on the PR #10217. It is updated to be compatible with `main`. It has also been significantly updated and simplified to take advantage of changes to llguidance since the original PR was written. Many thanks to the llguidance team for the help on this. Signed-off-by: Russell Bryant <[email protected]> Co-authored-by: Loc Huynh <[email protected]> Co-authored-by: Michal Moskal <[email protected]> Co-authored-by: Aaron Pham <[email protected]>
1 parent 61c7a1b commit 2631863

File tree

8 files changed

+167
-13
lines changed

8 files changed

+167
-13
lines changed

benchmarks/benchmark_serving_structured_output.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -999,11 +999,12 @@ def main(args: argparse.Namespace):
999999
type=float,
10001000
default=1.0,
10011001
help="Ratio of Structured Outputs requests")
1002-
parser.add_argument("--structured-output-backend",
1003-
type=str,
1004-
choices=["outlines", "lm-format-enforcer", "xgrammar"],
1005-
default="xgrammar",
1006-
help="Backend to use for structured outputs")
1002+
parser.add_argument(
1003+
"--structured-output-backend",
1004+
type=str,
1005+
choices=["outlines", "lm-format-enforcer", "xgrammar", "guidance"],
1006+
default="xgrammar",
1007+
help="Backend to use for structured outputs")
10071008

10081009
args = parser.parse_args()
10091010
main(args)

requirements/common.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ pillow # Required for image processing
1818
prometheus-fastapi-instrumentator >= 7.0.0
1919
tiktoken >= 0.6.0 # Required for DBRX tokenizer
2020
lm-format-enforcer >= 0.10.11, < 0.11
21+
llguidance >= 0.7.2, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64"
2122
outlines == 0.1.11
2223
lark == 1.2.2
2324
xgrammar == 0.1.16; platform_machine == "x86_64" or platform_machine == "aarch64"

tests/entrypoints/llm/test_guided_generate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
1515

1616
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
17-
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
17+
GUIDED_DECODING_BACKENDS = [
18+
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
19+
]
1820

1921

2022
@pytest.fixture(scope="module")

tests/model_executor/test_guided_processors.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
from vllm.sampling_params import GuidedDecodingParams
1717

1818
MODEL_NAME = 'HuggingFaceH4/zephyr-7b-beta'
19-
GUIDED_DECODING_BACKENDS = ["outlines", "lm-format-enforcer", "xgrammar"]
19+
GUIDED_DECODING_BACKENDS = [
20+
"outlines", "lm-format-enforcer", "xgrammar", "guidance"
21+
]
2022
GUIDED_DECODING_BACKENDS_WITH_REASONING_SUPPORT = ["outlines", "xgrammar"]
2123
REASONING_MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
2224

vllm/config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2785,7 +2785,9 @@ def compute_hash(self) -> str:
27852785
return hash_str
27862786

27872787
def __post_init__(self):
2788-
valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar']
2788+
valid_guided_backends = [
2789+
'outlines', 'lm-format-enforcer', 'xgrammar', 'guidance'
2790+
]
27892791

27902792
backend = GuidedDecodingParams(
27912793
backend=self.guided_decoding_backend).backend_name

vllm/model_executor/guided_decoding/__init__.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
7979
"xgrammar does not support Lark grammars and the "
8080
"grammar failed to convert to GBNF.", "outlines")
8181

82+
elif guided_params.json_object:
83+
# https://github.com/mlc-ai/xgrammar/issues/256
84+
fallback_or_error(guided_params,
85+
"xgrammar does not support json_object.",
86+
"guidance")
87+
8288
# If the xgrammar module cannot be imported successfully,
8389
# we should still allow users to use guided decoding with a fallback.
8490
elif not xgr_installed:
@@ -88,9 +94,9 @@ def fallback_or_error(guided_params: GuidedDecodingParams, message: str,
8894

8995
if (guided_params.backend_name == "outlines"
9096
and guided_params.json_object is not None):
91-
# outlines doesn't support json_object, fallback to xgrammar
97+
# outlines doesn't support json_object, fallback to guidance
9298
fallback_or_error(guided_params,
93-
"outlines does not support json_object.", "xgrammar")
99+
"outlines does not support json_object.", "guidance")
94100

95101
return guided_params
96102

@@ -122,10 +128,15 @@ async def get_guided_decoding_logits_processor(
122128
get_local_xgrammar_guided_decoding_logits_processor)
123129
return get_local_xgrammar_guided_decoding_logits_processor(
124130
guided_params, tokenizer, model_config, reasoner)
125-
131+
if guided_params.backend_name == 'guidance':
132+
from vllm.model_executor.guided_decoding.guidance_decoding import (
133+
get_local_guidance_guided_decoding_logits_processor)
134+
return get_local_guidance_guided_decoding_logits_processor(
135+
guided_params, tokenizer)
126136
raise ValueError(
127137
f"Unknown guided decoding backend '{guided_params.backend}'. "
128-
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
138+
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
139+
)
129140

130141

131142
def get_local_guided_decoding_logits_processor(
@@ -155,7 +166,13 @@ def get_local_guided_decoding_logits_processor(
155166
get_local_xgrammar_guided_decoding_logits_processor)
156167
return get_local_xgrammar_guided_decoding_logits_processor(
157168
guided_params, tokenizer, model_config, reasoner)
169+
if guided_params.backend_name == 'guidance':
170+
from vllm.model_executor.guided_decoding.guidance_decoding import (
171+
get_local_guidance_guided_decoding_logits_processor)
172+
return get_local_guidance_guided_decoding_logits_processor(
173+
guided_params, tokenizer)
158174

159175
raise ValueError(
160176
f"Unknown guided decoding backend '{guided_params.backend}'. "
161-
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'")
177+
"Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar', 'guidance'"
178+
)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from re import escape as regex_escape
3+
4+
import llguidance
5+
from transformers import PreTrainedTokenizerBase
6+
7+
from vllm.model_executor.guided_decoding.guidance_logits_processors import (
8+
GuidanceLogitsProcessor)
9+
from vllm.sampling_params import GuidedDecodingParams
10+
11+
12+
def get_local_guidance_guided_decoding_logits_processor(
13+
guided_params: GuidedDecodingParams,
14+
tokenizer: PreTrainedTokenizerBase) -> GuidanceLogitsProcessor:
15+
"""
16+
Given an OpenAI-compatible request, check for guided decoding parameters
17+
and get the necessary logits processor for the given guide.
18+
"""
19+
20+
grm = ""
21+
if guided_params.json:
22+
grm = llguidance.LLMatcher.grammar_from_json_schema(
23+
guided_params.json,
24+
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
25+
elif guided_params.json_object:
26+
grm = llguidance.LLMatcher.grammar_from_json_schema(
27+
'{"type": "object"}',
28+
overrides={"whitespace_pattern": guided_params.whitespace_pattern})
29+
elif guided_params.regex:
30+
grm = llguidance.grammar_from("regex", guided_params.regex)
31+
elif guided_params.choice:
32+
# choice just uses regex
33+
choices = (regex_escape(str(choice))
34+
for choice in guided_params.choice)
35+
choices_regex = "(" + "|".join(choices) + ")"
36+
grm = llguidance.grammar_from("regex", choices_regex)
37+
elif guided_params.grammar:
38+
# this supports Lark and GBNF
39+
grm = llguidance.grammar_from("grammar", guided_params.grammar)
40+
41+
if grm:
42+
return GuidanceLogitsProcessor(grm, tokenizer)
43+
44+
raise ValueError("Unknown guided decoding mode")
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
from typing import Any, List
4+
5+
import llguidance
6+
import llguidance.hf
7+
import llguidance.torch
8+
import torch
9+
from transformers import PreTrainedTokenizerBase
10+
11+
from vllm.logger import init_logger
12+
13+
logger = init_logger(__name__)
14+
15+
16+
class GuidanceLogitsProcessor:
17+
"""Base Guidance Logits Processor"""
18+
19+
cached_tokenizers: dict[str, Any] = {}
20+
21+
def __init__(
22+
self,
23+
grammar: str,
24+
tokenizer: PreTrainedTokenizerBase,
25+
) -> None:
26+
"""Base Guidance Logits Processor
27+
28+
Args:
29+
grammar (str)
30+
grammar to guide the generation
31+
tokenizer (PreTrainedTokenizerBase)
32+
model's tokenizer
33+
"""
34+
self.grammar = grammar
35+
self.tokenizer = tokenizer
36+
self.tokenizer_name = tokenizer.name_or_path
37+
self.new_sampling = False
38+
self.initialized = False
39+
40+
def _initialize(self):
41+
if self.initialized:
42+
return
43+
44+
ll_tokenizer = self.cached_tokenizers.get(self.tokenizer.name_or_path,
45+
None)
46+
if ll_tokenizer is None:
47+
ll_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
48+
self.cached_tokenizers[self.tokenizer.name_or_path] = ll_tokenizer
49+
50+
self.ll_tokenizer = ll_tokenizer
51+
self.ll_matcher = llguidance.LLMatcher(
52+
self.ll_tokenizer,
53+
self.grammar,
54+
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
55+
)
56+
57+
# create reusable bitmask
58+
self.bitmask = llguidance.torch.allocate_token_bitmask(
59+
1, self.ll_tokenizer.vocab_size)
60+
61+
self.initialized = True
62+
63+
def __call__(
64+
self,
65+
input_ids: List[int],
66+
scores: torch.Tensor,
67+
) -> torch.Tensor:
68+
# we initialize the guidance model here
69+
# to avoid pickling ll_tokenizer and ll_interpreter
70+
self._initialize()
71+
72+
if self.new_sampling and len(input_ids) > 0:
73+
self.ll_matcher.consume_token(input_ids[-1])
74+
err = self.ll_matcher.get_error()
75+
if err:
76+
logger.warning("Error in LLMatcher: %s", err)
77+
78+
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
79+
0)
80+
llguidance.torch.apply_token_bitmask_inplace(
81+
scores, self.bitmask.to(scores.device))
82+
83+
self.new_sampling = True
84+
85+
return scores

0 commit comments

Comments
 (0)