diff --git a/src/hf_tokenizer.cpp b/src/hf_tokenizer.cpp index 6d7e742..03d7816 100644 --- a/src/hf_tokenizer.cpp +++ b/src/hf_tokenizer.cpp @@ -56,7 +56,7 @@ Error HFTokenizer::load(const std::string& path) { json parsed_json; try { parsed_json = json::parse(contents); - } catch (const json::exception& e) { + } catch (const std::exception& e) { TK_LOG(Error, "Error parsing json file: %s", e.what()); return Error::LoadFailure; } @@ -76,7 +76,7 @@ Error HFTokenizer::load(const std::string& path) { // Store for future use. special_token_map_.emplace(std::move(special_token_map)); - } catch (const json::out_of_range& e) { + } catch (const std::exception& e) { TK_LOG(Info, "Could not parse special tokens: %s", e.what()); return Error::LoadFailure; } @@ -96,7 +96,7 @@ Error HFTokenizer::load(const std::string& path) { auto token_map = TK_UNWRAP(detail::build_token_map(std::move(token_pairs))); token_map_.emplace(std::move(token_map)); - } catch (const json::out_of_range& e) { + } catch (const std::exception& e) { TK_LOG(Info, "Could not parse tokens: %s", e.what()); return Error::LoadFailure; } @@ -114,7 +114,7 @@ Error HFTokenizer::load(const std::string& path) { } else { TK_LOG(Info, "Normalizer field is null, skipping"); } - } catch (const json::out_of_range& e) { + } catch (const std::exception& e) { // No "Normalizer" field found TK_LOG( Info, @@ -129,7 +129,7 @@ Error HFTokenizer::load(const std::string& path) { .parse_json(parsed_json.at("pre_tokenizer")) .create(); TK_LOG(Info, "Pretokenizer set up"); - } catch (const json::out_of_range& e) { + } catch (const std::exception& e) { TK_LOG(Info, "Could not parse pre_tokenizer: %s", e.what()); return Error::LoadFailure; } @@ -138,7 +138,7 @@ Error HFTokenizer::load(const std::string& path) { try { _decoder = TokenDecoderConfig().parse_json(parsed_json.at("decoder")).create(); - } catch (const json::out_of_range&) { + } catch (const std::exception&) { // No decoder specified } @@ -192,7 +192,7 @@ Error HFTokenizer::load(const std::string& path) { "Built merge ranks map with %" PRId64 " entries", static_cast(merge_ranks.size())); merge_ranks_.emplace(std::move(merge_ranks)); - } catch (const json::out_of_range& e) { + } catch (const std::exception& e) { TK_LOG(Error, "Could not parse merges: %s", e.what()); return Error::LoadFailure; } @@ -211,7 +211,7 @@ Error HFTokenizer::load(const std::string& path) { json parsed_config_json; try { parsed_config_json = json::parse(config_contents); - } catch (const json::exception& e) { + } catch (const std::exception& e) { TK_LOG(Error, "Error parsing model config json json file: %s", e.what()); return Error::LoadFailure; } @@ -239,7 +239,7 @@ Error HFTokenizer::load(const std::string& path) { } bos_tok_ = *bos_res; eos_tok_ = *eos_res; - } catch (const json::out_of_range& e) { + } catch (const std::exception& e) { TK_LOG(Error, "Could not eos/bos from tokenizer config: %s", e.what()); return Error::LoadFailure; } diff --git a/src/pre_tokenizer.cpp b/src/pre_tokenizer.cpp index 279fc39..42f0f97 100644 --- a/src/pre_tokenizer.cpp +++ b/src/pre_tokenizer.cpp @@ -48,14 +48,15 @@ PreTokenizer::Ptr PreTokenizerConfig::create() const { } // Validate invert parameter - bool invert_flag = invert ? *invert : false; - if (invert_flag) { + const bool invert_flag = invert ? *invert : false; + const bool delimiter_flag = is_delimiter ? *is_delimiter : false; + if (invert_flag && delimiter_flag) { throw std::runtime_error( - "invert=true is not supported for Split PreTokenizer. Only invert=false is supported."); + "invert=true is not supported for Split PreTokenizer with a String pattern."); } - return PreTokenizer::Ptr(new RegexPreTokenizer( - *pattern, is_delimiter ? *is_delimiter : false, behavior_str)); + return PreTokenizer::Ptr( + new RegexPreTokenizer(*pattern, delimiter_flag, behavior_str)); } if (type == "Digits") { if (individual_digits) { @@ -143,16 +144,51 @@ PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) { // RegexPreTokenizer /////////////////////////////////////////////////////////// +namespace { + +// Make Hugging Face Split patterns RE2-compatible by: +// 1) removing the negative look-ahead "\s+(?!\S)" (→ "\s+$") +// 2) expanding the inline case-insensitive contractions +// "(?i:'s|'t|'re|'ve|'m|'ll|'d)" into explicit alternations. +static void replace_all_in_place( + std::string& input, + const std::string& needle, + const std::string& replacement) { + if (needle.empty()) { + return; + } + size_t search_pos = 0; + while ((search_pos = input.find(needle, search_pos)) != std::string::npos) { + input.replace(search_pos, needle.size(), replacement); + search_pos += replacement.size(); + } +} + +static std::string make_re2_compatible(std::string pattern) { + const std::string lookahead_trailing_space = R"(\s+(?!\S))"; + const std::string trailing_space_replacement = R"(\s+$)"; + replace_all_in_place( + pattern, lookahead_trailing_space, trailing_space_replacement); + const std::string ci_contractions = R"((?i:'s|'t|'re|'ve|'m|'ll|'d))"; + const std::string contractions_expanded = + "(?:'s|'S|'t|'T|'re|'RE|'ve|'VE|'m|'M|'ll|'LL|'d|'D)"; + replace_all_in_place(pattern, ci_contractions, contractions_expanded); + return pattern; +} + +} // namespace + std::unique_ptr RegexPreTokenizer::create_regex_( const std::string& pattern) { assert(!pattern.empty()); - return TK_UNWRAP_THROW(create_regex(pattern)); + return TK_UNWRAP_THROW(create_regex(make_re2_compatible(pattern))); } std::vector RegexPreTokenizer::pre_tokenize( const std::string& input) const { - if (!regex_) + if (!regex_) { return {}; + } std::vector results; auto matches = regex_->find_all(input); diff --git a/test/test_hf_tokenizer.py b/test/test_hf_tokenizer.py index 6162dc1..dbed244 100644 --- a/test/test_hf_tokenizer.py +++ b/test/test_hf_tokenizer.py @@ -48,3 +48,16 @@ def test_llama3_2_1b(self) -> None: tokens = tokenizer.encode(PROMPT) cpp_tokens = cpp_tokenizer.encode(PROMPT, bos=1) self.assertEqual(tokens, cpp_tokens) + + def test_phi_4_mini(self) -> None: + tokenizer = AutoTokenizer.from_pretrained( + "software-mansion/react-native-executorch-phi-4-mini" + ) + tokenizer_path = tokenizer.save_pretrained(self.temp_dir.name)[-1] + + cpp_tokenizer = CppHFTokenizer() + cpp_tokenizer.load(tokenizer_path) + + tokens = tokenizer.encode(PROMPT) + cpp_tokens = cpp_tokenizer.encode(PROMPT) + self.assertEqual(tokens, cpp_tokens) diff --git a/third-party/targets.bzl b/third-party/targets.bzl index 8daebc1..9bdcd8c 100644 --- a/third-party/targets.bzl +++ b/third-party/targets.bzl @@ -12,6 +12,9 @@ def define_common_targets(): exported_headers = subdir_glob([ ("llama.cpp-unicode/include", "*.h"), ]), + compiler_flags = [ + "-Wno-error=deprecated-declarations", + ], visibility = ["@EXECUTORCH_CLIENTS", "//pytorch/tokenizers/..."], )