Skip to content

Commit a42c7fb

Browse files
committed
Refactor logits processor
1 parent 3383371 commit a42c7fb

File tree

8 files changed

+186
-158
lines changed

8 files changed

+186
-158
lines changed

syncode/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from syncode.infer import Syncode
2-
from syncode.grammar_decoder import SyncodeLogitsProcessor
2+
from syncode.grammar_mask.logits_processor import SyncodeLogitsProcessor
33
from syncode.parsers.grammars import Grammar
44
import syncode.common as common
55

syncode/evaluation/code_eval.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,13 @@ def run_eval_for_task(syncode, num_samples_per_task, format_tabs, problems, samp
8686
# We tokenize the whole thing together since tokenizer just the generated_ids messes up with the
8787
# indentation and removes the initial whitespaces in some cases
8888
raw_completion = syncode.model.tokenizer.decode(generated_ids, skip_special_tokens=True)
89+
grammar_constrainer = syncode.model.logits_processor.grammar_engine
8990

9091
# Post-processing to filter out using stop word
9192
if syncode.model.grammar != None and syncode.model.grammar.name == "python":
92-
completion = CodeEval.postproces_completion_python(syncode.model, i, batch_size, input_ids_cutoff, generated_ids, syncode.model.grammar_decoder, raw_completion, stop_words)
93+
completion = CodeEval.postproces_completion_python(syncode.model, i, batch_size, input_ids_cutoff, generated_ids, grammar_constrainer, raw_completion, stop_words)
9394
elif syncode.model.grammar != None and syncode.model.grammar.name == "go":
94-
completion = CodeEval.postproces_completion_go(syncode.model, i, batch_size, raw_completion, generated_ids, syncode.model.grammar_decoder, input_ids_cutoff)
95+
completion = CodeEval.postproces_completion_go(syncode.model, i, batch_size, raw_completion, generated_ids, grammar_constrainer, input_ids_cutoff)
9596
else: # TODO: handle the case for other grammars
9697
completion = raw_completion
9798

@@ -121,39 +122,39 @@ def write_results(syncode, out_path, avg_time, functional_result, num_tasks=1):
121122
f.write(f"Averge time taken for each task: {avg_time:.2f}s\n")
122123
f.write("\n")
123124

124-
def postproces_completion_python(hf_model, i, batch_size, input_ids_cutoff, generated_ids, grammar_decoder, raw_completion, stop_words):
125+
def postproces_completion_python(hf_model, i, batch_size, input_ids_cutoff, generated_ids, grammar_constrainer, raw_completion, stop_words):
125126
generated_output = hf_model.tokenizer.decode(generated_ids[input_ids_cutoff:])
126127

127-
if all(stop_word not in generated_output for stop_word in stop_words) and hf_model.tokenizer.eos_token_id != generated_ids[-1] and grammar_decoder is not None:
128-
# Use when the stop word does not exist in the completion and grammar_decoder is used
128+
if all(stop_word not in generated_output for stop_word in stop_words) and hf_model.tokenizer.eos_token_id != generated_ids[-1] and grammar_constrainer is not None:
129+
# Use when the stop word does not exist in the completion and grammar_constrainer is used
129130
function_incomplete = [False for _ in range(batch_size)]
130-
completion = CodeEval.compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion)
131+
completion = CodeEval.compute_backup_completion(hf_model, grammar_constrainer, function_incomplete, i, raw_completion)
131132
else:
132133
completion = raw_completion
133134
return completion
134135

135-
def postproces_completion_go(hf_model, i, batch_size, raw_completion, generated_ids, grammar_decoder, input_ids_cutoff):
136+
def postproces_completion_go(hf_model, i, batch_size, raw_completion, generated_ids, grammar_constrainer, input_ids_cutoff):
136137
if hf_model.mode != "original":
137-
# When the grammar_decoder is used
138+
# When the grammar_constrainer is used
138139
function_incomplete = [False for _ in range(batch_size)]
139-
completion = CodeEval.compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion)
140+
completion = CodeEval.compute_backup_completion(hf_model, grammar_constrainer, function_incomplete, i, raw_completion)
140141

