Skip to content

Add left recursion check: quit early instead of going into an infinite loop #7083

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13174,6 +13174,58 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
return rejects;
}

static bool llama_grammar_detect_left_recursion(
const std::vector<std::vector<llama_grammar_element>> & rules,
size_t rule_index,
std::vector<bool> * rules_visited,
std::vector<bool> * rules_in_progress,
std::vector<bool> * rules_may_be_empty) {
if ((*rules_in_progress)[rule_index]) {
return true;
}

(*rules_in_progress)[rule_index] = true;

const std::vector<llama_grammar_element> & rule = rules[rule_index];

// First check if the rule might produce the empty string. This could be done combined with the second
// step but it's more readable as two steps.
bool at_rule_start = true;
for (size_t i = 0; i < rule.size(); i++) {
if (llama_grammar_is_end_of_sequence(&rule[i])) {
if (at_rule_start) {
(*rules_may_be_empty)[rule_index] = true;
break;
}
at_rule_start = true;
} else {
at_rule_start = false;
}
}

// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
// be empty)
bool recurse_into_nonterminal = true;
for (size_t i = 0; i < rule.size(); i++) {
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
return true;
}
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
recurse_into_nonterminal = false;
}
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
recurse_into_nonterminal = true;
} else {
recurse_into_nonterminal = false;
}
}

(*rules_in_progress)[rule_index] = false;
(*rules_visited)[rule_index] = true;
return false;
}

//
// grammar - external
//
Expand All @@ -13193,6 +13245,19 @@ struct llama_grammar * llama_grammar_init(
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
}

// Check for left recursion
std::vector<bool> rules_visited(n_rules);
std::vector<bool> rules_in_progress(n_rules);
std::vector<bool> rules_may_be_empty(n_rules);
for (size_t i = 0; i < n_rules; i++) {
if (rules_visited[i]) {
continue;
}
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
}
}

// loop over alternates of start rule to build initial stacks
std::vector<std::vector<const llama_grammar_element *>> stacks;
pos = vec_rules[start_rule_index].data();
Expand All @@ -13215,6 +13280,9 @@ struct llama_grammar * llama_grammar_init(
}
} while (true);

// Important: vec_rules has to be moved here, not copied, because stacks contains
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
// then the pointers would be invalidated when the local vec_rules goes out of scope.
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
}

Expand Down
46 changes: 46 additions & 0 deletions tests/test-grammar-integration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
return grammar;
}

static bool test_build_grammar_fails(const std::string & grammar_str) {
fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
bool grammar_fails = false;
try {
build_grammar(grammar_str);
fprintf(stderr, " ❌ Expected build failure, but succeeded\n");
} catch (const std::exception & err) {
grammar_fails = true;
fprintf(stdout, " ✅︎\n");
}
return grammar_fails;
}

static bool match_string(const std::string & input, llama_grammar* grammar) {
auto decoded = decode_utf8(input, {});

Expand Down Expand Up @@ -320,13 +333,46 @@ number ::= [0-9]+)""";
fprintf(stderr, " ✅︎ Passed\n");
}

static void test_failure_left_recursion() {
fprintf(stderr, "⚫ Testing left recursion detection:\n");

// Test simple left recursion detection
const std::string simple_str = R"""(root ::= "a" | root "a")""";
assert(test_build_grammar_fails(simple_str));

// Test more complicated left recursion detection
const std::string medium_str = R"""(
root ::= asdf
asdf ::= "a" | asdf "a"
)""";
assert(test_build_grammar_fails(medium_str));

// Test even more complicated left recursion detection
const std::string hard_str = R"""(
root ::= asdf
asdf ::= "a" | foo "b"
foo ::= "c" | asdf "d" | "e")""";
assert(test_build_grammar_fails(hard_str));

// Test yet even more complicated left recursion detection
const std::string hardest_str = R"""(
root ::= asdf
asdf ::= "a" | foo "b"
foo ::= "c" | empty asdf "d" | "e"
empty ::= "blah" | )""";
assert(test_build_grammar_fails(hardest_str));

fprintf(stderr, " ✅︎ Passed\n");
}

int main() {
fprintf(stdout, "Running grammar integration tests...\n");
test_simple_grammar();
test_complex_grammar();
test_quantifiers();
test_failure_missing_root();
test_failure_missing_reference();
test_failure_left_recursion();
fprintf(stdout, "All tests passed.\n");
return 0;
}
Loading