1
1
from __future__ import annotations
2
2
3
- import copy
3
+ import copy , enum
4
4
import threading
5
- from abc import ABC , abstractmethod
6
5
from concurrent .futures import ThreadPoolExecutor
7
6
from dataclasses import dataclass
8
- from typing import TYPE_CHECKING , Any , Generic , Literal , TypeVar , get_args
7
+ from typing import TYPE_CHECKING , TypeVar
9
8
10
- from transformers import PreTrainedTokenizer
11
9
import xgrammar as xgr
12
10
13
11
from vllm .config import ModelConfig
17
15
from .grammar import Grammar
18
16
19
17
if TYPE_CHECKING :
20
- from transformers import PreTrainedTokenizer
21
- from typing_extensions import LiteralString
18
+ from typing_extensions import Self
22
19
23
20
from vllm .transformers_utils .tokenizer_group import BaseTokenizerGroup
24
21
31
28
32
29
@dataclass
33
30
class GrammarCache :
34
- value : Grammar | None
31
+ value : Optional [ Grammar ]
35
32
event : threading .Event
36
33
37
34
38
- T = TypeVar ("T" , bound = str )
39
-
40
-
41
- class GuidedDecodingManager (ABC , Generic [T ]):
42
-
43
- @abstractmethod
44
- def initialize_cache (self , key : GuidedDecodingKey ) -> Grammar :
45
- ...
35
+ class GuidedDecodingManager :
46
36
47
37
def flush (self ):
48
38
with self ._lock :
@@ -84,68 +74,21 @@ def collect(self, request: Request):
84
74
return True
85
75
return False
86
76
87
- @classmethod
88
- def from_backend (cls ,
89
- backend : LiteralString = "xgrammar" ,
90
- / ,
91
- * ,
92
- tokenizer_group : BaseTokenizerGroup ,
93
- model_config : ModelConfig ) -> GuidedDecodingManager [T ]:
94
- manager_cls = cls ._registry .get (backend )
95
- if manager_cls is None :
96
- raise ValueError (
97
- f"Backend '{ backend } ' not found in registry. Available backends: { list (cls ._registry )} "
98
- )
99
- return manager_cls (tokenizer_group = tokenizer_group ,
100
- model_config = model_config )
101
-
102
- _registry : dict [str , type [GuidedDecodingManager [T ]]] = {}
103
- _backend : T
104
-
105
- def __init__ (self , * , tokenizer_group : BaseTokenizerGroup ,
77
+ def __init__ (self , * , backend : str , tokenizer_group : BaseTokenizerGroup ,
106
78
model_config : ModelConfig ):
79
+ self ._backend = backend
107
80
self .model_config = model_config
108
81
self .tokenizer = tokenizer_group .get_lora_tokenizer (None )
109
82
self .grammar_cache : dict [GuidedDecodingKey , GrammarCache ] = {}
110
83
self .executor = ThreadPoolExecutor ()
111
84
self ._lock = threading .Lock ()
112
-
113
- def __init_subclass__ (cls , ** kwargs : Any ):
114
- if not hasattr (cls , '__orig_bases__' ):
115
- raise TypeError (
116
- f"{ cls .__qualname__ } must be subclass of GuidedDecodingManager"
117
- )
118
-
119
- backend = None
120
- for base in cls .__orig_bases__ :
121
- if (origin := get_args (base )) and issubclass (
122
- base .__origin__ , GuidedDecodingManager ):
123
- backend = get_args (origin [0 ])[0 ]
124
- break
125
-
126
- if backend is None :
127
- raise TypeError (
128
- f"Class { cls .__qualname__ } must specify backend as a Literal type"
129
- )
130
-
131
- if backend in cls ._registry :
132
- name = cls ._registry [backend ].__qualname__
133
- raise ValueError (
134
- f"Backend '{ backend } ' is already registered to { name } " )
135
-
136
- # Set the backend value from the Literal type
137
- cls ._backend = backend
138
85
cls ._registry [backend ] = cls
139
86
140
-
141
- class XGrammarManager (GuidedDecodingManager [Literal ["xgrammar" ]]):
142
- # cache GrammarCompiler instances based on given tokenizer
143
- _compiler_cache : dict [str , xgr .GrammarCompiler ] = {}
144
- _compiler : xgr .GrammarCompiler | None = None
145
-
146
- def initialize_cache (self , key : GuidedDecodingKey ) -> XGrammar :
87
+ def initialize_cache (self , key : GuidedDecodingKey ) -> Self :
147
88
request_type , grammar_spec = key
148
- compiler = XGrammarManager .get_compiler (self .tokenizer )
89
+ tokenizer_info = xgr .TokenizerInfo .from_huggingface (
90
+ tokenizer , stop_token_ids = stop_token_ids , vocab_size = vocab_size )
91
+ compiler = xgr .GrammarCompiler (tokenizer_info , max_threads = max_threads )
149
92
if request_type == "json" :
150
93
if type (grammar_spec ) is not str :
151
94
ctx = compiler .compile_builtin_json_grammar ()
@@ -155,35 +98,6 @@ def initialize_cache(self, key: GuidedDecodingKey) -> XGrammar:
155
98
ctx = compiler .compile_grammar (grammar_spec )
156
99
else :
157
100
raise ValueError ("grammar is not of valid supported types." )
158
- return Grammar .from_backend (
159
- self ._backend ,
160
- matcher = xgr .GrammarMatcher (ctx ),
161
- vocab_size = self .model_config .hf_text_config .vocab_size ,
162
- ctx = ctx )
163
-
164
- def flush (self ):
165
- super ().flush ()
166
- if self ._compiler : self ._compiler .clear_cache ()
167
- for compiler in self ._compiler_cache .values ():
168
- compiler .clear_cache ()
169
- self ._compiler_cache .clear ()
170
-
171
- @classmethod
172
- def get_compiler (
173
- cls ,
174
- tokenizer : PreTrainedTokenizer ,
175
- * ,
176
- max_threads : int = 8 ,
177
- # passthrough to TokenizerInfo
178
- vocab_size : int | None = None ,
179
- stop_token_ids : list [int ] | int | None = None
180
- ) -> xgr .GrammarCompiler :
181
- cache_key = str (hash (tokenizer ))
182
- if cache_key not in cls ._compiler_cache :
183
- tokenizer_info = xgr .TokenizerInfo .from_huggingface (
184
- tokenizer ,
185
- stop_token_ids = stop_token_ids ,
186
- vocab_size = vocab_size )
187
- cls ._compiler_cache [cache_key ] = xgr .GrammarCompiler (
188
- tokenizer_info , max_threads = max_threads )
189
- return cls ._compiler_cache [cache_key ]
101
+ return Grammar (matcher = xgr .GrammarMatcher (ctx ),
102
+ vocab_size = self .model_config .hf_text_config .vocab_size ,
103
+ ctx = ctx )
0 commit comments