From 65176e78ec900416cd59053a790c49fced906189 Mon Sep 17 00:00:00 2001 From: Haggai Nuchi Date: Sat, 4 May 2024 22:43:08 -0700 Subject: [PATCH 1/3] Add left recursion check: quit early instead of going into an infinite loop --- llama.cpp | 72 ++++++++++++++++++++++++++++++ tests/test-grammar-integration.cpp | 37 +++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/llama.cpp b/llama.cpp index e91ad7285da99..2571bb0988554 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13178,6 +13178,25 @@ static std::vector llama_grammar_reject_candidates( // grammar - external // +enum detect_left_recursion_status { + // haven't searched this nonterminal + LLAMA_LEFT_REC_NOT_SEARCHED = 0, + + // searching this nonterminal in progress + LLAMA_LEFT_REC_IN_PROGRESS = 1, + + // finished searching this nonterminal + LLAMA_LEFT_REC_FINISHED_SEARCH = 2, + + // detected a cycle + LLAMA_LEFT_REC_FOUND_CYCLE = 3, +}; + +static void detect_left_recursion( + const std::vector> & rules, + size_t rule_index, + std::vector * rules_visited); + struct llama_grammar * llama_grammar_init( const llama_grammar_element ** rules, size_t n_rules, @@ -13193,6 +13212,17 @@ struct llama_grammar * llama_grammar_init( vec_rules[i].push_back({LLAMA_GRETYPE_END, 0}); } + // Check for left recursion + std::vector rules_visited(n_rules); + for (size_t i = 0; i < n_rules; i++) { + detect_left_recursion(vec_rules, i, &rules_visited); + } + + auto iter = std::find(rules_visited.begin(), rules_visited.end(), LLAMA_LEFT_REC_FOUND_CYCLE); + if (iter != rules_visited.end()) { + throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %d", (int)(iter - rules_visited.begin()))); + } + // loop over alternates of start rule to build initial stacks std::vector> stacks; pos = vec_rules[start_rule_index].data(); @@ -13215,9 +13245,51 @@ struct llama_grammar * llama_grammar_init( } } while (true); + // Important: vec_rules has to be moved here, not copied, because stacks contains + // pointers to elements of vec_rules. If vec_rules were copied into llama_grammar + // then the pointers would be invalidated when the local vec_rules goes out of scope. return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; } +static void detect_left_recursion( + const std::vector> & rules, + size_t rule_index, + std::vector * rules_visited) { + + int visit_status = (*rules_visited)[rule_index]; + if (visit_status == LLAMA_LEFT_REC_IN_PROGRESS) { + // in progress -- we're in a cycle + (*rules_visited)[rule_index] = LLAMA_LEFT_REC_FOUND_CYCLE; + return; + } else if (visit_status == LLAMA_LEFT_REC_NOT_SEARCHED) { + // haven't visited yet. mark in progress, recurse, then mark complete. + // mark in progress + (*rules_visited)[rule_index] = LLAMA_LEFT_REC_IN_PROGRESS; + + // recurse + const std::vector & rule = rules[rule_index]; + size_t i = 0; + do { + if (rule[i].type == LLAMA_GRETYPE_RULE_REF) { + detect_left_recursion(rules, (size_t)rule[i].value, rules_visited); + } + while (!llama_grammar_is_end_of_sequence(&rule[i])) { + i++; + } + i++; + } while (i < rule.size()); + + // mark complete, but only if the recursive call didn't mark a cycle. + // that doesn't mean there's definitely no cycle *for this rule* -- the recursive call + // might have found a different cycle and stopped early. + if ((*rules_visited)[rule_index] == LLAMA_LEFT_REC_IN_PROGRESS) { + (*rules_visited)[rule_index] = LLAMA_LEFT_REC_FINISHED_SEARCH; + } + } + + return; +} + void llama_grammar_free(struct llama_grammar * grammar) { delete grammar; } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 1a4004e2ab175..0ae7b09723026 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -28,6 +28,18 @@ static llama_grammar* build_grammar(const std::string & grammar_str) { return grammar; } +static bool test_build_grammar_fails(const std::string & grammar_str) { + bool grammar_fails = false; + try { + build_grammar(grammar_str); + fprintf(stderr, "❌ Expected build failure, but succeeded: %s\n", grammar_str.c_str()); + } catch (const std::exception & err) { + grammar_fails = true; + fprintf(stdout, "✅︎\n"); + } + return grammar_fails; +} + static bool match_string(const std::string & input, llama_grammar* grammar) { auto decoded = decode_utf8(input, {}); @@ -320,6 +332,30 @@ number ::= [0-9]+)"""; fprintf(stderr, " ✅︎ Passed\n"); } +static void test_failure_left_recursion() { + fprintf(stderr, "⚫ Testing left recursion detection:\n"); + + // Test simple left recursion detection + const std::string simple_str = R"""(root ::= "a" | root "a")"""; + assert(test_build_grammar_fails(simple_str)); + + // Test more complicated left recursion detection + const std::string medium_str = R"""( +root ::= asdf +asdf ::= "a" | asdf "a" +)"""; + assert(test_build_grammar_fails(medium_str)); + + // Test even more complicated left recursion detection + const std::string hard_str = R"""( +root ::= asdf +asdf ::= "a" | foo "b" +foo ::= "c" | asdf "d" | "e")"""; + assert(test_build_grammar_fails(hard_str)); + + fprintf(stderr, " ✅︎ Passed\n"); +} + int main() { fprintf(stdout, "Running grammar integration tests...\n"); test_simple_grammar(); @@ -327,6 +363,7 @@ int main() { test_quantifiers(); test_failure_missing_root(); test_failure_missing_reference(); + test_failure_left_recursion(); fprintf(stdout, "All tests passed.\n"); return 0; } From 5c5911d69dcc82cbcb4b178e4fe4c9eab3d4f11f Mon Sep 17 00:00:00 2001 From: Haggai Nuchi Date: Sat, 11 May 2024 16:28:59 -0700 Subject: [PATCH 2/3] Remove custom enum, rename left recursion check and move to "grammar internal" section, add handling for edge case where a leftmost nonterminal may be empty --- llama.cpp | 129 +++++++++++++++-------------- tests/test-grammar-integration.cpp | 13 ++- 2 files changed, 77 insertions(+), 65 deletions(-) diff --git a/llama.cpp b/llama.cpp index 2571bb0988554..e5c9bd2b30c38 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13174,28 +13174,68 @@ static std::vector llama_grammar_reject_candidates( return rejects; } -// -// grammar - external -// +static bool llama_grammar_detect_left_recursion( + const std::vector> & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty); -enum detect_left_recursion_status { - // haven't searched this nonterminal - LLAMA_LEFT_REC_NOT_SEARCHED = 0, +static bool llama_grammar_detect_left_recursion( + const std::vector> & rules, + size_t rule_index, + std::vector * rules_visited, + std::vector * rules_in_progress, + std::vector * rules_may_be_empty) { + if ((*rules_in_progress)[rule_index]) { + return true; + } - // searching this nonterminal in progress - LLAMA_LEFT_REC_IN_PROGRESS = 1, + (*rules_in_progress)[rule_index] = true; - // finished searching this nonterminal - LLAMA_LEFT_REC_FINISHED_SEARCH = 2, + const std::vector & rule = rules[rule_index]; - // detected a cycle - LLAMA_LEFT_REC_FOUND_CYCLE = 3, -}; + // First check if the rule might produce the empty string. This could be done combined with the second + // step but it's more readable as two steps. + bool at_rule_start = true; + for (size_t i = 0; i < rule.size(); i++) { + if (llama_grammar_is_end_of_sequence(&rule[i])) { + if (at_rule_start) { + (*rules_may_be_empty)[rule_index] = true; + break; + } + at_rule_start = true; + } else { + at_rule_start = false; + } + } -static void detect_left_recursion( - const std::vector> & rules, - size_t rule_index, - std::vector * rules_visited); + // Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may + // be empty) + bool recurse_into_nonterminal = true; + for (size_t i = 0; i < rule.size(); i++) { + if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) { + if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) { + return true; + } + if (!((*rules_may_be_empty)[(size_t)rule[i].value])) { + recurse_into_nonterminal = false; + } + } else if (llama_grammar_is_end_of_sequence(&rule[i])) { + recurse_into_nonterminal = true; + } else { + recurse_into_nonterminal = false; + } + } + + (*rules_in_progress)[rule_index] = false; + (*rules_visited)[rule_index] = true; + return false; +} + +// +// grammar - external +// struct llama_grammar * llama_grammar_init( const llama_grammar_element ** rules, @@ -13213,14 +13253,16 @@ struct llama_grammar * llama_grammar_init( } // Check for left recursion - std::vector rules_visited(n_rules); + std::vector rules_visited(n_rules); + std::vector rules_in_progress(n_rules); + std::vector rules_may_be_empty(n_rules); for (size_t i = 0; i < n_rules; i++) { - detect_left_recursion(vec_rules, i, &rules_visited); - } - - auto iter = std::find(rules_visited.begin(), rules_visited.end(), LLAMA_LEFT_REC_FOUND_CYCLE); - if (iter != rules_visited.end()) { - throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %d", (int)(iter - rules_visited.begin()))); + if (rules_visited[i]) { + continue; + } + if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { + throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i)); + } } // loop over alternates of start rule to build initial stacks @@ -13251,45 +13293,6 @@ struct llama_grammar * llama_grammar_init( return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} }; } -static void detect_left_recursion( - const std::vector> & rules, - size_t rule_index, - std::vector * rules_visited) { - - int visit_status = (*rules_visited)[rule_index]; - if (visit_status == LLAMA_LEFT_REC_IN_PROGRESS) { - // in progress -- we're in a cycle - (*rules_visited)[rule_index] = LLAMA_LEFT_REC_FOUND_CYCLE; - return; - } else if (visit_status == LLAMA_LEFT_REC_NOT_SEARCHED) { - // haven't visited yet. mark in progress, recurse, then mark complete. - // mark in progress - (*rules_visited)[rule_index] = LLAMA_LEFT_REC_IN_PROGRESS; - - // recurse - const std::vector & rule = rules[rule_index]; - size_t i = 0; - do { - if (rule[i].type == LLAMA_GRETYPE_RULE_REF) { - detect_left_recursion(rules, (size_t)rule[i].value, rules_visited); - } - while (!llama_grammar_is_end_of_sequence(&rule[i])) { - i++; - } - i++; - } while (i < rule.size()); - - // mark complete, but only if the recursive call didn't mark a cycle. - // that doesn't mean there's definitely no cycle *for this rule* -- the recursive call - // might have found a different cycle and stopped early. - if ((*rules_visited)[rule_index] == LLAMA_LEFT_REC_IN_PROGRESS) { - (*rules_visited)[rule_index] = LLAMA_LEFT_REC_FINISHED_SEARCH; - } - } - - return; -} - void llama_grammar_free(struct llama_grammar * grammar) { delete grammar; } diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 0ae7b09723026..01c5bb27aabb9 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -29,13 +29,14 @@ static llama_grammar* build_grammar(const std::string & grammar_str) { } static bool test_build_grammar_fails(const std::string & grammar_str) { + fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str()); bool grammar_fails = false; try { build_grammar(grammar_str); - fprintf(stderr, "❌ Expected build failure, but succeeded: %s\n", grammar_str.c_str()); + fprintf(stderr, " ❌ Expected build failure, but succeeded\n"); } catch (const std::exception & err) { grammar_fails = true; - fprintf(stdout, "✅︎\n"); + fprintf(stdout, " ✅︎\n"); } return grammar_fails; } @@ -353,6 +354,14 @@ asdf ::= "a" | foo "b" foo ::= "c" | asdf "d" | "e")"""; assert(test_build_grammar_fails(hard_str)); + // Test yet even more complicated left recursion detection + const std::string hardest_str = R"""( +root ::= asdf +asdf ::= "a" | foo "b" +foo ::= "c" | empty asdf "d" | "e" +empty ::= "blah" | )"""; + assert(test_build_grammar_fails(hardest_str)); + fprintf(stderr, " ✅︎ Passed\n"); } From 1bbf54a13c93e30d32f22326d4c9f46890fa1eb2 Mon Sep 17 00:00:00 2001 From: Haggai Nuchi Date: Sun, 12 May 2024 12:59:30 -0700 Subject: [PATCH 3/3] Remove unnecessary declaration --- llama.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index e5c9bd2b30c38..96d1c18956d68 100644 --- a/llama.cpp +++ b/llama.cpp @@ -13174,13 +13174,6 @@ static std::vector llama_grammar_reject_candidates( return rejects; } -static bool llama_grammar_detect_left_recursion( - const std::vector> & rules, - size_t rule_index, - std::vector * rules_visited, - std::vector * rules_in_progress, - std::vector * rules_may_be_empty); - static bool llama_grammar_detect_left_recursion( const std::vector> & rules, size_t rule_index,