|
14 | 14 |
|
15 | 15 | #include <string>
|
16 | 16 | #include <vector>
|
17 |
| -#include <map> |
18 |
| -#include <unordered_map> |
19 |
| -#include <memory> |
20 | 17 | #include <stdexcept>
|
21 | 18 |
|
22 | 19 | #ifdef __has_include
|
@@ -544,166 +541,4 @@ struct llama_ctx_buffer {
|
544 | 541 | typedef llama_buffer llama_ctx_buffer;
|
545 | 542 | #endif
|
546 | 543 |
|
547 |
| -struct llama_trie_node { |
548 |
| - llama_trie_node(): is_terminator(false) {} |
549 |
| - |
550 |
| - std::unordered_map<char, std::unique_ptr<llama_trie_node>> children; |
551 |
| - bool is_terminator; |
552 |
| -}; |
553 |
| - |
554 |
| -// Trie in C++. Creates a Trie out of a list of words. The trie is used to split on multiple delimiters in one pass |
555 |
| -// Ported from: https://github.com/huggingface/transformers/blob/ee88ae59940fd4b2c8fc119373143d7a1175c651/src/transformers/tokenization_utils.py#L52 |
556 |
| -struct llama_trie { |
557 |
| -public: |
558 |
| - llama_trie(): root_(new llama_trie_node()) {} |
559 |
| - |
560 |
| - void add(const std::string & word) { |
561 |
| - if (word.empty()) { |
562 |
| - return; |
563 |
| - } |
564 |
| - |
565 |
| - llama_trie_node *ref = root_.get(); |
566 |
| - for (char c : word) { |
567 |
| - if (ref->children.find(c) == ref->children.end()) { |
568 |
| - ref->children[c].reset(new llama_trie_node()); |
569 |
| - } |
570 |
| - ref = ref->children[c].get(); |
571 |
| - } |
572 |
| - ref->is_terminator = true; |
573 |
| - } |
574 |
| - |
575 |
| - // Will look for the words added to the trie within `text`. Output is the boundaries of the words found. |
576 |
| - // Note that this trie will match the longest possible word first! |
577 |
| - std::vector<size_t> split(const std::string & text) const { |
578 |
| - std::map<size_t, llama_trie_node*> states; |
579 |
| - std::vector<size_t> offsets{0}; |
580 |
| - |
581 |
| - size_t skip = 0; |
582 |
| - for (size_t current = 0; current < text.size(); current++) { |
583 |
| - char current_char = text[current]; |
584 |
| - if (skip > 0 && current < skip) { |
585 |
| - // Prevents the lookahead for matching twice |
586 |
| - // like extra_id_100 and id_100 |
587 |
| - continue; |
588 |
| - } |
589 |
| - |
590 |
| - // Whenever we found a match, we need to drop everything |
591 |
| - // this is a greedy algorithm, it will match on the first found token |
592 |
| - bool reset = false; |
593 |
| - |
594 |
| - // In this case, we already have partial matches (But unfinished) |
595 |
| - for (auto state = states.begin(); state != states.end(); ) { |
596 |
| - size_t start = state->first; |
597 |
| - llama_trie_node *trie_pointer = state->second; |
598 |
| - if (trie_pointer->is_terminator) { |
599 |
| - // This is a final match, we need to reset and |
600 |
| - // store the results in `offsets`. |
601 |
| - |
602 |
| - // Lookahead to match longest first |
603 |
| - // Important in case of extra_id_1 vs extra_id_100 |
604 |
| - // Here we are also actively looking for other earlier partial |
605 |
| - // matches |
606 |
| - // "[CLS]", "L", we need to match CLS even if L is special |
607 |
| - size_t end = 0; |
608 |
| - for (const auto & look : states) { |
609 |
| - size_t lookstart = look.first; |
610 |
| - llama_trie_node *looktrie_pointer = look.second; |
611 |
| - size_t lookahead_index = 0; |
612 |
| - if (lookstart > start) { |
613 |
| - // This partial match is later, we can stop looking |
614 |
| - break; |
615 |
| - } |
616 |
| - if (lookstart < start) { |
617 |
| - // This partial match is earlier, the trie pointer |
618 |
| - // was already updated, so index is + 1 |
619 |
| - lookahead_index = current + 1; |
620 |
| - end = current + 1; |
621 |
| - } else { |
622 |
| - // Here lookstart == start and |
623 |
| - // looktrie_pointer == trie_pointer |
624 |
| - // It wasn't updated yet so indices are current ones |
625 |
| - lookahead_index = current; |
626 |
| - end = current; |
627 |
| - } |
628 |
| - char next_char = lookahead_index < text.size() ? text[lookahead_index] : '\0'; |
629 |
| - if (looktrie_pointer->is_terminator) { |
630 |
| - start = lookstart; |
631 |
| - end = lookahead_index; |
632 |
| - skip = lookahead_index; |
633 |
| - } |
634 |
| - |
635 |
| - auto looktrie_pointer_it = looktrie_pointer->children.find(next_char); |
636 |
| - while (looktrie_pointer_it != looktrie_pointer->children.end()) { |
637 |
| - looktrie_pointer = looktrie_pointer_it->second.get(); |
638 |
| - lookahead_index++; |
639 |
| - if (looktrie_pointer->is_terminator) { |
640 |
| - start = lookstart; |
641 |
| - end = lookahead_index; |
642 |
| - skip = lookahead_index; |
643 |
| - } |
644 |
| - |
645 |
| - if (lookahead_index == text.size()) { |
646 |
| - // End of string |
647 |
| - break; |
648 |
| - } |
649 |
| - next_char = text[lookahead_index]; |
650 |
| - looktrie_pointer_it = looktrie_pointer->children.find(next_char); |
651 |
| - } |
652 |
| - } |
653 |
| - |
654 |
| - offsets.push_back(start); |
655 |
| - offsets.push_back(end); |
656 |
| - reset = true; |
657 |
| - break; |
658 |
| - } |
659 |
| - |
660 |
| - auto trie_pointer_it = trie_pointer->children.find(current_char); |
661 |
| - if (trie_pointer_it != trie_pointer->children.end()) { |
662 |
| - // The current character being looked at has a match within the trie |
663 |
| - // update the pointer (it will be stored back into states later). |
664 |
| - trie_pointer = trie_pointer_it->second.get(); |
665 |
| - states[start] = trie_pointer; |
666 |
| - ++state; |
667 |
| - } else { |
668 |
| - // The new character has not match in the trie, we need |
669 |
| - // to stop keeping track of this partial match. |
670 |
| - state = states.erase(state); |
671 |
| - } |
672 |
| - } |
673 |
| - |
674 |
| - if (reset) { |
675 |
| - // Clear the full start (we found a real match) |
676 |
| - states.clear(); |
677 |
| - } |
678 |
| - |
679 |
| - // If this character is a starting character within the trie |
680 |
| - // start keeping track of this partial match. |
681 |
| - auto children_it = root_->children.find(current_char); |
682 |
| - if (current >= skip && children_it != root_->children.end()) { |
683 |
| - states[current] = children_it->second.get(); |
684 |
| - } |
685 |
| - } |
686 |
| - |
687 |
| - // We have a cut at the end with states. |
688 |
| - for (const auto & state : states) { |
689 |
| - size_t start = state.first; |
690 |
| - llama_trie_node *trie_pointer = state.second; |
691 |
| - if (trie_pointer->is_terminator) { |
692 |
| - // This is a final match, we need to reset and |
693 |
| - // store the results in `offsets`. |
694 |
| - size_t end = text.size(); |
695 |
| - offsets.push_back(start); |
696 |
| - offsets.push_back(end); |
697 |
| - break; |
698 |
| - } |
699 |
| - } |
700 |
| - |
701 |
| - offsets.push_back(text.size()); |
702 |
| - return offsets; |
703 |
| - } |
704 |
| - |
705 |
| -private: |
706 |
| - std::unique_ptr<llama_trie_node> root_; |
707 |
| -}; |
708 |
| - |
709 | 544 | #endif
|
0 commit comments