1010import logging
1111logger = logging .getLogger (__name__ )
1212
13-
14- # Set to True for debugging
15- DEBUG = True
16-
17- class SyncodeLogitsProcessor (LogitsProcessor ):
13+ class GrammarConstrainer :
1814 """
19- This class is used to filter the logits of the model to only allow syntactically valid tokens for Python.
20-
21- Args:
22- grammar (str): The grammar to use for parsing e.g. "python".
23- tokenizer (PreTrainedTokenizer): The tokenizer to use for decoding.
24- use_cache (bool, optional): Whether to use the cache. Defaults to True.
25- parse_output_only (bool, optional): Whether to parse the prompt. Defaults to False.
26- num_samples (int, optional): The number of sequences to generate. Defaults to 1.
27- dev_mode (bool, optional): Whether to run in development mode. Defaults to False.
28- parser (str, optional): The parser to use. Defaults to 'lalr'.
29- mode (str, optional): The mode to use. Defaults to 'grammar_mask'.
15+ Core class for constraining LLM token generation based on formal grammar rules.
16+
17+ This class handles the parsing of generated code, validates its grammatical correctness,
18+ and creates token masks to ensure syntactically valid generations.
19+
20+ The class supports two primary operating modes:
21+
22+ 1. `grammar_mask` (Conservative/Overapproximation):
23+ This mode is more permissive and overapproximates the set of acceptable tokens.
24+ It allows a wider range of tokens that might be syntactically valid given the
25+ limited lookahead of the parser. This mode preserves more of the LLM's original
26+ token distribution while still enforcing basic syntactic correctness.
27+
28+ 2. `grammar_strict` (Strict/Underapproximation):
29+ This mode is stricter and underapproximates the set of acceptable tokens.
30+ It enforces tighter grammatical constraints and may be more invasive in the
31+ LLM's generation process. It sometimes breaks LLM tokens that would have been
32+ syntactically correct when considered as a whole, potentially affecting the
33+ fluency or accuracy of generation.
34+
35+ Example illustrating the difference:
36+ Consider generating Python code with the partial input: `def calculate`
37+
38+ In `grammar_mask` mode, it might allow tokens like:
39+ - "(num" (combining opening parenthesis and parameter name as one token)
40+
41+ In `grammar_strict` mode, it would force separate tokens:
42+ - "(" followed by "num" (requiring two separate token generations)
43+
44+ For more details on the approximation methods, refer to the SynCode paper:
45+ https://arxiv.org/abs/2403.01632
3046 """
3147 def __init__ (self ,
32- grammar : Grammar ,
33- tokenizer : PreTrainedTokenizer ,
34- use_cache = True ,
35- parse_output_only = True ,
36- num_samples = 1 ,
37- dev_mode = False ,
38- parser = 'lalr' ,
39- mode = 'grammar_mask' ):
40-
48+ grammar : Grammar ,
49+ tokenizer : PreTrainedTokenizer ,
50+ byte_tokenizer : ByteTokenizer ,
51+ use_cache = True ,
52+ parse_output_only = True ,
53+ batch_size = 1 ,
54+ dev_mode = False ,
55+ parser = 'lalr' ,
56+ mode = 'grammar_mask' ):
57+
4158 self .tokenizer = tokenizer
42- self .byte_tokenizer = ByteTokenizer (tokenizer )
43-
59+ self .byte_tokenizer = byte_tokenizer
4460 self .grammar = grammar
4561 self .dev_mode = dev_mode
46- self .batch_size = num_samples
62+ self .batch_size = batch_size
4763 self .parse_failed = False
4864
4965 # For backtracking to syntactically valid completions
50- self .last_valid_state : list = []
51- self .function_ends : list = []
66+ self .last_valid_state = [0 for _ in range ( self . batch_size ) ]
67+ self .function_ends = [None for _ in range ( self . batch_size ) ]
5268
5369 # We use this when only the LLM output is parsed and not (input+output)
5470 self .parse_output_only = parse_output_only
55- self .start_from = None
71+ self .start_from = None
5672
5773 # Ignore whitespace tokens
5874 self ._ignore_whitespace = self ._get_ignore_whitespace (self .grammar )
5975
6076 # Create parser
6177 self .inc_parser : IncrementalParser = create_parser (self .grammar , parser = parser , ignore_whitespace = self ._ignore_whitespace )
6278
63- # Load dfa mask store
79+ # Load dfa mask store with specified mode (grammar_mask or grammar_strict)
6480 self .dfa_mask_store = MaskStore .init_mask_store (
6581 grammar = self .grammar ,
6682 tokenizer = self .tokenizer ,
6783 use_cache = use_cache ,
68- mode = mode ,
84+ mode = mode , # Controls approximation strategy for token masking
6985 )
70-
71-
72- def _get_ignore_whitespace (self , grammar ):
73- """
74- Check if the grammar allows whitespace tokens to be ignored.
75- """
76- base_parser = create_base_parser (grammar )
77- terminals = base_parser .terminals
78- ignore_terminals = base_parser .ignore_tokens
79-
80- import regex
81- ignore_whitespace = False
82- for ig_name in ignore_terminals :
83- for terminal in terminals :
84- if terminal .name == ig_name :
85- if regex .match (terminal .pattern .to_regexp (), ' ' ) is not None :
86- ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations
87- return ignore_whitespace
86+
8887
8988 def reset (self ):
9089 """
@@ -96,6 +95,15 @@ def reset(self):
9695 self .start_from = None
9796 self .inc_parser .reset ()
9897
98+ def _set_start_from (self , input_ids ):
99+ """
100+ Sets the starting point for parsing based on whether we're parsing only the output or the full input+output.
101+ """
102+ if self .start_from is None :
103+ if self .parse_output_only :
104+ self .start_from = input_ids .size (1 )
105+ else :
106+ self .start_from = 0
99107
100108 def is_valid (self , input_ids : torch .LongTensor , next_token : torch .LongTensor ) -> bool :
101109 """
@@ -134,19 +142,26 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
134142 is_valid = self .dfa_mask_store .is_valid_prefix (res )
135143
136144 if is_valid :
137- self .update_valid_state (partial_code , 0 , res )
145+ self ._update_valid_state (partial_code , 0 , res )
138146
139147 return is_valid
140148
141- def _set_start_from (self , input_ids ):
142- if self .start_from is None :
143- if self .parse_output_only :
144- self .start_from = input_ids .size (1 )
145- else :
146- self .start_from = 0
147-
148-
149- def __call__ (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
149+ def mask_scores (self , input_ids : torch .LongTensor , scores : torch .FloatTensor ) -> torch .FloatTensor :
150+ """
151+ Mask scores by zeroing out invalid next tokens based on grammar constraints.
152+
153+ The exact behavior depends on whether we're using grammar_mask mode (conservative/
154+ overapproximation) or grammar_strict mode (strict/underapproximation). In both cases,
155+ tokens that would lead to definitely invalid syntax are masked out by setting their
156+ scores to negative infinity.
157+
158+ Args:
159+ input_ids (torch.LongTensor): The input ids.
160+ scores (torch.FloatTensor): The scores to be masked.
161+
162+ Returns:
163+ torch.FloatTensor: The masked scores.
164+ """
150165 self ._set_start_from (input_ids ) # start_from is used for choosing where the parsing should start
151166 partial_codes = self ._get_partial_codes (input_ids )
152167
@@ -188,7 +203,7 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte
188203 else :
189204 res .remainder = res .remainder .encode ('utf-8' )
190205
191- self .update_valid_state (partial_code , idx , res )
206+ self ._update_valid_state (partial_code , idx , res )
192207 except Exception as e :
193208 if self .dev_mode == True :
194209 raise e
@@ -200,7 +215,6 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte
200215 skip = True
201216 return res , skip
202217
203-
204218 def _get_partial_codes (self , input_ids : torch .LongTensor ) -> list [(str , bytes )]:
205219 """
206220 Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string.
@@ -219,9 +233,8 @@ def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
219233 )
220234 output .append ((partial_code , remainder_bytes ))
221235 return output
222-
223236
224- def update_valid_state (self , partial_code : str , idx : int , r : ParseResult ):
237+ def _update_valid_state (self , partial_code : str , idx : int , r : ParseResult ):
225238 """
226239 This a simple heuristic to cut off the generated output at the end of the function.
227240 TODO: Put this under a flag to enable/disable this heuristic.
@@ -237,7 +250,6 @@ def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
237250 if accept_seq [0 ] == '$END' or accept_seq [0 ] == 'EOF' :
238251 self .last_valid_state [idx ] = len (partial_code ) - len (r .remainder )
239252
240-
241253 @staticmethod
242254 def _bytes_to_string (byte_sequence : bytes ) -> tuple [str , bytes ]:
243255 """
@@ -253,16 +265,6 @@ def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:
253265 A tuple (string, remainder) where:
254266 - string is the longest valid UTF-8 prefix of the input as a Python string
255267 - remainder is the rest of the bytes that could not be decoded as UTF-8
256-
257- Examples:
258- >>> bytes_to_string(b'Hello, world!')
259- ('Hello, world!', b'')
260- >>> bytes_to_string(b'Hello, \xe2 \x82 \xac !') # Euro symbol (€) followed by !
261- ('Hello, €!', b'')
262- >>> bytes_to_string(b'Hello, \xe2 \x82 !') # Incomplete Euro symbol
263- ('Hello, ', b'\xe2 \x82 !')
264- >>> bytes_to_string(b'\xff \xfe ') # Invalid UTF-8
265- ('', b'\xff \xfe ')
266268 """
267269 if not isinstance (byte_sequence , bytes ):
268270 raise TypeError ("Input must be a bytes object" )
@@ -292,3 +294,21 @@ def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:
292294 return byte_sequence [:valid_end ].decode ('utf-8' ), byte_sequence [valid_end :]
293295 else :
294296 return "" , byte_sequence
297+
298+ def _get_ignore_whitespace (self , grammar ):
299+ """
300+ Check if the grammar allows whitespace tokens to be ignored.
301+ """
302+ base_parser = create_base_parser (grammar )
303+ terminals = base_parser .terminals
304+ ignore_terminals = base_parser .ignore_tokens
305+
306+ import regex
307+ ignore_whitespace = False
308+ for ig_name in ignore_terminals :
309+ for terminal in terminals :
310+ if terminal .name == ig_name :
311+ if regex .match (terminal .pattern .to_regexp (), ' ' ) is not None :
312+ ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations
313+ return ignore_whitespace
314+
0 commit comments