Skip to content

Add regex unit tests and enable shared linkage in fbcode #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions include/pytorch/tokenizers/pcre2_regex.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@ namespace tokenizers {
class Pcre2Regex : public IRegex {
public:
/**
* @brief Construct a PCRE2 regex with the given pattern.
*
* @brief Construct a PCRE2 regex.
*/
explicit Pcre2Regex(){};

/**
* @brief Compile the given regex pattern.
* @param pattern The regex pattern to compile.
* @return An Error object indicating success or failure of the compilation.
*/
explicit Pcre2Regex(const std::string& pattern);
virtual Error compile(const std::string& pattern) override;

/**
* @brief Destructor to clean up PCRE2 resources.
Expand All @@ -44,9 +49,6 @@ class Pcre2Regex : public IRegex {
private:
pcre2_code* regex_;
pcre2_match_data* match_data_;

friend Result<std::unique_ptr<IRegex>> create_fallback_regex(
const std::string& pattern);
};

} // namespace tokenizers
14 changes: 8 additions & 6 deletions include/pytorch/tokenizers/re2_regex.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ namespace tokenizers {
class Re2Regex : public IRegex {
public:
/**
* @brief Construct a RE2 regex with the given pattern.
*
* @brief Construct a RE2 regex.
*/
explicit Re2Regex() {}

/**
* @brief compile the given regex pattern.
* @param pattern The regex pattern to compile.
* @return An Error object indicating success or failure of the compilation.
*/
explicit Re2Regex(const std::string& pattern);
virtual Error compile(const std::string& pattern) override;

/**
* @brief Return all non-overlapping matches found in the input string.
Expand All @@ -36,9 +41,6 @@ class Re2Regex : public IRegex {

private:
std::unique_ptr<re2::RE2> regex_;

friend Result<std::unique_ptr<IRegex>> create_regex(
const std::string& pattern);
};

} // namespace tokenizers
23 changes: 13 additions & 10 deletions include/pytorch/tokenizers/regex.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ class IRegex {
public:
virtual ~IRegex() = default;

/**
* @brief Compile the given regex pattern.
* @param pattern The regex pattern to compile.
* @return An Error object indicating success or failure of the compilation.
*/
virtual Error compile(const std::string& pattern) = 0;

/**
* @brief Find all non-overlapping matches in the input string.
*
Expand All @@ -37,6 +44,9 @@ class IRegex {
virtual std::vector<Match> find_all(const std::string& text) const = 0;
};

// Function pointer type for create_fallback_regex implementations
using FallbackRegexFn = Result<std::unique_ptr<IRegex>> (*)(const std::string&);

/**
* @brief Creates a regex instance. If no strong symbol defined, only
* uses RE2. This is a weak symbol to allow other regex libraries to be
Expand All @@ -47,15 +57,8 @@ class IRegex {
*/
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern);

/**
* @brief Creates a fallback regex instance. If no strong symbol defined,
* returns Error, otherwise uses PCRE2 and std::regex.
* This is a weak symbol to allow other regex libraries to be used.
*
* @param pattern The regex pattern to compile.
* @return A unique pointer to an IRegex-compatible object.
*/
Result<std::unique_ptr<IRegex>> create_fallback_regex(
const std::string& pattern) TK_WEAK;
bool register_override_fallback_regex(FallbackRegexFn fn);

