@@ -226,7 +226,42 @@ def _clean_up_pretty_token(token: str) -> str:
226226 return token .replace ("\n " , "\\ n" ).strip ()
227227
228228
229- def _convert_ids_to_pretty_tokens (ids : Tensor , tokenizer : TokenizerLike ) -> List [str ]:
229+ def _encode_with_offsets (
230+ txt : str ,
231+ tokenizer : TokenizerLike ,
232+ add_special_tokens : bool = True ,
233+ ** kwargs : Any ,
234+ ) -> Tuple [List [int ], List [Tuple [int , int ]]]:
235+ enc = tokenizer (
236+ txt ,
237+ return_offsets_mapping = True ,
238+ add_special_tokens = add_special_tokens ,
239+ ** kwargs ,
240+ )
241+ input_ids = cast (List [int ], enc ["input_ids" ])
242+ offset_mapping = cast (List [Tuple [int , int ]], enc ["offset_mapping" ])
243+ assert len (input_ids ) == len (offset_mapping ), (
244+ f"{ len (input_ids )} != { len (offset_mapping )} : { txt } -> "
245+ f"{ input_ids } , { offset_mapping } "
246+ )
247+ # For the case where offsets are not set properly (the end and start are
248+ # equal for all tokens - fall back on the start of the next span in the
249+ # offset mapping)
250+ offset_mapping_corrected = []
251+ for i , (start , end ) in enumerate (offset_mapping ):
252+ if start == end :
253+ if (i + 1 ) < len (offset_mapping ):
254+ end = offset_mapping [i + 1 ][0 ]
255+ else :
256+ end = len (txt )
257+ offset_mapping_corrected .append ((start , end ))
258+ return input_ids , offset_mapping_corrected
259+
260+
261+ def _convert_ids_to_pretty_tokens (
262+ ids : Tensor ,
263+ tokenizer : TokenizerLike ,
264+ ) -> List [str ]:
230265 """
231266 Convert ids to tokens without ugly unicode characters (e.g., Ġ). See:
232267 https://github.com/huggingface/transformers/issues/4786 and
@@ -241,32 +276,26 @@ def _convert_ids_to_pretty_tokens(ids: Tensor, tokenizer: TokenizerLike) -> List
241276 > used spaces in its process
242277 """
243278 txt = tokenizer .decode (ids )
279+ input_ids : Optional [List [int ]] = None
244280 # Don't add special tokens (they're either already there, or we don't want them)
245- enc = tokenizer ( txt , return_offsets_mapping = True , add_special_tokens = False )
246- input_ids = cast ( List [ int ], enc [ "input_ids" ])
247- offset_mapping = cast ( List [ Tuple [ int , int ]], enc [ "offset_mapping" ] )
281+ input_ids , offset_mapping = _encode_with_offsets (
282+ txt , tokenizer , add_special_tokens = False
283+ )
248284
249285 pretty_tokens = []
250286 end_prev = - 1
251287 idx = 0
252- for i , ( input_id , offset ) in enumerate (zip ( input_ids , offset_mapping ) ):
288+ for i , offset in enumerate (offset_mapping ):
253289 start , end = offset
254- if start == end :
255- # For the case where offsets are not set properly (the end and start are
256- # equal for all tokens - fall back on the start of the next span in the
257- # offset mapping)
258- if (i + 1 ) < len (input_ids ):
259- end = offset_mapping [i + 1 ][0 ]
260- else :
261- end = len (txt )
262- if input_id != ids [idx ]:
290+ if input_ids [i ] != ids [idx ]:
263291 # When the re-encoded string doesn't match the original encoding we skip
264292 # this token and hope for the best, falling back on a naive method. This
265293 # can happen when a tokenizer might add a token that corresponds to
266294 # a space only when add_special_tokens=False.
267295 warnings .warn (
268- f"(i={ i } ) input_id { input_id } != ids[idx] { ids [idx ]} (corresponding "
269- f"to text: { repr (txt [start :end ])} ). Skipping this token." ,
296+ f"(i={ i } , idx={ idx } ) input_ids[i] { input_ids [i ]} != ids[idx] "
297+ f"{ ids [idx ]} (corresponding to text: { repr (txt [start :end ])} ). "
298+ "Skipping this token." ,
270299 stacklevel = 2 ,
271300 )
272301 continue
0 commit comments