Skip to content

Commit d9ebcdd

Browse files
authored
Merge pull request #172 from structuredllm/json_quote
Fix ByteFSM for dead state
2 parents 9687b53 + e337ca6 commit d9ebcdd

File tree

7 files changed

+158
-66
lines changed

7 files changed

+158
-66
lines changed

syncode/mask_store/byte_fsm.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _build_byte_fsm(self, regex_fsm):
5050
self.transitions = {}
5151

5252
# Create a mapping from byte values to category numbers
53-
self.byte_to_category = {}
53+
self.byte_to_category = {}
5454

5555
# Extract the mapping from the regex FSM's alphabet and build our byte-level alphabet
5656
for char, category in regex_fsm.alphabet.items():
@@ -84,8 +84,8 @@ def _build_byte_fsm(self, regex_fsm):
8484
# Copy the transitions from the regex FSM to our byte FSM
8585
for state, category_transitions in regex_fsm.map.items():
8686
for category, target in category_transitions.items():
87-
self.transitions[state][category] = target
88-
87+
self.transitions[state][category] = target
88+
8989
# Handle multi-byte Unicode characters separately
9090
# This is needed because a multi-byte character might need special handling
9191
for char, category in regex_fsm.alphabet.items():
@@ -95,30 +95,42 @@ def _build_byte_fsm(self, regex_fsm):
9595
char_bytes = char.encode('utf-8')
9696
if len(char_bytes) <= 1:
9797
continue
98-
98+
99+
# Add an explicit dead state for invalid transitions
100+
dead_state = f"DEAD"
101+
if dead_state not in self.transitions:
102+
self.transitions[dead_state] = {}
103+
99104
# For multi-byte characters, we need to add special transitions
100105
# Make a copy of states to avoid modifying the dictionary during iteration
101106
states_to_process = list(self.transitions.keys())
102107
for state in states_to_process:
103108
if category in self.transitions[state]:
104109
target = self.transitions[state][category]
110+
else:
111+
target = dead_state
105112

106-
# Create intermediate states for the multi-byte character
107-
current = state
108-
for i, byte in enumerate(char_bytes):
109-
if byte not in self.alphabet:
110-
# Add the byte to the alphabet with a new category
111-
byte_category = f"{byte}_{i}"
112-
self.byte_to_category[byte] = byte_category
113+
# Create intermediate states for the multi-byte character
114+
current = state
115+
for i, byte in enumerate(char_bytes):
116+
if byte not in self.alphabet:
117+
# Add the byte to the alphabet with a new category
118+
byte_category = f"{byte}_{i}"
119+
self.byte_to_category[byte] = byte_category
113120

114-
if i < len(char_bytes) - 1:
121+
if i < len(char_bytes) - 1:
122+
if byte_category not in self.transitions[current]:
123+
# Create a new state for this byte
115124
next_state = f"{current}_{byte}_{i}_{char}"
116125
if next_state not in self.transitions:
117126
self.transitions[next_state] = {}
118127
self.transitions[current][byte_category] = next_state
119128
current = next_state
120129
else:
121-
self.transitions[current][byte_category] = target
130+
# Transition already exists
131+
current = self.transitions[current][byte_category]
132+
else:
133+
self.transitions[current][byte_category] = target
122134

123135
@lru_cache(maxsize=100000)
124136
def _get_category(self, byte_val: int) -> Any:

syncode/mask_store/fsm_set.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import time
22
import interegular
3-
from typing import Any, Optional, Tuple, Iterable, Dict
3+
from typing import Any, Optional, Tuple, Iterable, Dict, Union
44
from syncode.mask_store.byte_fsm import ByteFSM
55
import logging
66
logger = logging.getLogger(__name__)
@@ -21,11 +21,20 @@ def __hash__(self):
2121
return self._hash
2222

2323
@staticmethod
24-
def det_hash(terminal: str, state_id: int):
24+
def det_hash(terminal: str, state_id: Union[str, int]):
2525
h = 0
2626
for char in terminal:
2727
h = (h * 31 + ord(char)) & 0xFFFFFFFF
28-
h = (h * 31 + state_id) & 0xFFFFFFFF
28+
29+
# Handle state_id based on its type
30+
if isinstance(state_id, str):
31+
# If state_id is a string, hash each character
32+
for char in state_id:
33+
h = (h * 31 + ord(char)) & 0xFFFFFFFF
34+
else:
35+
# If state_id is an integer, hash it directly
36+
h = (h * 31 + state_id) & 0xFFFFFFFF
37+
2938
return h
3039

3140
def __repr__(self):

tests/mask_store/test_byte_fsm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ def test_consume_prefix(self):
120120
("[email protected]", (False, None)),
121121
("user@", (True, b"")), # Live state
122122
("invalid", (True, b"")) # Live state for [a-z]+
123+
]),
124+
('"[^"”“]+"', [
125+
('\"key”', (False, None)),
123126
])
124127
]
125128

tests/mask_store/test_mask_store_go.py