141142
if function_incomplete[i]:
142143
completion += "}"
143144

144145
return completion
145146

146-
def compute_backup_completion(hf_model, grammar_decoder, function_incomplete, i, raw_completion):
147-
if grammar_decoder.function_ends[i] is not None:
148-
fn_ends = sorted(list(set(grammar_decoder.function_ends[i])))
147+
def compute_backup_completion(hf_model, grammar_constrainer, function_incomplete, i, raw_completion):
148+
if grammar_constrainer.function_ends[i] is not None:
149+
fn_ends = sorted(list(set(grammar_constrainer.function_ends[i])))
149150
if len(fn_ends) > 1:
150151
# if the function end is not None, then the last valid state is the function end
151152
last_valid_state = fn_ends[1]
152153
return raw_completion[:last_valid_state]
153154

154155
# otherwise, the last valid state is the last valid state
155156
function_incomplete[i] = True
156-
last_valid_state = grammar_decoder.last_valid_state[i]
157+
last_valid_state = grammar_constrainer.last_valid_state[i]
157158

158159
# Use when the stop word does not exist in the completion
159160
backup_completion = raw_completion[:last_valid_state]

syncode/grammar_decoder.py renamed to syncode/grammar_mask/grammar_constrainer.py

Lines changed: 96 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -10,81 +10,80 @@
1010
import logging
1111
logger = logging.getLogger(__name__)
1212

13-
14-
# Set to True for debugging
15-
DEBUG = True
16-
17-
class SyncodeLogitsProcessor(LogitsProcessor):
13+
class GrammarConstrainer:
1814
"""
19-
This class is used to filter the logits of the model to only allow syntactically valid tokens for Python.
20-
21-
Args:
22-
grammar (str): The grammar to use for parsing e.g. "python".
23-
tokenizer (PreTrainedTokenizer): The tokenizer to use for decoding.
24-
use_cache (bool, optional): Whether to use the cache. Defaults to True.
25-
parse_output_only (bool, optional): Whether to parse the prompt. Defaults to False.
26-
num_samples (int, optional): The number of sequences to generate. Defaults to 1.
27-
dev_mode (bool, optional): Whether to run in development mode. Defaults to False.
28-
parser (str, optional): The parser to use. Defaults to 'lalr'.
29-
mode (str, optional): The mode to use. Defaults to 'grammar_mask'.
15+
Core class for constraining LLM token generation based on formal grammar rules.
16+
17+
This class handles the parsing of generated code, validates its grammatical correctness,
18+
and creates token masks to ensure syntactically valid generations.
19+
20+
The class supports two primary operating modes:
21+
22+
1. `grammar_mask` (Conservative/Overapproximation):
23+
This mode is more permissive and overapproximates the set of acceptable tokens.
24+
It allows a wider range of tokens that might be syntactically valid given the
25+
limited lookahead of the parser. This mode preserves more of the LLM's original
26+
token distribution while still enforcing basic syntactic correctness.
27+
28+
2. `grammar_strict` (Strict/Underapproximation):
29+
This mode is stricter and underapproximates the set of acceptable tokens.
30+
It enforces tighter grammatical constraints and may be more invasive in the
31+
LLM's generation process. It sometimes breaks LLM tokens that would have been
32+
syntactically correct when considered as a whole, potentially affecting the
33+
fluency or accuracy of generation.
34+
35+
Example illustrating the difference:
36+
Consider generating Python code with the partial input: `def calculate`
37+
38+
In `grammar_mask` mode, it might allow tokens like:
39+
- "(num" (combining opening parenthesis and parameter name as one token)
40+
41+
In `grammar_strict` mode, it would force separate tokens:
42+
- "(" followed by "num" (requiring two separate token generations)
43+
44+
For more details on the approximation methods, refer to the SynCode paper:
45+
https://arxiv.org/abs/2403.01632
3046
"""
3147
def __init__(self,
32-
grammar: Grammar,
33-
tokenizer: PreTrainedTokenizer,
34-
use_cache=True,
35-
parse_output_only=True,
36-
num_samples=1,
37-
dev_mode=False,
38-
parser='lalr',
39-
mode='grammar_mask'):
40-
48+
grammar: Grammar,
49+
tokenizer: PreTrainedTokenizer,
50+
byte_tokenizer: ByteTokenizer,
51+
use_cache=True,
52+
parse_output_only=True,
53+
batch_size=1,
54+
dev_mode=False,
55+
parser='lalr',
56+
mode='grammar_mask'):
57+
4158
self.tokenizer = tokenizer
42-
self.byte_tokenizer = ByteTokenizer(tokenizer)
43-
59+
self.byte_tokenizer = byte_tokenizer
4460
self.grammar = grammar
4561
self.dev_mode = dev_mode
46-
self.batch_size = num_samples
62+
self.batch_size = batch_size
4763
self.parse_failed = False
4864

