Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions syncode/mask_store/byte_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _build_byte_fsm(self, regex_fsm):
self.transitions = {}

# Create a mapping from byte values to category numbers
self.byte_to_category = {}
self.byte_to_category = {}

# Extract the mapping from the regex FSM's alphabet and build our byte-level alphabet
for char, category in regex_fsm.alphabet.items():
Expand Down Expand Up @@ -84,8 +84,8 @@ def _build_byte_fsm(self, regex_fsm):
# Copy the transitions from the regex FSM to our byte FSM
for state, category_transitions in regex_fsm.map.items():
for category, target in category_transitions.items():
self.transitions[state][category] = target
self.transitions[state][category] = target

# Handle multi-byte Unicode characters separately
# This is needed because a multi-byte character might need special handling
for char, category in regex_fsm.alphabet.items():
Expand All @@ -95,30 +95,42 @@ def _build_byte_fsm(self, regex_fsm):
char_bytes = char.encode('utf-8')
if len(char_bytes) <= 1:
continue


# Add an explicit dead state for invalid transitions
dead_state = f"DEAD"
if dead_state not in self.transitions:
self.transitions[dead_state] = {}

# For multi-byte characters, we need to add special transitions
# Make a copy of states to avoid modifying the dictionary during iteration
states_to_process = list(self.transitions.keys())
for state in states_to_process:
if category in self.transitions[state]:
target = self.transitions[state][category]
else:
target = dead_state

# Create intermediate states for the multi-byte character
current = state
for i, byte in enumerate(char_bytes):
if byte not in self.alphabet:
# Add the byte to the alphabet with a new category
byte_category = f"{byte}_{i}"
self.byte_to_category[byte] = byte_category
# Create intermediate states for the multi-byte character
current = state
for i, byte in enumerate(char_bytes):
if byte not in self.alphabet:
# Add the byte to the alphabet with a new category
byte_category = f"{byte}_{i}"
self.byte_to_category[byte] = byte_category

if i < len(char_bytes) - 1:
if i < len(char_bytes) - 1:
if byte_category not in self.transitions[current]:
# Create a new state for this byte
next_state = f"{current}_{byte}_{i}_{char}"
if next_state not in self.transitions:
self.transitions[next_state] = {}
self.transitions[current][byte_category] = next_state
current = next_state
else:
self.transitions[current][byte_category] = target
# Transition already exists
current = self.transitions[current][byte_category]
else:
self.transitions[current][byte_category] = target

@lru_cache(maxsize=100000)
def _get_category(self, byte_val: int) -> Any:
Expand Down
15 changes: 12 additions & 3 deletions syncode/mask_store/fsm_set.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import time
import interegular
from typing import Any, Optional, Tuple, Iterable, Dict
from typing import Any, Optional, Tuple, Iterable, Dict, Union
from syncode.mask_store.byte_fsm import ByteFSM
import logging
logger = logging.getLogger(__name__)
Expand All @@ -21,11 +21,20 @@ def __hash__(self):
return self._hash

@staticmethod
def det_hash(terminal: str, state_id: int):
def det_hash(terminal: str, state_id: Union[str, int]):
h = 0
for char in terminal:
h = (h * 31 + ord(char)) & 0xFFFFFFFF
h = (h * 31 + state_id) & 0xFFFFFFFF

# Handle state_id based on its type
if isinstance(state_id, str):
# If state_id is a string, hash each character
for char in state_id:
h = (h * 31 + ord(char)) & 0xFFFFFFFF
else:
# If state_id is an integer, hash it directly
h = (h * 31 + state_id) & 0xFFFFFFFF

return h

def __repr__(self):
Expand Down
3 changes: 3 additions & 0 deletions tests/mask_store/test_byte_fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,9 @@ def test_consume_prefix(self):
("[email protected]", (False, None)),
("user@", (True, b"")), # Live state
("invalid", (True, b"")) # Live state for [a-z]+
]),
('"[^"”“]+"', [
('\"key”', (False, None)),
])
]

Expand Down
48 changes: 0 additions & 48 deletions tests/mask_store/test_mask_store_go.py

This file was deleted.

107 changes: 107 additions & 0 deletions tests/mask_store/test_mask_store_misc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import sys
import os
import time
import unittest
sys.path.append(os.path.dirname(os.path.realpath(__file__)) + '/../..')
import syncode.common as common
from syncode.parsers.incremental_parser import ParseResult
from syncode.parse_result import AcceptSequence, RemainderState
from syncode.mask_store.mask_store import MaskStore
from syncode.parsers.grammars.grammar import Grammar
from tests.test_utils import CustomAssertMixin


class TestMaskGo(unittest.TestCase, CustomAssertMixin):
def setUp(self):
model = 'Qwen/Qwen2.5-1.5B-Instruct'
tokenizer = common.load_tokenizer(model)
self.mask_store = MaskStore.init_mask_store(grammar=Grammar('go'), tokenizer=tokenizer, use_cache=False, mode='grammar_mask')
return super().setUp()

