44from  typing  import  Any , Optional 
55
66import  pytest 
7- from  transformers  import  AutoTokenizer 
7+ from  transformers  import  (AutoTokenizer , PreTrainedTokenizer ,
8+                           PreTrainedTokenizerFast )
89
910from  vllm .inputs  import  token_inputs 
1011from  vllm .sequence  import  Logprob , SamplingParams , Sequence , SequenceGroup 
11- from  vllm .transformers_utils .detokenizer  import  (Detokenizer ,
12-                                                  detokenize_incrementally )
12+ from  vllm .transformers_utils .detokenizer  import  Detokenizer 
1313from  vllm .transformers_utils .tokenizer_group  import  get_tokenizer_group 
1414from  vllm .transformers_utils .tokenizers .mistral  import  MistralTokenizer 
15+ from  vllm .v1 .engine  import  EngineCoreRequest 
16+ from  vllm .v1 .engine .detokenizer  import  (FastIncrementalDetokenizer ,
17+                                         IncrementalDetokenizer ,
18+                                         SlowIncrementalDetokenizer )
19+ 
20+ SPECIAL_TOKS_TRUTH  =  [
21+     "Some text with adjacent special tokens                <|padding|><|padding|><fim_prefix><fim_middle><fim_suffix>other text<fim_pad>" ,  # noqa 
22+ ]
1523
1624TRUTH  =  [
1725    "Hello here, this is a simple test" ,
2230    # incomplete UTF-8 characters 
2331    # see https://github.com/vllm-project/vllm/pull/9625 
2432    "ပုံပြင်လေးပြောပြပါ်" ,
25- ]
33+ ] +  SPECIAL_TOKS_TRUTH 
34+ 
2635TOKENIZERS  =  [
2736    "facebook/opt-125m" ,
2837    "gpt2" ,
3847]
3948
4049
41- def  _run_incremental_decode (tokenizer , all_input_ids ,
42-                             skip_special_tokens : bool , starting_index : int ):
43-     decoded_text  =  "" 
44-     offset  =  0 
45-     token_offset  =  0 
46-     prev_tokens  =  None 
47-     for  i  in  range (starting_index , len (all_input_ids )):
48-         new_tokens , text , offset , token_offset  =  detokenize_incrementally (
49-             tokenizer ,
50-             all_input_ids [:i  +  1 ],
51-             prev_tokens ,
52-             offset ,
53-             token_offset ,
54-             skip_special_tokens = skip_special_tokens )
55-         decoded_text  +=  text 
56-         if  prev_tokens  is  None :
57-             prev_tokens  =  new_tokens 
58-         else :
59-             prev_tokens  +=  new_tokens 
60-     return  decoded_text 
50+ def  _run_incremental_decode (tokenizer ,
51+                             all_input_ids ,
52+                             skip_special_tokens : bool ,
53+                             starting_index : int ,
54+                             spaces_between_special_tokens : bool  =  True ,
55+                             fast : Optional [bool ] =  None ):
56+ 
57+     prompt_token_ids  =  all_input_ids [:starting_index ]
58+ 
59+     params  =  SamplingParams (
60+         skip_special_tokens = skip_special_tokens ,
61+         spaces_between_special_tokens = spaces_between_special_tokens ,
62+     )
63+     request  =  EngineCoreRequest ("" , "" , prompt_token_ids , None , None , None ,
64+                                 params , None , 0.0 , None )
65+ 
66+     if  fast  is  None :
67+         detokenizer  =  IncrementalDetokenizer .from_new_request (
68+             tokenizer , request )
69+     elif  fast :
70+         detokenizer  =  FastIncrementalDetokenizer (tokenizer , request )
71+     else :
72+         detokenizer  =  SlowIncrementalDetokenizer (tokenizer , request )
73+ 
74+     output_text  =  "" 
75+     for  i , token_id  in  enumerate (all_input_ids [starting_index :]):
76+         detokenizer .update ([token_id ], False )
77+         finished  =  i  ==  len (all_input_ids ) -  1 
78+         output_text  +=  detokenizer .get_next_output_text (finished , delta = True )
79+ 
80+     return  output_text , detokenizer .output_token_ids 
6181
6282
6383@pytest .fixture  
@@ -85,11 +105,13 @@ def test_mistral_edge_case(tokenizer, truth):
85105    starting_index  =  0 
86106    all_input_ids  =  tokenizer (truth , add_special_tokens = False ).input_ids 
87107
88-     decoded_text  =  _run_incremental_decode (tokenizer ,
89-                                            all_input_ids ,
90-                                            skip_special_tokens = True ,
91-                                            starting_index = starting_index )
108+     decoded_text , out_ids  =  _run_incremental_decode (
109+         tokenizer ,
110+         all_input_ids ,
111+         skip_special_tokens = True ,
112+         starting_index = starting_index )
92113    assert  decoded_text  ==  truth 
114+     assert  out_ids  ==  all_input_ids [starting_index :]
93115
94116
95117@pytest .fixture  
@@ -106,40 +128,86 @@ def skip_special_tokens(request, tokenizer_name) -> Generator[bool, Any, None]:
106128@pytest .mark .parametrize ("with_prompt" , [True , False ]) 
107129@pytest .mark .parametrize ("tokenizer_name" , TOKENIZERS ) 
108130@pytest .mark .parametrize ("skip_special_tokens" , (True , False ), indirect = True ) 
109- def  test_decode_streaming (tokenizer , truth , with_prompt , skip_special_tokens ):
131+ @pytest .mark .parametrize ("spaces_between_special_tokens" , (True , False )) 
132+ @pytest .mark .parametrize ("fast" , (True , False )) 
133+ def  test_decode_streaming (tokenizer , truth , with_prompt , skip_special_tokens ,
134+                           spaces_between_special_tokens , fast ):
135+     if  fast  and  not  isinstance (tokenizer , PreTrainedTokenizerFast ):
136+         pytest .skip ()
137+ 
138+     if  skip_special_tokens  and  not  spaces_between_special_tokens :
139+         pytest .skip ()
140+ 
141+     if  not  fast  and  isinstance (tokenizer , PreTrainedTokenizerFast ):
142+         # Fix up inconsistency in fast/slow tokenizer behaviour. 
143+         tokenizer .add_special_tokens ({
144+             "additional_special_tokens" : [
145+                 at  for  at  in 
146+                 tokenizer ._tokenizer .get_added_tokens_decoder ().values ()
147+                 if  at .special 
148+             ]
149+         })
150+ 
151+     extra_decode_args  =  {} if  not  isinstance (tokenizer ,  PreTrainedTokenizer ) \
152+         else  {"spaces_between_special_tokens" : spaces_between_special_tokens }
153+ 
154+     truth_tokens  =  tokenizer (truth , add_special_tokens = False ).input_ids 
155+     if  tokenizer .bos_token_id  is  not None :
156+         truth_tokens .insert (0 , tokenizer .bos_token_id )
157+     truth_tokens .append (tokenizer .eos_token_id )
158+ 
159+     new_truth  =  tokenizer .decode (truth_tokens ,
160+                                  skip_special_tokens = skip_special_tokens ,
161+                                  ** extra_decode_args )
162+ 
110163    if  with_prompt :
111-         truth_tokens  =  tokenizer (truth , add_special_tokens = False ).input_ids 
112-         prompt_input_ids  =  truth_tokens [:len (truth ) //  2 ]
113-         generated_input_ids  =  truth_tokens [len (truth ) //  2 :]
164+         num_prompt_tokens  =  len (
165+             tokenizer (truth [:len (truth ) //  2 ],
166+                       add_special_tokens = False ).input_ids )
167+         if  tokenizer .bos_token_id  is  not None :
168+             num_prompt_tokens  +=  1 
169+ 
170+         prompt_input_ids  =  truth_tokens [:num_prompt_tokens ]
171+         generated_input_ids  =  truth_tokens [num_prompt_tokens :]
114172        all_input_ids  =  prompt_input_ids  +  generated_input_ids 
115173        starting_index  =  len (prompt_input_ids )
116174        prompt  =  tokenizer .decode (prompt_input_ids ,
117-                                   skip_special_tokens = skip_special_tokens )
118-         generated  =  truth [len (prompt ):]
175+                                   skip_special_tokens = skip_special_tokens ,
176+                                   ** extra_decode_args )
177+ 
178+         generated  =  new_truth [len (prompt ):]
119179    else :
120-         generated  =  truth 
180+         generated  =  new_truth 
121181        starting_index  =  0 
122-         all_input_ids  =  tokenizer (truth , add_special_tokens = False ).input_ids 
123-     if  skip_special_tokens :
124-         if  tokenizer .bos_token_id  is  not None :
125-             all_input_ids  =  [tokenizer .bos_token_id ] +  all_input_ids 
126-             starting_index  +=  1 
127-         all_input_ids  =  all_input_ids  +  [tokenizer .eos_token_id ]
182+         all_input_ids  =  truth_tokens 
128183
129-     decoded_text  =  _run_incremental_decode (
184+     decoded_text ,  out_ids  =  _run_incremental_decode (
130185        tokenizer ,
131186        all_input_ids ,
132187        skip_special_tokens = skip_special_tokens ,
133-         starting_index = starting_index )
188+         starting_index = starting_index ,
189+         spaces_between_special_tokens = spaces_between_special_tokens ,
190+         fast = fast )
134191
135192    assert  decoded_text  ==  generated 
193+     assert  out_ids  ==  all_input_ids [starting_index :]
136194
137-     decoded_text  =  _run_incremental_decode (
195+ 
196+ @pytest .mark .parametrize ("tokenizer_name" , TOKENIZERS ) 
197+ @pytest .mark .parametrize ("fast" , (True , False )) 
198+ def  test_oov_decode (tokenizer , fast ):
199+     if  fast  and  not  isinstance (tokenizer , PreTrainedTokenizerFast ):
200+         pytest .skip ()
201+ 
202+     decoded_text , out_ids  =  _run_incremental_decode (
138203        tokenizer , [len (tokenizer )],
139-         skip_special_tokens = skip_special_tokens ,
140-         starting_index = starting_index )
204+         skip_special_tokens = True ,
205+         starting_index = 0 ,
206+         spaces_between_special_tokens = True ,
207+         fast = fast )
141208
142209    assert  decoded_text  ==  '' 
210+     assert  out_ids  ==  [len (tokenizer )]
143211
144212
145213@pytest .fixture  
@@ -165,15 +233,14 @@ def detokenizer(tokenizer_name: str) -> Detokenizer:
165233@pytest .fixture (name = "complete_sequence_token_ids" ) 
166234def  create_complete_sequence_token_ids (complete_sequence : str ,
167235                                       tokenizer ) ->  list [int ]:
168-     complete_sequence_token_ids  =  tokenizer (complete_sequence ).input_ids 
169-     return  complete_sequence_token_ids 
236+     return  tokenizer (complete_sequence , add_special_tokens = False ).input_ids 
170237
171238
172239def  create_sequence (prompt_token_ids = None ):
173-     prompt_token_ids  =  prompt_token_ids  or  [1 ]
240+     prompt_token_ids  =  prompt_token_ids  or  []
174241    return  Sequence (
175242        seq_id = 0 ,
176-         inputs = token_inputs (prompt_token_ids ,  prompt = "<s>" ),
243+         inputs = token_inputs (prompt_token_ids ),
177244        block_size = 16 ,
178245    )
179246
@@ -224,7 +291,7 @@ def test_decode_sequence_logprobs(complete_sequence: str,
224291    assert  sequential_result  ==  "" .join (sequential_logprobs_text_chosen_token )
225292    assert  sequential_result  !=  "" .join (sequential_logprobs_text_other_token )
226293
227-     if  skip_special_tokens :
294+     if  not   skip_special_tokens :
228295        # Text for logprobs for the chosen token should be the same as the 
229296        # generated text. Note that this will only be true if we skip 
230297        # special tokens. 
@@ -233,10 +300,23 @@ def test_decode_sequence_logprobs(complete_sequence: str,
233300
234301@pytest .mark .parametrize ("complete_sequence" , TRUTH ) 
235302@pytest .mark .parametrize ("tokenizer_name" , TOKENIZERS ) 
236- def  test_decode_prompt_logprobs (complete_sequence_token_ids : list [int ],
303+ def  test_decode_prompt_logprobs (complete_sequence : str ,
304+                                 complete_sequence_token_ids : list [int ],
237305                                detokenizer : Detokenizer ):
306+ 
307+     # We want to use skip_special_tokens=False here but Mistral tokenizers 
308+     # don't support that. 
309+     if  complete_sequence  not  in SPECIAL_TOKS_TRUTH :
310+         skip_special_tokens  =  True 
311+     elif  not  isinstance (detokenizer .tokenizer_group .get_lora_tokenizer (None ),
312+                         MistralTokenizer ):
313+         skip_special_tokens  =  False 
314+     else :
315+         pytest .skip ("MistralTokenizers don't support " 
316+                     "skip_special_tokens=False" )
317+         return 
238318    """Verify Detokenizer decodes prompt logprobs correctly.""" 
239-     sampling_params  =  SamplingParams (skip_special_tokens = True ,
319+     sampling_params  =  SamplingParams (skip_special_tokens = skip_special_tokens ,
240320                                     prompt_logprobs = 1 )
241321
242322    # Run sequentially. 
@@ -256,8 +336,10 @@ def test_decode_prompt_logprobs(complete_sequence_token_ids: list[int],
256336    # decoded_prompt_logprobs doesn't contain the first token. 
257337    token_ids  =  complete_sequence_token_ids 
258338    tokenizer  =  detokenizer .get_tokenizer_for_seq (seq )
259-     text_full  =  tokenizer .decode (token_ids , skip_special_tokens = True )
260-     text_first  =  tokenizer .decode (token_ids [0 ], skip_special_tokens = True )
339+     text_full  =  tokenizer .decode (token_ids ,
340+                                  skip_special_tokens = skip_special_tokens )
341+     text_first  =  tokenizer .decode (token_ids [0 ],
342+                                   skip_special_tokens = skip_special_tokens )
261343    text  =  text_full [len (text_first ):]
262344
263345    # Text for logprobs for the chosen token should be the same as the 
0 commit comments