Skip to content

Commit 2bb535e

Browse files
committed
fix: update types and attach bitmask to requests
Signed-off-by: Aaron Pham <[email protected]>
1 parent 39068c8 commit 2bb535e

File tree

5 files changed

+62
-231
lines changed

5 files changed

+62
-231
lines changed

vllm/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import threading
2323
import time
2424
import traceback
25-
import types
2625
import uuid
2726
import warnings
2827
import weakref
Lines changed: 14 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
from __future__ import annotations
22

3-
import copy
3+
import copy, enum
44
import threading
5-
from abc import ABC, abstractmethod
65
from concurrent.futures import ThreadPoolExecutor
76
from dataclasses import dataclass
8-
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, get_args
7+
from typing import TYPE_CHECKING, TypeVar
98

10-
from transformers import PreTrainedTokenizer
119
import xgrammar as xgr
1210

1311
from vllm.config import ModelConfig
@@ -17,8 +15,7 @@
1715
from .grammar import Grammar
1816

1917
if TYPE_CHECKING:
20-
from transformers import PreTrainedTokenizer
21-
from typing_extensions import LiteralString
18+
from typing_extensions import Self
2219

2320
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
2421

@@ -31,18 +28,11 @@
3128

3229
@dataclass
3330
class GrammarCache:
34-
value: Grammar | None
31+
value: Optional[Grammar]
3532
event: threading.Event
3633

3734

38-
T = TypeVar("T", bound=str)
39-
40-
41-
class GuidedDecodingManager(ABC, Generic[T]):
42-
43-
@abstractmethod
44-
def initialize_cache(self, key: GuidedDecodingKey) -> Grammar:
45-
...
35+
class GuidedDecodingManager:
4636

4737
def flush(self):
4838
with self._lock:
@@ -84,68 +74,21 @@ def collect(self, request: Request):
8474
return True
8575
return False
8676

87-
@classmethod
88-
def from_backend(cls,
89-
backend: LiteralString = "xgrammar",
90-
/,
91-
*,
92-
tokenizer_group: BaseTokenizerGroup,
93-
model_config: ModelConfig) -> GuidedDecodingManager[T]:
94-
manager_cls = cls._registry.get(backend)
95-
if manager_cls is None:
96-
raise ValueError(
97-
f"Backend '{backend}' not found in registry. Available backends: {list(cls._registry)}"
98-
)
99-
return manager_cls(tokenizer_group=tokenizer_group,
100-
model_config=model_config)
101-
102-
_registry: dict[str, type[GuidedDecodingManager[T]]] = {}
103-
_backend: T
104-
105-
def __init__(self, *, tokenizer_group: BaseTokenizerGroup,
77+
def __init__(self, *, backend: str, tokenizer_group: BaseTokenizerGroup,
10678
model_config: ModelConfig):
79+
self._backend = backend
10780
self.model_config = model_config
10881
self.tokenizer = tokenizer_group.get_lora_tokenizer(None)
10982
self.grammar_cache: dict[GuidedDecodingKey, GrammarCache] = {}
11083
self.executor = ThreadPoolExecutor()
11184
self._lock = threading.Lock()
112-
113-
def __init_subclass__(cls, **kwargs: Any):
114-
if not hasattr(cls, '__orig_bases__'):
115-
raise TypeError(
116-
f"{cls.__qualname__} must be subclass of GuidedDecodingManager"
117-
)
118-
119-
backend = None
120-
for base in cls.__orig_bases__:
121-
if (origin := get_args(base)) and issubclass(
122-
base.__origin__, GuidedDecodingManager):
123-
backend = get_args(origin[0])[0]
124-
break
125-
126-
if backend is None:
127-
raise TypeError(
128-
f"Class {cls.__qualname__} must specify backend as a Literal type"
129-
)
130-
131-
if backend in cls._registry:
132-
name = cls._registry[backend].__qualname__
133-
raise ValueError(
134-
f"Backend '{backend}' is already registered to {name}")
135-
136-
# Set the backend value from the Literal type
137-
cls._backend = backend
13885
cls._registry[backend] = cls
13986

140-
141-
class XGrammarManager(GuidedDecodingManager[Literal["xgrammar"]]):
142-
# cache GrammarCompiler instances based on given tokenizer
143-
_compiler_cache: dict[str, xgr.GrammarCompiler] = {}
144-
_compiler: xgr.GrammarCompiler | None = None
145-
146-
def initialize_cache(self, key: GuidedDecodingKey) -> XGrammar:
87+
def initialize_cache(self, key: GuidedDecodingKey) -> Self:
14788
request_type, grammar_spec = key
148-
compiler = XGrammarManager.get_compiler(self.tokenizer)
89+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
90+
tokenizer, stop_token_ids=stop_token_ids, vocab_size=vocab_size)
91+
compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=max_threads)
14992
if request_type == "json":
15093
if type(grammar_spec) is not str:
15194
ctx = compiler.compile_builtin_json_grammar()
@@ -155,35 +98,6 @@ def initialize_cache(self, key: GuidedDecodingKey) -> XGrammar:
15598
ctx = compiler.compile_grammar(grammar_spec)
15699
else:
157100
raise ValueError("grammar is not of valid supported types.")
158-
return Grammar.from_backend(
159-
self._backend,
160-
matcher=xgr.GrammarMatcher(ctx),
161-
vocab_size=self.model_config.hf_text_config.vocab_size,
162-
ctx=ctx)
163-
164-
def flush(self):
165-
super().flush()
166-
if self._compiler: self._compiler.clear_cache()
167-
for compiler in self._compiler_cache.values():
168-
compiler.clear_cache()
169-
self._compiler_cache.clear()
170-
171-
@classmethod
172-
def get_compiler(
173-
cls,
174-
tokenizer: PreTrainedTokenizer,
175-
*,
176-
max_threads: int = 8,
177-
# passthrough to TokenizerInfo
178-
vocab_size: int | None = None,
179-
stop_token_ids: list[int] | int | None = None
180-
) -> xgr.GrammarCompiler:
181-
cache_key = str(hash(tokenizer))
182-
if cache_key not in cls._compiler_cache:
183-
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
184-
tokenizer,
185-
stop_token_ids=stop_token_ids,
186-
vocab_size=vocab_size)
187-
cls._compiler_cache[cache_key] = xgr.GrammarCompiler(
188-
tokenizer_info, max_threads=max_threads)
189-
return cls._compiler_cache[cache_key]
101+
return Grammar(matcher=xgr.GrammarMatcher(ctx),
102+
vocab_size=self.model_config.hf_text_config.vocab_size,
103+
ctx=ctx)

