@@ -544,166 +544,4 @@ struct llama_ctx_buffer {
544
544
typedef llama_buffer llama_ctx_buffer;
545
545
#endif
546
546
547
- struct llama_trie_node {
548
- llama_trie_node (): is_terminator(false ) {}
549
-
550
- std::unordered_map<char , std::unique_ptr<llama_trie_node>> children;
551
- bool is_terminator;
552
- };
553
-
554
- // Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass
555
- // Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52
556
- struct llama_trie {
557
- public:
558
- llama_trie (): root_(new llama_trie_node()) {}
559
-
560
- void add (const std::string & word) {
561
- if (word.empty ()) {
562
- return ;
563
- }
564
-
565
- llama_trie_node *ref = root_.get ();
566
- for (char c : word) {
567
- if (ref->children .find (c) == ref->children .end ()) {
568
- ref->children [c].reset (new llama_trie_node ());
569
- }
570
- ref = ref->children [c].get ();
571
- }
572
- ref->is_terminator = true ;
573
- }
574
-
575
- // Will look for the words added to the trie within `text`. Output is the boundaries of the words found.
576
- // Note that this trie will match the longest possible word first!
577
- std::vector<size_t > split (const std::string & text) const {
578
- std::map<size_t , llama_trie_node*> states;
579
- std::vector<size_t > offsets{0 };
580
-
581
- size_t skip = 0 ;
582
- for (size_t current = 0 ; current < text.size (); current++) {
583
- char current_char = text[current];
584
- if (skip > 0 && current < skip) {
585
- // Prevents the lookahead for matching twice
586
- // like extra_id_100 and id_100
587
- continue ;
588
- }
589
-
590
- // Whenever we found a match, we need to drop everything
591
- // this is a greedy algorithm, it will match on the first found token
592
- bool reset = false ;
593
-
594
- // In this case, we already have partial matches (But unfinished)
595
- for (auto state = states.begin (); state != states.end (); ) {
596
- size_t start = state->first ;
597
- llama_trie_node *trie_pointer = state->second ;
598
- if (trie_pointer->is_terminator ) {
599
- // This is a final match, we need to reset and
600
- // store the results in `offsets`.
601
-
602
- // Lookahead to match longest first
603
- // Important in case of extra_id_1 vs extra_id_100
604
- // Here we are also actively looking for other earlier partial
605
- // matches
606
- // "[CLS]", "L", we need to match CLS even if L is special
607
- size_t end = 0 ;
608
- for (const auto & look : states) {
609
- size_t lookstart = look.first ;
610
- llama_trie_node *looktrie_pointer = look.second ;
611
- size_t lookahead_index = 0 ;
612
- if (lookstart > start) {
613
- // This partial match is later, we can stop looking
614
- break ;
615
- }
616
- if (lookstart < start) {
617
- // This partial match is earlier, the trie pointer
618
- // was already updated, so index is + 1
619
- lookahead_index = current + 1 ;
620
- end = current + 1 ;
621
- } else {
622
- // Here lookstart == start and
623
- // looktrie_pointer == trie_pointer
624
- // It wasn't updated yet so indices are current ones
625
- lookahead_index = current;
626
- end = current;
627
- }
628
- char next_char = lookahead_index < text.size () ? text[lookahead_index] : ' \0 ' ;
629
- if (looktrie_pointer->is_terminator ) {
630
- start = lookstart;
631
- end = lookahead_index;
632
- skip = lookahead_index;
633
- }
634
-
635
- auto looktrie_pointer_it = looktrie_pointer->children .find (next_char);
636
- while (looktrie_pointer_it != looktrie_pointer->children .end ()) {
637
- looktrie_pointer = looktrie_pointer_it->second .get ();
638
- lookahead_index++;
639
- if (looktrie_pointer->is_terminator ) {
640
- start = lookstart;
641
- end = lookahead_index;
642
- skip = lookahead_index;
643
- }
644
-
645
- if (lookahead_index == text.size ()) {
646
- // End of string
647
- break ;
648
- }
649
- next_char = text[lookahead_index];
650
- looktrie_pointer_it = looktrie_pointer->children .find (next_char);
651
- }
652
- }
653
-
654
- offsets.push_back (start);
655
- offsets.push_back (end);
656
- reset = true ;
657
- break ;
658
- }
659
-
660
- auto trie_pointer_it = trie_pointer->children .find (current_char);
661
- if (trie_pointer_it != trie_pointer->children .end ()) {
662
- // The current character being looked at has a match within the trie
663
- // update the pointer (it will be stored back into states later).
664
- trie_pointer = trie_pointer_it->second .get ();
665
- states[start] = trie_pointer;
666
- ++state;
667
- } else {
668
- // The new character has not match in the trie, we need
669
- // to stop keeping track of this partial match.
670
- state = states.erase (state);
671
- }
672
- }
673
-
674
- if (reset) {
675
- // Clear the full start (we found a real match)
676
- states.clear ();
677
- }
678
-
679
- // If this character is a starting character within the trie
680
- // start keeping track of this partial match.
681
- auto children_it = root_->children .find (current_char);
682
- if (current >= skip && children_it != root_->children .end ()) {
683
- states[current] = children_it->second .get ();
684
- }
685
- }
686
-
687
- // We have a cut at the end with states.
688
- for (const auto & state : states) {
689
- size_t start = state.first ;
690
- llama_trie_node *trie_pointer = state.second ;
691
- if (trie_pointer->is_terminator ) {
692
- // This is a final match, we need to reset and
693
- // store the results in `offsets`.
694
- size_t end = text.size ();
695
- offsets.push_back (start);
696
- offsets.push_back (end);
697
- break ;
698
- }
699
- }
700
-
701
- offsets.push_back (text.size ());
702
- return offsets;
703
- }
704
-
705
- private:
706
- std::unique_ptr<llama_trie_node> root_;
707
- };
708
-
709
547
#endif
0 commit comments