def test_mask(self):
r = ParseResult({AcceptSequence(['DECIMAL_LIT', 'PLUS'])}, b'1', RemainderState.MAYBE_COMPLETE)
self.mask_store.get_accept_mask(r, get_list=True)
result_list = self.mask_store.get_accept_mask(r, get_list=True)
for token in [' +', ' +=', ' ++']:
self.assertInWithLimit(token, result_list, f"{token} not found in result list")

def test_mask2(self):
r = ParseResult({AcceptSequence(['EOS'])}, b'\n // 1.', RemainderState.MAYBE_COMPLETE)
result_list = self.mask_store.get_accept_mask(r, get_list=True)
self.assertTrue(len(result_list) > 32000, "Result list is smaller than expected")

def test_mask3(self):
r = ParseResult({AcceptSequence(['__ANON_14'])}, b'', RemainderState.COMPLETE)
result_list = self.mask_store.get_accept_mask(r, get_list=True)
# Uncomment the following line if you want to assert presence of specific tokens
self.assertInWithLimit(":=", result_list, ":= not found in result list")

def test_mask4(self):
r = ParseResult({AcceptSequence(['__IGNORE_0'])}, b'', RemainderState.COMPLETE)
self.assertInWithLimit("\t", self.mask_store.get_accept_mask(r, get_list=True), "Tab character not found in result list")

def test_mask5(self):
r = ParseResult({AcceptSequence(['LBRACE', '__IGNORE_0'])}, b'{', RemainderState.MAYBE_COMPLETE)
self.assertInWithLimit("\t", self.mask_store.get_accept_mask(r, get_list=True), "Tab character not found in result list")

def test_mask6(self):
r = ParseResult({AcceptSequence(['NAME'])}, b'for', RemainderState.MAYBE_COMPLETE)
self.assertInWithLimit(" {", self.mask_store.get_accept_mask(r, get_list=True), "Opening brace not found in result list")


class TestMaskJSON(unittest.TestCase, CustomAssertMixin):
def setUp(self):
model = 'google/gemma-2-2b-it'
tokenizer = common.load_tokenizer(model)

custom_json_grammar = f"""
?start: start_value
?start_value: object
| array

?value: object
| array
| EMPTY_STRING
| NONEMPTY_STRING
| SIGNED_NUMBER -> number
| "true" -> true
| "false" -> false
| "null" -> null

array : "[" [value ("," value)*] "]"
object : "{" [pair ("," pair)*] "}"
pair : NONEMPTY_STRING ":" value

NONEMPTY_STRING: /\"[^"”“]+\"/
EMPTY_STRING: /\"\"/

DIGIT: "0".."9"
HEXDIGIT: "a".."f"|"A".."F"|DIGIT
INT: DIGIT+
SIGNED_INT: ["+"|"-"] INT
DECIMAL: INT "." INT? | "." INT


_EXP: ("e"|"E") SIGNED_INT
FLOAT: INT _EXP | DECIMAL _EXP?
NUMBER: FLOAT | INT
SIGNED_NUMBER: ["+"|"-"] NUMBER
WS: /[ \t\f\r\n]/+

%ignore WS
"""
self.mask_store = MaskStore.init_mask_store(grammar=Grammar(custom_json_grammar), tokenizer=tokenizer, use_cache=False, mode='grammar_mask')
return super().setUp()

def test_mask(self):
r = ParseResult({AcceptSequence(['NONEMPTY_STRING'])}, b'"key', RemainderState.INCOMPLETE)
result_list = self.mask_store.get_accept_mask(r, get_list=True)
self.assertInWithLimit('"', result_list, '" not found in result list')
self.assertNotIn('”', result_list)
self.assertNotIn('“', result_list)


if __name__ == '__main__':
# Run JSON tests
suite = unittest.TestLoader().loadTestsFromTestCase(TestMaskJSON)
unittest.TextTestRunner().run(suite)
2 changes: 1 addition & 1 deletion tests/mask_store/test_mask_store_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestDFAMaskLlama(unittest.TestCase, CustomAssertMixin):
mask_store = MaskStore.init_mask_store(
grammar=Grammar('python'),
tokenizer=tokenizer,
use_cache=True,
use_cache=False,
indent=True,
mode="grammar_strict"
)
Expand Down
11 changes: 10 additions & 1 deletion tests/parser/test_grammar_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,13 @@ def test_json_parser2(self):
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert r.remainder == ''
assert r.remainder_state == RemainderState.COMPLETE


def test_json_parser3(self):
# Tests when the last incomplete word is unparsed
inc_parser.reset()
partial_code = '{\n "key'
r = inc_parser.get_acceptable_next_terminals(partial_code)
assert AcceptSequence(['NONEMPTY_STRING']) in r.accept_sequences

if __name__ == '__main__':
unittest.main()