vllm/v1/core/guided_decoding/grammar.py

Lines changed: 5 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -14,103 +14,8 @@
1414
T = TypeVar("T", bound=Annotated[LiteralString, str])
1515

1616

17-
class Grammar(ABC, Generic[T]):
17+
class Grammar:
1818
finished: bool = False
19-
20-
@abstractmethod
21-
def accept_token(self, token: int) -> bool:
22-
"""Whether to accept the token and advance the machine state."""
23-
24-
@abstractmethod
25-
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
26-
"""Fill the bitmask for the token at the given index."""
27-
28-
@abstractmethod
29-
def allocate_bitmask(self, batch_size: int,
30-
vocab_size: int) -> torch.Tensor:
31-
"""Allocate a bitmask for the given batch size and vocabulary size."""
32-
33-
@staticmethod
34-
@abstractmethod
35-
def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
36-
"""Apply the bitmask to the logits."""
37-
38-
@abstractmethod
39-
def reset(self):
40-
"""Reset the machine state."""
41-
42-
@abstractmethod
43-
def copy(self) -> Self:
44-
"""Copy the grammar object."""
45-
46-
def __copy__(self):
47-
return self.copy()
48-
49-
_registry: dict[str, type[Grammar[T]]] = {}
50-
_backend: T
51-
52-
def __init_subclass__(cls):
53-
if not hasattr(cls, '__orig_bases__'):
54-
raise TypeError(
55-
f"Class {cls.__qualname__} must be a subclass of GrammarObject"
56-
)
57-
58-
backend = None
59-
for base in cls.__orig_bases__:
60-
if (origin := get_args(base)) and issubclass(
61-
base.__origin__, Grammar):
62-
backend = get_args(origin[0])[0]
63-
break
64-
65-
if backend is None:
66-
raise TypeError(
67-
f"Class {cls.__qualname__} must specify backend as Literal type"
68-
)
69-
70-
if backend in cls._registry:
71-
name = cls._registry[backend].__qualname__
72-
raise ValueError(
73-
f"Backend '{backend}' is already registered to {name}")
74-
75-
# Set the backend value from the Literal type
76-
cls._backend = backend
77-
cls._registry[backend] = cls
78-
79-
@overload
80-
@classmethod
81-
def from_backend(
82-
cls,
83-
backend: Literal["xgrammar"] = ...,
84-
*,
85-
matcher: xgr.GrammarMatcher = ...,
86-
vocab_size: int = ...,
87-
ctx: xgr.CompiledGrammar = ...,
88-
) -> XGrammar:
89-
...
90-
91-
@overload
92-
@classmethod
93-
def from_backend(
94-
cls,
95-
backend: Literal["outlines"] = ...,
96-
*,
97-
guide: str = ...,
98-
whitespace_pattern: str | None = ...,
99-
) -> XGrammar:
100-
...
101-
102-
@classmethod
103-
def from_backend(cls,
104-
backend: LiteralString = "xgrammar",
105-
**kwargs: Any) -> Grammar[T]:
106-
grammar_cls = cls._registry.get(backend)
107-
if grammar_cls is None:
108-
raise ValueError(
109-
f"No grammar implementation registered for '{backend}'")
110-
return grammar_cls(**kwargs)
111-
112-
113-
class XGrammar(Grammar[Literal["xgrammar"]]):
11419
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string for jump-forward decoding
11520

11621
def __init__(self, matcher: xgr.GrammarMatcher, vocab_size: int,
@@ -135,15 +40,15 @@ def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
13540

13641
@staticmethod
13742
def apply_bitmask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
138-
# Note: In this method, if the tensors have different dimensions
139-
# on CPU device fails, but on GPU it runs without error. Hence the
140-
# unsqueeze above for scores, to match the token bitmask shape
14143
xgr.apply_token_bitmask_inplace(logits, vocab_mask)
14244

14345
def reset(self):
14446
self.matcher.reset()
14547

14648
def copy(self):
147-
return XGrammar(matcher=xgr.GrammarMatcher(self.ctx),
49+
return Grammar(matcher=xgr.GrammarMatcher(self.ctx),
14850
vocab_size=self.vocab_size,
14951
ctx=self.ctx)
52+
53+
def __copy__(self):
54+
return self.copy()

0 commit comments

Comments
 (0)