Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/hf_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ Error HFTokenizer::load(const std::string& path) {
json parsed_json;
try {
parsed_json = json::parse(contents);
} catch (const json::exception& e) {
} catch (const std::exception& e) {
TK_LOG(Error, "Error parsing json file: %s", e.what());
return Error::LoadFailure;
}
Expand All @@ -76,7 +76,7 @@ Error HFTokenizer::load(const std::string& path) {

// Store for future use.
special_token_map_.emplace(std::move(special_token_map));
} catch (const json::out_of_range& e) {
} catch (const std::exception& e) {
TK_LOG(Info, "Could not parse special tokens: %s", e.what());
return Error::LoadFailure;
}
Expand All @@ -96,7 +96,7 @@ Error HFTokenizer::load(const std::string& path) {

auto token_map = TK_UNWRAP(detail::build_token_map(std::move(token_pairs)));
token_map_.emplace(std::move(token_map));
} catch (const json::out_of_range& e) {
} catch (const std::exception& e) {
TK_LOG(Info, "Could not parse tokens: %s", e.what());
return Error::LoadFailure;
}
Expand All @@ -114,7 +114,7 @@ Error HFTokenizer::load(const std::string& path) {
} else {
TK_LOG(Info, "Normalizer field is null, skipping");
}
} catch (const json::out_of_range& e) {
} catch (const std::exception& e) {
// No "Normalizer" field found
TK_LOG(
Info,
Expand All @@ -129,7 +129,7 @@ Error HFTokenizer::load(const std::string& path) {
.parse_json(parsed_json.at("pre_tokenizer"))
.create();
TK_LOG(Info, "Pretokenizer set up");
} catch (const json::out_of_range& e) {
} catch (const std::exception& e) {
TK_LOG(Info, "Could not parse pre_tokenizer: %s", e.what());
return Error::LoadFailure;
}
Expand All @@ -138,7 +138,7 @@ Error HFTokenizer::load(const std::string& path) {
try {
_decoder =
TokenDecoderConfig().parse_json(parsed_json.at("decoder")).create();
} catch (const json::out_of_range&) {
} catch (const std::exception&) {
// No decoder specified
}

Expand Down Expand Up @@ -192,7 +192,7 @@ Error HFTokenizer::load(const std::string& path) {
"Built merge ranks map with %" PRId64 " entries",
static_cast<int64_t>(merge_ranks.size()));
merge_ranks_.emplace(std::move(merge_ranks));
} catch (const json::out_of_range& e) {
} catch (const std::exception& e) {
TK_LOG(Error, "Could not parse merges: %s", e.what());
return Error::LoadFailure;
}
Expand All @@ -211,7 +211,7 @@ Error HFTokenizer::load(const std::string& path) {
json parsed_config_json;
try {
parsed_config_json = json::parse(config_contents);
} catch (const json::exception& e) {
} catch (const std::exception& e) {
TK_LOG(Error, "Error parsing model config json json file: %s", e.what());
return Error::LoadFailure;
}
Expand Down Expand Up @@ -239,7 +239,7 @@ Error HFTokenizer::load(const std::string& path) {
}
bos_tok_ = *bos_res;
eos_tok_ = *eos_res;
} catch (const json::out_of_range& e) {
} catch (const std::exception& e) {
TK_LOG(Error, "Could not eos/bos from tokenizer config: %s", e.what());
return Error::LoadFailure;
}
Expand Down
50 changes: 43 additions & 7 deletions src/pre_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,15 @@ PreTokenizer::Ptr PreTokenizerConfig::create() const {
}

// Validate invert parameter
bool invert_flag = invert ? *invert : false;
if (invert_flag) {
const bool invert_flag = invert ? *invert : false;
const bool delimiter_flag = is_delimiter ? *is_delimiter : false;
if (invert_flag && delimiter_flag) {
throw std::runtime_error(
"invert=true is not supported for Split PreTokenizer. Only invert=false is supported.");
"invert=true is not supported for Split PreTokenizer with a String pattern.");
}

return PreTokenizer::Ptr(new RegexPreTokenizer(
*pattern, is_delimiter ? *is_delimiter : false, behavior_str));
return PreTokenizer::Ptr(
new RegexPreTokenizer(*pattern, delimiter_flag, behavior_str));
}
if (type == "Digits") {
if (individual_digits) {
Expand Down Expand Up @@ -143,16 +144,51 @@ PreTokenizerConfig& PreTokenizerConfig::parse_json(const json& json_config) {

// RegexPreTokenizer ///////////////////////////////////////////////////////////

namespace {

// Make Hugging Face Split patterns RE2-compatible by:
// 1) removing the negative look-ahead "\s+(?!\S)" (→ "\s+$")
// 2) expanding the inline case-insensitive contractions
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)" into explicit alternations.
static void replace_all_in_place(
std::string& input,
const std::string& needle,
const std::string& replacement) {
if (needle.empty()) {
return;
}
size_t search_pos = 0;
while ((search_pos = input.find(needle, search_pos)) != std::string::npos) {
input.replace(search_pos, needle.size(), replacement);
search_pos += replacement.size();
}
}

static std::string make_re2_compatible(std::string pattern) {
const std::string lookahead_trailing_space = R"(\s+(?!\S))";
const std::string trailing_space_replacement = R"(\s+$)";
replace_all_in_place(
pattern, lookahead_trailing_space, trailing_space_replacement);
const std::string ci_contractions = R"((?i:'s|'t|'re|'ve|'m|'ll|'d))";
const std::string contractions_expanded =
"(?:'s|'S|'t|'T|'re|'RE|'ve|'VE|'m|'M|'ll|'LL|'d|'D)";
replace_all_in_place(pattern, ci_contractions, contractions_expanded);
return pattern;
}

} // namespace

std::unique_ptr<IRegex> RegexPreTokenizer::create_regex_(
const std::string& pattern) {
assert(!pattern.empty());
return TK_UNWRAP_THROW(create_regex(pattern));
return TK_UNWRAP_THROW(create_regex(make_re2_compatible(pattern)));
}

std::vector<std::string> RegexPreTokenizer::pre_tokenize(
const std::string& input) const {
if (!regex_)
if (!regex_) {
return {};
}

std::vector<std::string> results;
auto matches = regex_->find_all(input);
Expand Down
13 changes: 13 additions & 0 deletions test/test_hf_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,16 @@ def test_llama3_2_1b(self) -> None:
tokens = tokenizer.encode(PROMPT)
cpp_tokens = cpp_tokenizer.encode(PROMPT, bos=1)
self.assertEqual(tokens, cpp_tokens)

def test_phi_4_mini(self) -> None:
tokenizer = AutoTokenizer.from_pretrained(
"software-mansion/react-native-executorch-phi-4-mini"
)
tokenizer_path = tokenizer.save_pretrained(self.temp_dir.name)[-1]

cpp_tokenizer = CppHFTokenizer()
cpp_tokenizer.load(tokenizer_path)

tokens = tokenizer.encode(PROMPT)
cpp_tokens = cpp_tokenizer.encode(PROMPT)
self.assertEqual(tokens, cpp_tokens)
3 changes: 3 additions & 0 deletions third-party/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def define_common_targets():
exported_headers = subdir_glob([
("llama.cpp-unicode/include", "*.h"),
]),
compiler_flags = [
"-Wno-error=deprecated-declarations",
],
visibility = ["@EXECUTORCH_CLIENTS", "//pytorch/tokenizers/..."],
)

Expand Down