4965
# For backtracking to syntactically valid completions
50-
self.last_valid_state: list = []
51-
self.function_ends: list = []
66+
self.last_valid_state = [0 for _ in range(self.batch_size)]
67+
self.function_ends = [None for _ in range(self.batch_size)]
5268

5369
# We use this when only the LLM output is parsed and not (input+output)
5470
self.parse_output_only = parse_output_only
55-
self.start_from = None
71+
self.start_from = None
5672

5773
# Ignore whitespace tokens
5874
self._ignore_whitespace = self._get_ignore_whitespace(self.grammar)
5975

6076
# Create parser
6177
self.inc_parser: IncrementalParser = create_parser(self.grammar, parser=parser, ignore_whitespace=self._ignore_whitespace)
6278

63-
# Load dfa mask store
79+
# Load dfa mask store with specified mode (grammar_mask or grammar_strict)
6480
self.dfa_mask_store = MaskStore.init_mask_store(
6581
grammar=self.grammar,
6682
tokenizer=self.tokenizer,
6783
use_cache=use_cache,
68-
mode=mode,
84+
mode=mode, # Controls approximation strategy for token masking
6985
)
70-
71-
72-
def _get_ignore_whitespace(self, grammar):
73-
"""
74-
Check if the grammar allows whitespace tokens to be ignored.
75-
"""
76-
base_parser = create_base_parser(grammar)
77-
terminals = base_parser.terminals
78-
ignore_terminals = base_parser.ignore_tokens
79-
80-
import regex
81-
ignore_whitespace = False
82-
for ig_name in ignore_terminals:
83-
for terminal in terminals:
84-
if terminal.name == ig_name:
85-
if regex.match(terminal.pattern.to_regexp(), ' ') is not None:
86-
ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations
87-
return ignore_whitespace
86+
8887

8988
def reset(self):
9089
"""
@@ -96,6 +95,15 @@ def reset(self):
9695
self.start_from = None
9796
self.inc_parser.reset()
9897

98+
def _set_start_from(self, input_ids):
99+
"""
100+
Sets the starting point for parsing based on whether we're parsing only the output or the full input+output.
101+
"""
102+
if self.start_from is None:
103+
if self.parse_output_only:
104+
self.start_from = input_ids.size(1)
105+
else:
106+
self.start_from = 0
99107

100108
def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) -> bool:
101109
"""
@@ -134,19 +142,26 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
134142
is_valid = self.dfa_mask_store.is_valid_prefix(res)
135143

136144
if is_valid:
137-
self.update_valid_state(partial_code, 0, res)
145+
self._update_valid_state(partial_code, 0, res)
138146

139147
return is_valid
140148