FallbackRegexFn get_fallback_regex();

} // namespace tokenizers
12 changes: 8 additions & 4 deletions include/pytorch/tokenizers/std_regex.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ namespace tokenizers {
class StdRegex : public IRegex {
public:
/**
* @brief Construct a std::regex wrapper with the given pattern.
*
* @brief Construct a std::regex wrapper.
*/
explicit StdRegex() {}

/**
* @brief Compile the given regex pattern.
* @param pattern The regex pattern to compile.
* @throws std::regex_error if the pattern is invalid.
* @return An Error object indicating success or failure of the compilation.
*/
explicit StdRegex(const std::string& pattern);
virtual Error compile(const std::string& pattern) override;

/**
* @brief Find all non-overlapping matches in the input string.
Expand Down
19 changes: 12 additions & 7 deletions src/pcre2_regex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

namespace tokenizers {

Pcre2Regex::Pcre2Regex(const std::string& pattern)
: regex_(nullptr), match_data_(nullptr) {
Error Pcre2Regex::compile(const std::string& pattern) {
int error_code;
PCRE2_SIZE error_offset;

Expand All @@ -30,19 +29,24 @@ Pcre2Regex::Pcre2Regex(const std::string& pattern)
if (regex_ == nullptr) {
PCRE2_UCHAR error_buffer[256];
pcre2_get_error_message(error_code, error_buffer, sizeof(error_buffer));
std::cerr << "PCRE2 compilation failed at offset " << error_offset << ": "
<< error_buffer << std::endl;
return;
TK_LOG(
Error,
"PCRE2 compilation failed at offset %" PRId64 ": %s",
static_cast<int64_t>(error_offset),
error_buffer);
return Error::RegexFailure;
}

// Create match data
match_data_ = pcre2_match_data_create_from_pattern(regex_, nullptr);
if (match_data_ == nullptr) {
pcre2_code_free(regex_);
regex_ = nullptr;
std::cerr << "Failed to create PCRE2 match data" << std::endl;
return;
TK_LOG(Error, "Failed to create PCRE2 match data");
return Error::RegexFailure;
}

return Error::Ok;
}

Pcre2Regex::~Pcre2Regex() {
Expand All @@ -58,6 +62,7 @@ std::vector<Match> Pcre2Regex::find_all(const std::string& text) const {
std::vector<Match> result;

if (!regex_ || !match_data_) {
TK_LOG(Error, "Regex is not compiled or invalid, run compile() first");
return result;
}

Expand Down
16 changes: 15 additions & 1 deletion src/re2_regex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,29 @@

namespace tokenizers {

Re2Regex::Re2Regex(const std::string& pattern) {
Error Re2Regex::compile(const std::string& pattern) {
regex_ = std::make_unique<re2::RE2>(pattern);
// Warmup re2 as it is slow on the first run, void the return value as it's
// not needed Refer to
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
(void)regex_->ReverseProgramSize();
if (regex_->ok()) {
return Error::Ok;
} else {
TK_LOG(
Error,
"Failed to compile regex: %s, error: %s",
pattern.c_str(),
regex_->error().c_str());
return Error::RegexFailure;
}
}

std::vector<Match> Re2Regex::find_all(const std::string& text) const {
if (!regex_ || !regex_->ok()) {
TK_LOG(Error, "Regex is not compiled or invalid, run compile() first");
return std::vector<Match>{};
}
std::vector<Match> result;
re2::StringPiece input(text);
re2::StringPiece piece;
Expand Down
58 changes: 30 additions & 28 deletions src/regex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,50 +5,52 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
// A weak symbol for create_regex, only using RE2 regex library.
// Default implementation for create_regex, only using RE2 regex library.
// regex_lookahead.cpp has the implementation of create_regex with lookahead
// support, backed by PCRE2 and std::regex.

#include <pytorch/tokenizers/re2_regex.h>
#include <pytorch/tokenizers/regex.h>

#include <iostream>

namespace tokenizers {

// Default implementation that returns failure
static Result<std::unique_ptr<IRegex>> default_create_fallback_regex(
const std::string& pattern) {
(void)pattern;
return tokenizers::Error::RegexFailure;
}

FallbackRegexFn fallback_regex = default_create_fallback_regex;

bool register_override_fallback_regex(FallbackRegexFn fn) {
TK_LOG(Info, "Registering override fallback regex");
fallback_regex = fn;
return true;
}

FallbackRegexFn get_fallback_regex() {
return fallback_regex;
}

Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern) {
// Try RE2 first
auto re2 = std::make_unique<Re2Regex>("(" + pattern + ")");
auto re2 = std::make_unique<Re2Regex>();
auto err = re2->compile("(" + pattern + ")");

if (re2->regex_->ok()) {
if (err == Error::Ok) {
return static_cast<std::unique_ptr<IRegex>>(std::move(re2));
}

std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
std::cerr << "Error: " << (re2->regex_->error()) << std::endl;

if (re2->regex_->error_code() == re2::RE2::ErrorBadPerlOp) {
auto res = create_fallback_regex(pattern);
if (!res.ok()) {
std::cerr
<< "RE2 doesn't support lookahead patterns. "
<< "Link with the lookahead-enabled version of this library to enable support."
<< std::endl;
} else {
return res;
}
auto res = get_fallback_regex()(pattern);
if (!res.ok()) {
TK_LOG(
Error,
"RE2 doesn't support lookahead patterns. Link with `regex_lookahead` to enable support.");
} else {
return res;
}

return tokenizers::Error::RegexFailure;
}

#ifdef _MSC_VER
#pragma weak create_fallback_regex
#endif // _MSC_VER
Result<std::unique_ptr<IRegex>> create_fallback_regex(
const std::string& pattern) {
(void)pattern;
return tokenizers::Error::RegexFailure;
}

} // namespace tokenizers
31 changes: 15 additions & 16 deletions src/regex_lookahead.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,33 @@
namespace tokenizers {

/**
* @brief Factory function that creates a regex object using RE2 if possible.
* @brief Implementation of the fallback regex function with lookahead support.
* Falls back to PCRE2 if RE2 rejects the pattern due to lookahead.
* Falls back to std::regex if PCRE2 also fails.
*/

#ifdef _MSC_VER
#pragma weak create_fallback_regex
#endif // _MSC_VER
Result<std::unique_ptr<IRegex>> create_fallback_regex(
const std::string& pattern) {
auto pcre2 = std::make_unique<Pcre2Regex>("(" + pattern + ")");
TK_LOG(Info, "Creating PCRE2 regex");
auto pcre2 = std::make_unique<Pcre2Regex>();
auto err = pcre2->compile(pattern);

if (pcre2->regex_ != nullptr && pcre2->match_data_ != nullptr) {
std::cout
<< "RE2 is unable to support things such as negative lookaheads in "
<< pattern << ", using PCRE2 instead." << std::endl;
if (err == Error::Ok) {
return static_cast<std::unique_ptr<IRegex>>(std::move(pcre2));
}

// If PCRE2 also fails, fall back to std::regex
try {
std::cout << "PCRE2 failed to compile pattern, falling back to std::regex.";
auto std_regex = std::make_unique<StdRegex>("(" + pattern + ")");
auto std_regex = std::make_unique<StdRegex>();
err = std_regex->compile(pattern);
if (err == Error::Ok) {
TK_LOG(
Info, "PCRE2 failed to compile pattern, falling back to std::regex.");
return static_cast<std::unique_ptr<IRegex>>(std::move(std_regex));
} catch (const std::regex_error& e) {
std::cerr << "std::regex failed: " << e.what() << std::endl;
return tokenizers::Error::LoadFailure;
}

return tokenizers::Error::RegexFailure;
}

static bool registered =
register_override_fallback_regex(create_fallback_regex);

} // namespace tokenizers
13 changes: 12 additions & 1 deletion src/std_regex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,25 @@
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*
* @lint-ignore-every LICENSELINT
* @lint-ignore-every CLANGTIDY facebook-hte-StdRegexIsAwful
*/

#include <pytorch/tokenizers/std_regex.h>
#include <regex>

namespace tokenizers {

StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {}
Error StdRegex::compile(const std::string& pattern) {
try {
regex_ = std::regex(pattern);
return Error::Ok;
} catch (std::regex_error) {
TK_LOG(Error, "Failed to compile regex: %s", pattern.c_str());
return Error::RegexFailure;
}
}

std::vector<Match> StdRegex::find_all(const std::string& text) const {
std::vector<Match> result;
Expand Down
8 changes: 8 additions & 0 deletions targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,19 @@ def define_common_targets():
"src/std_regex.cpp",
],
exported_deps = [
":regex",
":headers",
],
compiler_flags = [
"-Wno-global-constructors",
"-Wno-missing-prototypes",
],
exported_external_deps = [
"pcre2",
],
# Making sure this library is not being stripped by linker.
# @lint-ignore BUCKLINT: Avoid link_whole=True
link_whole = True,
visibility = [
"@EXECUTORCH_CLIENTS",
"//pytorch/tokenizers/...",
Expand Down
11 changes: 11 additions & 0 deletions test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ def define_common_targets():
],
)

runtime.cxx_test(
name = "test_regex",
srcs = [
"test_regex.cpp",
],
deps = [
"//pytorch/tokenizers:regex_lookahead",
"//pytorch/tokenizers:headers",
],
)

runtime.filegroup(
name = "resources",
srcs = native.glob([
Expand Down