Lines changed: 0 additions & 48 deletions
This file was deleted.
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import sys
2+
import os
3+
import time
4+
import unittest
5+
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
6+
import syncode.common as common
7+
from syncode.parsers.incremental_parser import ParseResult
8+
from syncode.parse_result import AcceptSequence, RemainderState
9+
from syncode.mask_store.mask_store import MaskStore
10+
from syncode.parsers.grammars.grammar import Grammar
11+
from tests.test_utils import CustomAssertMixin
12+
13+
14+
class TestMaskGo(unittest.TestCase, CustomAssertMixin):
15+
def setUp(self):
16+
model = 'Qwen/Qwen2.5-1.5B-Instruct'
17+
tokenizer = common.load_tokenizer(model)
18+
self.mask_store = MaskStore.init_mask_store(grammar=Grammar('go'), tokenizer=tokenizer, use_cache=False, mode='grammar_mask')
19+
return super().setUp()
20+
21+
def test_mask(self):
22+
r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, b'1', RemainderState.MAYBE_COMPLETE)
23+
self.mask_store.get_accept_mask(r, get_list=True)
24+
result_list = self.mask_store.get_accept_mask(r, get_list=True)
25+
for token in [' +', ' +=', ' ++']:
26+
self.assertInWithLimit(token, result_list, f"{token} not found in result list")
27+
28+
def test_mask2(self):
29+
r = ParseResult({AcceptSequence(['EOS'])}, b'\n // 1.', RemainderState.MAYBE_COMPLETE)
30+
result_list = self.mask_store.get_accept_mask(r, get_list=True)
31+
self.assertTrue(len(result_list) > 32000, "Result list is smaller than expected")
32+
33+
def test_mask3(self):
34+
r = ParseResult({AcceptSequence(['__ANON_14'])}, b'', RemainderState.COMPLETE)
35+
result_list = self.mask_store.get_accept_mask(r, get_list=True)
36+
# Uncomment the following line if you want to assert presence of specific tokens
37+
self.assertInWithLimit(":=", result_list, ":= not found in result list")
38+
39+
def test_mask4(self):
40+
r = ParseResult({AcceptSequence(['__IGNORE_0'])}, b'', RemainderState.COMPLETE)
41+
self.assertInWithLimit("\t", self.mask_store.get_accept_mask(r, get_list=True), "Tab character not found in result list")
42+
43+
def test_mask5(self):
44+
r = ParseResult({AcceptSequence(['LBRACE', '__IGNORE_0'])}, b'{', RemainderState.MAYBE_COMPLETE)
45+
self.assertInWithLimit("\t", self.mask_store.get_accept_mask(r, get_list=True), "Tab character not found in result list")
46+
47+
def test_mask6(self):
48+
r = ParseResult({AcceptSequence(['NAME'])}, b'for', RemainderState.MAYBE_COMPLETE)
49+
self.assertInWithLimit(" {", self.mask_store.get_accept_mask(r, get_list=True), "Opening brace not found in result list")
50+
51+
52+
class TestMaskJSON(unittest.TestCase, CustomAssertMixin):
53+
def setUp(self):
54+
model = 'google/gemma-2-2b-it'
55+
tokenizer = common.load_tokenizer(model)
56+
57+
custom_json_grammar = f"""
58+
?start: start_value
59+
?start_value: object
60+
| array
61+
62+
?value: object
63+
| array
64+
| EMPTY_STRING
65+
| NONEMPTY_STRING
66+
| SIGNED_NUMBER -> number
67+
| "true" -> true
68+
| "false" -> false
69+
| "null" -> null
70+
71+
array : "[" [value ("," value)*] "]"
72+
object : "{" [pair ("," pair)*] "}"
73+
pair : NONEMPTY_STRING ":" value
74+
75+
NONEMPTY_STRING: /\"[^"”“]+\"/
76+
EMPTY_STRING: /\"\"/
77+
78+
DIGIT: "0".."9"
79+
HEXDIGIT: "a".."f"|"A".."F"|DIGIT
80+
INT: DIGIT+
81+
SIGNED_INT: ["+"|"-"] INT
82+
DECIMAL: INT "." INT? | "." INT
83+
84+
85+
_EXP: ("e"|"E") SIGNED_INT
86+
FLOAT: INT _EXP | DECIMAL _EXP?
87+
NUMBER: FLOAT | INT
88+
SIGNED_NUMBER: ["+"|"-"] NUMBER
89+
WS: /[ \t\f\r\n]/+
90+
91+
%ignore WS
92+
"""
93+
self.mask_store = MaskStore.init_mask_store(grammar=Grammar(custom_json_grammar), tokenizer=tokenizer, use_cache=False, mode='grammar_mask')
94+
return super().setUp()
95+
96+
def test_mask(self):
97+
r = ParseResult({AcceptSequence(['NONEMPTY_STRING'])}, b'"key', RemainderState.INCOMPLETE)
98+
result_list = self.mask_store.get_accept_mask(r, get_list=True)
99+
self.assertInWithLimit('"', result_list, '" not found in result list')
100+
self.assertNotIn('”', result_list)
101+
self.assertNotIn('“', result_list)
102+
103+
104+
if __name__ == '__main__':
105+
# Run JSON tests
106+
suite = unittest.TestLoader().loadTestsFromTestCase(TestMaskJSON)
107+
unittest.TextTestRunner().run(suite)

tests/mask_store/test_mask_store_python.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TestDFAMaskLlama(unittest.TestCase, CustomAssertMixin):
1919
mask_store = MaskStore.init_mask_store(
2020
grammar=Grammar('python'),
2121
tokenizer=tokenizer,
22-
use_cache=True,
22+
use_cache=False,
2323
indent=True,
2424
mode="grammar_strict"
2525
)

tests/parser/test_grammar_json.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,13 @@ def test_json_parser2(self):
2525
r = inc_parser.get_acceptable_next_terminals(partial_code)
2626
assert r.remainder == ''
2727
assert r.remainder_state == RemainderState.COMPLETE
28-
28+
29+
def test_json_parser3(self):
30+
# Tests when the last incomplete word is unparsed
31+
inc_parser.reset()
32+
partial_code = '{\n "key'
33+
r = inc_parser.get_acceptable_next_terminals(partial_code)
34+
assert AcceptSequence(['NONEMPTY_STRING']) in r.accept_sequences
35+
36+
if __name__ == '__main__':
37+
unittest.main()

0 commit comments

Comments
 (0)