141-
def _set_start_from(self, input_ids):
142-
if self.start_from is None:
143-
if self.parse_output_only:
144-
self.start_from = input_ids.size(1)
145-
else:
146-
self.start_from = 0
147-
148-
149-
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
149+
def mask_scores(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
150+
"""
151+
Mask scores by zeroing out invalid next tokens based on grammar constraints.
152+
153+
The exact behavior depends on whether we're using grammar_mask mode (conservative/
154+
overapproximation) or grammar_strict mode (strict/underapproximation). In both cases,
155+
tokens that would lead to definitely invalid syntax are masked out by setting their
156+
scores to negative infinity.
157+
158+
Args:
159+
input_ids (torch.LongTensor): The input ids.
160+
scores (torch.FloatTensor): The scores to be masked.
161+
162+
Returns:
163+
torch.FloatTensor: The masked scores.
164+
"""
150165
self._set_start_from(input_ids) # start_from is used for choosing where the parsing should start
151166
partial_codes = self._get_partial_codes(input_ids)
152167

@@ -188,7 +203,7 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte
188203
else:
189204
res.remainder = res.remainder.encode('utf-8')
190205

191-
self.update_valid_state(partial_code, idx, res)
206+
self._update_valid_state(partial_code, idx, res)
192207
except Exception as e:
193208
if self.dev_mode == True:
194209
raise e
@@ -200,7 +215,6 @@ def _parse_partial_code(self, idx: int, partial_code: str, remainder_bytes: byte
200215
skip = True
201216
return res, skip
202217

203-
204218
def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
205219
"""
206220
Get the partial codes for the input_ids and return the remainder bytes if the partial code is not a valid UTF-8 string.
@@ -219,9 +233,8 @@ def _get_partial_codes(self, input_ids: torch.LongTensor) -> list[(str, bytes)]:
219233
)
220234
output.append((partial_code, remainder_bytes))
221235
return output
222-
223236

224-
def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
237+
def _update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
225238
"""
226239
This a simple heuristic to cut off the generated output at the end of the function.
227240
TODO: Put this under a flag to enable/disable this heuristic.
@@ -237,7 +250,6 @@ def update_valid_state(self, partial_code: str, idx: int, r: ParseResult):
237250
if accept_seq[0] == '$END' or accept_seq[0] == 'EOF':
238251
self.last_valid_state[idx] = len(partial_code) - len(r.remainder)
239252

240-
241253
@staticmethod
242254
def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:
243255
"""
@@ -253,16 +265,6 @@ def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:
253265
A tuple (string, remainder) where:
254266
- string is the longest valid UTF-8 prefix of the input as a Python string
255267
- remainder is the rest of the bytes that could not be decoded as UTF-8
256-
257-
Examples:
258-
>>> bytes_to_string(b'Hello, world!')
259-
('Hello, world!', b'')
260-
>>> bytes_to_string(b'Hello, \xe2\x82\xac!') # Euro symbol (€) followed by !
261-
('Hello, €!', b'')
262-
>>> bytes_to_string(b'Hello, \xe2\x82!') # Incomplete Euro symbol
263-
('Hello, ', b'\xe2\x82!')
264-
>>> bytes_to_string(b'\xff\xfe') # Invalid UTF-8
265-
('', b'\xff\xfe')
266268
"""
267269
if not isinstance(byte_sequence, bytes):
268270
raise TypeError("Input must be a bytes object")
@@ -292,3 +294,21 @@ def _bytes_to_string(byte_sequence: bytes) -> tuple[str, bytes]:
292294
return byte_sequence[:valid_end].decode('utf-8'), byte_sequence[valid_end:]
293295
else:
294296
return "", byte_sequence
297+
298+
def _get_ignore_whitespace(self, grammar):
299+
"""
300+
Check if the grammar allows whitespace tokens to be ignored.
301+
"""
302+
base_parser = create_base_parser(grammar)
303+
terminals = base_parser.terminals
304+
ignore_terminals = base_parser.ignore_tokens
305+
306+
import regex
307+
ignore_whitespace = False
308+
for ig_name in ignore_terminals:
309+
for terminal in terminals:
310+
if terminal.name == ig_name:
311+
if regex.match(terminal.pattern.to_regexp(), ' ') is not None:
312+
ignore_whitespace = True # convert to boolean tensor mask. This is useful for fast union operations
313+
return ignore_whitespace
314+

0 commit comments

Comments
 (0)