@@ -445,15 +445,15 @@ void llama_grammar_free_impl(struct llama_grammar * grammar) {
445
445
delete grammar;
446
446
}
447
447
448
- struct llama_grammar * llama_grammar_copy_impl (const struct llama_grammar * grammar) {
449
- llama_grammar * result = new llama_grammar{ grammar-> rules , grammar-> stacks , grammar-> partial_utf8 };
448
+ struct llama_grammar * llama_grammar_copy_impl (const struct llama_grammar & grammar) {
449
+ llama_grammar * result = new llama_grammar{ grammar. rules , grammar. stacks , grammar. partial_utf8 };
450
450
451
451
// redirect elements in stacks to point to new rules
452
452
for (size_t is = 0 ; is < result->stacks .size (); is++) {
453
453
for (size_t ie = 0 ; ie < result->stacks [is].size (); ie++) {
454
- for (size_t ir0 = 0 ; ir0 < grammar-> rules .size (); ir0++) {
455
- for (size_t ir1 = 0 ; ir1 < grammar-> rules [ir0].size (); ir1++) {
456
- if (grammar-> stacks [is][ie] == &grammar-> rules [ir0][ir1]) {
454
+ for (size_t ir0 = 0 ; ir0 < grammar. rules .size (); ir0++) {
455
+ for (size_t ir1 = 0 ; ir1 < grammar. rules [ir0].size (); ir1++) {
456
+ if (grammar. stacks [is][ie] == &grammar. rules [ir0][ir1]) {
457
457
result->stacks [is][ie] = &result->rules [ir0][ir1];
458
458
}
459
459
}
@@ -464,14 +464,9 @@ struct llama_grammar * llama_grammar_copy_impl(const struct llama_grammar * gram
464
464
return result;
465
465
}
466
466
467
- void llama_grammar_sample_impl (const struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token_data_array * candidates) {
468
- GGML_ASSERT (grammar);
469
- GGML_ASSERT (vocab);
470
-
471
- int64_t t_start_sample_us = ggml_time_us ();
472
-
467
+ void llama_grammar_sample_impl (const struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token_data_array * candidates) {
473
468
bool allow_eog = false ;
474
- for (const auto & stack : grammar-> stacks ) {
469
+ for (const auto & stack : grammar. stacks ) {
475
470
if (stack.empty ()) {
476
471
allow_eog = true ;
477
472
break ;
@@ -486,54 +481,48 @@ void llama_grammar_sample_impl(const struct llama_grammar * grammar, const struc
486
481
487
482
for (size_t i = 0 ; i < candidates->size ; ++i) {
488
483
const llama_token id = candidates->data [i].id ;
489
- const std::string & piece = vocab-> cache_token_to_piece .at (id);
484
+ const std::string & piece = vocab. cache_token_to_piece .at (id);
490
485
491
- if (llama_token_is_eog_impl (* vocab, id)) {
486
+ if (llama_token_is_eog_impl (vocab, id)) {
492
487
if (!allow_eog) {
493
488
candidates->data [i].logit = -INFINITY;
494
489
}
495
490
} else if (piece.empty () || piece[0 ] == 0 ) {
496
491
candidates->data [i].logit = -INFINITY;
497
492
} else {
498
- candidates_decoded.push_back (decode_utf8 (piece, grammar-> partial_utf8 ));
493
+ candidates_decoded.push_back (decode_utf8 (piece, grammar. partial_utf8 ));
499
494
candidates_grammar.push_back ({ i, candidates_decoded.back ().first .data (), candidates_decoded.back ().second });
500
495
}
501
496
}
502
497
503
- const auto rejects = llama_grammar_reject_candidates (grammar-> rules , grammar-> stacks , candidates_grammar);
498
+ const auto rejects = llama_grammar_reject_candidates (grammar. rules , grammar. stacks , candidates_grammar);
504
499
for (const auto & reject : rejects) {
505
500
candidates->data [reject.index ].logit = -INFINITY;
506
501
}
507
-
508
- smpl->t_sample_us += ggml_time_us () - t_start_sample_us;
509
502
}
510
503
511
- void llama_grammar_accept_token_impl (struct llama_grammar * grammar, const struct llama_vocab * vocab, const struct llama_sampling * smpl, llama_token token) {
512
- const int64_t t_start_sample_us = ggml_time_us ();
513
-
514
- if (llama_token_is_eog_impl (*vocab, token)) {
515
- for (const auto & stack : grammar->stacks ) {
504
+ void llama_grammar_accept_token_impl (struct llama_grammar & grammar, const struct llama_vocab & vocab, llama_token token) {
505
+ if (llama_token_is_eog_impl (vocab, token)) {
506
+ for (const auto & stack : grammar.stacks ) {
516
507
if (stack.empty ()) {
517
508
return ;
518
509
}
519
510
}
520
511
GGML_ASSERT (false );
521
512
}
522
513
523
- const std::string & piece = vocab-> cache_token_to_piece .at (token);
514
+ const std::string & piece = vocab. cache_token_to_piece .at (token);
524
515
525
516
// Note terminating 0 in decoded string
526
- const auto decoded = decode_utf8 (piece, grammar-> partial_utf8 );
517
+ const auto decoded = decode_utf8 (piece, grammar. partial_utf8 );
527
518
const auto & code_points = decoded.first ;
528
519
529
520
llama_grammar_stacks tmp_new_stacks;
530
521
for (auto it = code_points.begin (), end = code_points.end () - 1 ; it != end; ++it) {
531
- llama_grammar_accept (grammar-> rules , grammar-> stacks , *it, tmp_new_stacks);
532
- grammar-> stacks = tmp_new_stacks;
522
+ llama_grammar_accept (grammar. rules , grammar. stacks , *it, tmp_new_stacks);
523
+ grammar. stacks = tmp_new_stacks;
533
524
}
534
525
535
- grammar->partial_utf8 = decoded.second ;
536
- GGML_ASSERT (!grammar->stacks .empty ());
537
-
538
- smpl->t_sample_us += ggml_time_us () - t_start_sample_us;
526
+ grammar.partial_utf8 = decoded.second ;
527
+ GGML_ASSERT (!grammar.stacks .empty ());
539
528
}
0 commit comments