Skip to content

Commit c7b8254

Browse files
nuchiteleprint-me
authored andcommitted
Add left recursion check: quit early instead of going into an infinite loop (ggml-org#7083)
* Add left recursion check: quit early instead of going into an infinite loop * Remove custom enum, rename left recursion check and move to "grammar internal" section, add handling for edge case where a leftmost nonterminal may be empty * Remove unnecessary declaration
1 parent 95390eb commit c7b8254

File tree

2 files changed

+114
-0
lines changed

2 files changed

+114
-0
lines changed

llama.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13191,6 +13191,58 @@ static std::vector<llama_grammar_candidate> llama_grammar_reject_candidates(
1319113191
return rejects;
1319213192
}
1319313193

13194+
static bool llama_grammar_detect_left_recursion(
13195+
const std::vector<std::vector<llama_grammar_element>> & rules,
13196+
size_t rule_index,
13197+
std::vector<bool> * rules_visited,
13198+
std::vector<bool> * rules_in_progress,
13199+
std::vector<bool> * rules_may_be_empty) {
13200+
if ((*rules_in_progress)[rule_index]) {
13201+
return true;
13202+
}
13203+
13204+
(*rules_in_progress)[rule_index] = true;
13205+
13206+
const std::vector<llama_grammar_element> & rule = rules[rule_index];
13207+
13208+
// First check if the rule might produce the empty string. This could be done combined with the second
13209+
// step but it's more readable as two steps.
13210+
bool at_rule_start = true;
13211+
for (size_t i = 0; i < rule.size(); i++) {
13212+
if (llama_grammar_is_end_of_sequence(&rule[i])) {
13213+
if (at_rule_start) {
13214+
(*rules_may_be_empty)[rule_index] = true;
13215+
break;
13216+
}
13217+
at_rule_start = true;
13218+
} else {
13219+
at_rule_start = false;
13220+
}
13221+
}
13222+
13223+
// Second, recurse into leftmost nonterminals (or next-leftmost as long as the previous nonterminal may
13224+
// be empty)
13225+
bool recurse_into_nonterminal = true;
13226+
for (size_t i = 0; i < rule.size(); i++) {
13227+
if (rule[i].type == LLAMA_GRETYPE_RULE_REF && recurse_into_nonterminal) {
13228+
if (llama_grammar_detect_left_recursion(rules, (size_t)rule[i].value, rules_visited, rules_in_progress, rules_may_be_empty)) {
13229+
return true;
13230+
}
13231+
if (!((*rules_may_be_empty)[(size_t)rule[i].value])) {
13232+
recurse_into_nonterminal = false;
13233+
}
13234+
} else if (llama_grammar_is_end_of_sequence(&rule[i])) {
13235+
recurse_into_nonterminal = true;
13236+
} else {
13237+
recurse_into_nonterminal = false;
13238+
}
13239+
}
13240+
13241+
(*rules_in_progress)[rule_index] = false;
13242+
(*rules_visited)[rule_index] = true;
13243+
return false;
13244+
}
13245+
1319413246
//
1319513247
// grammar - external
1319613248
//
@@ -13210,6 +13262,19 @@ struct llama_grammar * llama_grammar_init(
1321013262
vec_rules[i].push_back({LLAMA_GRETYPE_END, 0});
1321113263
}
1321213264

13265+
// Check for left recursion
13266+
std::vector<bool> rules_visited(n_rules);
13267+
std::vector<bool> rules_in_progress(n_rules);
13268+
std::vector<bool> rules_may_be_empty(n_rules);
13269+
for (size_t i = 0; i < n_rules; i++) {
13270+
if (rules_visited[i]) {
13271+
continue;
13272+
}
13273+
if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) {
13274+
throw std::runtime_error(format("unsupported grammar, left recursion detected for nonterminal at index %zu", i));
13275+
}
13276+
}
13277+
1321313278
// loop over alternates of start rule to build initial stacks
1321413279
std::vector<std::vector<const llama_grammar_element *>> stacks;
1321513280
pos = vec_rules[start_rule_index].data();
@@ -13232,6 +13297,9 @@ struct llama_grammar * llama_grammar_init(
1323213297
}
1323313298
} while (true);
1323413299

13300+
// Important: vec_rules has to be moved here, not copied, because stacks contains
13301+
// pointers to elements of vec_rules. If vec_rules were copied into llama_grammar
13302+
// then the pointers would be invalidated when the local vec_rules goes out of scope.
1323513303
return new llama_grammar{ std::move(vec_rules), std::move(stacks), {} };
1323613304
}
1323713305

tests/test-grammar-integration.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,19 @@ static llama_grammar* build_grammar(const std::string & grammar_str) {
2828
return grammar;
2929
}
3030

31+
static bool test_build_grammar_fails(const std::string & grammar_str) {
32+
fprintf(stderr, "⚫ Testing failure for grammar: %s\n", grammar_str.c_str());
33+
bool grammar_fails = false;
34+
try {
35+
build_grammar(grammar_str);
36+
fprintf(stderr, " ❌ Expected build failure, but succeeded\n");
37+
} catch (const std::exception & err) {
38+
grammar_fails = true;
39+
fprintf(stdout, " ✅︎\n");
40+
}
41+
return grammar_fails;
42+
}
43+
3144
static bool match_string(const std::string & input, llama_grammar* grammar) {
3245
auto decoded = decode_utf8(input, {});
3346

@@ -320,13 +333,46 @@ number ::= [0-9]+)""";
320333
fprintf(stderr, " ✅︎ Passed\n");
321334
}
322335

336+
static void test_failure_left_recursion() {
337+
fprintf(stderr, "⚫ Testing left recursion detection:\n");
338+
339+
// Test simple left recursion detection
340+
const std::string simple_str = R"""(root ::= "a" | root "a")""";
341+
assert(test_build_grammar_fails(simple_str));
342+
343+
// Test more complicated left recursion detection
344+
const std::string medium_str = R"""(
345+
root ::= asdf
346+
asdf ::= "a" | asdf "a"
347+
)""";
348+
assert(test_build_grammar_fails(medium_str));
349+
350+
// Test even more complicated left recursion detection
351+
const std::string hard_str = R"""(
352+
root ::= asdf
353+
asdf ::= "a" | foo "b"
354+
foo ::= "c" | asdf "d" | "e")""";
355+
assert(test_build_grammar_fails(hard_str));
356+
357+
// Test yet even more complicated left recursion detection
358+
const std::string hardest_str = R"""(
359+
root ::= asdf
360+
asdf ::= "a" | foo "b"
361+
foo ::= "c" | empty asdf "d" | "e"
362+
empty ::= "blah" | )""";
363+
assert(test_build_grammar_fails(hardest_str));
364+
365+
fprintf(stderr, " ✅︎ Passed\n");
366+
}
367+
323368
int main() {
324369
fprintf(stdout, "Running grammar integration tests...\n");
325370
test_simple_grammar();
326371
test_complex_grammar();
327372
test_quantifiers();
328373
test_failure_missing_root();
329374
test_failure_missing_reference();
375+
test_failure_left_recursion();
330376
fprintf(stdout, "All tests passed.\n");
331377
return 0;
332378
}

0 commit comments

Comments
 (0)