1+ import sys
2+ import os
3+ import time
4+ import unittest
5+ sys .path .append (os .path .dirname (os .path .realpath (__file__ )) + '/../..' )
6+ import syncode .common as common
7+ from syncode .parsers .incremental_parser import ParseResult
8+ from syncode .parse_result import AcceptSequence , RemainderState
9+ from syncode .mask_store .mask_store import MaskStore
10+ from syncode .parsers .grammars .grammar import Grammar
11+ from tests .test_utils import CustomAssertMixin
12+
13+
14+ class TestMaskGo (unittest .TestCase , CustomAssertMixin ):
15+ def setUp (self ):
16+ model = 'Qwen/Qwen2.5-1.5B-Instruct'
17+ tokenizer = common .load_tokenizer (model )
18+ self .mask_store = MaskStore .init_mask_store (grammar = Grammar ('go' ), tokenizer = tokenizer , use_cache = False , mode = 'grammar_mask' )
19+ return super ().setUp ()
20+
21+ def test_mask (self ):
22+ r = ParseResult ({AcceptSequence (['DECIMAL_LIT' , 'PLUS' ])}, b'1' , RemainderState .MAYBE_COMPLETE )
23+ self .mask_store .get_accept_mask (r , get_list = True )
24+ result_list = self .mask_store .get_accept_mask (r , get_list = True )
25+ for token in [' +' , ' +=' , ' ++' ]:
26+ self .assertInWithLimit (token , result_list , f"{ token } not found in result list" )
27+
28+ def test_mask2 (self ):
29+ r = ParseResult ({AcceptSequence (['EOS' ])}, b'\n // 1.' , RemainderState .MAYBE_COMPLETE )
30+ result_list = self .mask_store .get_accept_mask (r , get_list = True )
31+ self .assertTrue (len (result_list ) > 32000 , "Result list is smaller than expected" )
32+
33+ def test_mask3 (self ):
34+ r = ParseResult ({AcceptSequence (['__ANON_14' ])}, b'' , RemainderState .COMPLETE )
35+ result_list = self .mask_store .get_accept_mask (r , get_list = True )
36+ # Uncomment the following line if you want to assert presence of specific tokens
37+ self .assertInWithLimit (":=" , result_list , ":= not found in result list" )
38+
39+ def test_mask4 (self ):
40+ r = ParseResult ({AcceptSequence (['__IGNORE_0' ])}, b'' , RemainderState .COMPLETE )
41+ self .assertInWithLimit ("\t " , self .mask_store .get_accept_mask (r , get_list = True ), "Tab character not found in result list" )
42+
43+ def test_mask5 (self ):
44+ r = ParseResult ({AcceptSequence (['LBRACE' , '__IGNORE_0' ])}, b'{' , RemainderState .MAYBE_COMPLETE )
45+ self .assertInWithLimit ("\t " , self .mask_store .get_accept_mask (r , get_list = True ), "Tab character not found in result list" )
46+
47+ def test_mask6 (self ):
48+ r = ParseResult ({AcceptSequence (['NAME' ])}, b'for' , RemainderState .MAYBE_COMPLETE )
49+ self .assertInWithLimit (" {" , self .mask_store .get_accept_mask (r , get_list = True ), "Opening brace not found in result list" )
50+
51+
52+ class TestMaskJSON (unittest .TestCase , CustomAssertMixin ):
53+ def setUp (self ):
54+ model = 'google/gemma-2-2b-it'
55+ tokenizer = common .load_tokenizer (model )
56+
57+ custom_json_grammar = f"""
58+ ?start: start_value
59+ ?start_value: object
60+ | array
61+
62+ ?value: object
63+ | array
64+ | EMPTY_STRING
65+ | NONEMPTY_STRING
66+ | SIGNED_NUMBER -> number
67+ | "true" -> true
68+ | "false" -> false
69+ | "null" -> null
70+
71+ array : "[" [value ("," value)*] "]"
72+ object : "{ " [pair (" ," pair)*] " } "
73+ pair : NONEMPTY_STRING ":" value
74+
75+ NONEMPTY_STRING: /\" [^"”“]+\" /
76+ EMPTY_STRING: /\" \" /
77+
78+ DIGIT: "0".."9"
79+ HEXDIGIT: "a".."f"|"A".."F"|DIGIT
80+ INT: DIGIT+
81+ SIGNED_INT: ["+"|"-"] INT
82+ DECIMAL: INT "." INT? | "." INT
83+
84+
85+ _EXP: ("e"|"E") SIGNED_INT
86+ FLOAT: INT _EXP | DECIMAL _EXP?
87+ NUMBER: FLOAT | INT
88+ SIGNED_NUMBER: ["+"|"-"] NUMBER
89+ WS: /[ \t \f \r \n ]/+
90+
91+ %ignore WS
92+ """
93+ self .mask_store = MaskStore .init_mask_store (grammar = Grammar (custom_json_grammar ), tokenizer = tokenizer , use_cache = False , mode = 'grammar_mask' )
94+ return super ().setUp ()
95+
96+ def test_mask (self ):
97+ r = ParseResult ({AcceptSequence (['NONEMPTY_STRING' ])}, b'"key' , RemainderState .INCOMPLETE )
98+ result_list = self .mask_store .get_accept_mask (r , get_list = True )
99+ self .assertInWithLimit ('"' , result_list , '" not found in result list' )
100+ self .assertNotIn ('”' , result_list )
101+ self .assertNotIn ('“' , result_list )
102+
103+
104+ if __name__ == '__main__' :
105+ # Run JSON tests
106+ suite = unittest .TestLoader ().loadTestsFromTestCase (TestMaskJSON )
107+ unittest .TextTestRunner ().run (suite )
0 commit comments