32
32
#include < vector>
33
33
34
34
#include " conversation.h"
35
+ #include " random.h"
36
+ #include " support.h"
37
+ #include " tokenizers.h"
35
38
36
39
namespace mlc {
37
40
namespace llm {
38
41
39
42
using tvm::Device;
40
43
using namespace tvm ::runtime;
41
44
namespace {
42
- // ----------------------------
43
- // Tokenizers
44
- // ----------------------------
45
- using tokenizers::Tokenizer;
46
-
47
- std::string LoadBytesFromFile (const std::string& path) {
48
- std::ifstream fs (path, std::ios::in | std::ios::binary);
49
- ICHECK (!fs.fail ()) << " Cannot open " << path;
50
- std::string data;
51
- fs.seekg (0 , std::ios::end);
52
- size_t size = static_cast <size_t >(fs.tellg ());
53
- fs.seekg (0 , std::ios::beg);
54
- data.resize (size);
55
- fs.read (data.data (), size);
56
- return data;
57
- }
58
-
59
- std::unique_ptr<Tokenizer> TokenizerFromPath (const std::string& _path) {
60
- std::filesystem::path path (_path);
61
- std::filesystem::path sentencepiece;
62
- std::filesystem::path huggingface;
63
- std::filesystem::path rwkvworld;
64
- CHECK (std::filesystem::exists (path)) << " Cannot find tokenizer via path: " << _path;
65
- if (std::filesystem::is_directory (path)) {
66
- sentencepiece = path / " tokenizer.model" ;
67
- huggingface = path / " tokenizer.json" ;
68
- rwkvworld = path / " tokenizer_model" ;
69
- // Check ByteLevelBPE
70
- {
71
- std::filesystem::path merges_path = path / " merges.txt" ;
72
- std::filesystem::path vocab_path = path / " vocab.json" ;
73
- std::filesystem::path added_tokens_path = path / " added_tokens.json" ;
74
- if (std::filesystem::exists (merges_path) && std::filesystem::exists (vocab_path) &&
75
- std::filesystem::exists (added_tokens_path)) {
76
- std::string vocab = LoadBytesFromFile (vocab_path.string ());
77
- std::string merges = LoadBytesFromFile (merges_path.string ());
78
- std::string added_tokens = LoadBytesFromFile (added_tokens_path.string ());
79
- return Tokenizer::FromBlobByteLevelBPE (vocab, merges, added_tokens);
80
- }
81
- }
82
- } else {
83
- sentencepiece = path.parent_path () / " tokenizer.model" ;
84
- huggingface = path.parent_path () / " tokenizer.json" ;
85
- rwkvworld = path.parent_path () / " tokenizer_model" ;
86
- }
87
- if (std::filesystem::exists (sentencepiece)) {
88
- return Tokenizer::FromBlobSentencePiece (LoadBytesFromFile (sentencepiece.string ()));
89
- }
90
- if (std::filesystem::exists (huggingface)) {
91
- return Tokenizer::FromBlobJSON (LoadBytesFromFile (huggingface.string ()));
92
- }
93
- if (std::filesystem::exists (rwkvworld)) {
94
- return Tokenizer::FromBlobRWKVWorld (rwkvworld.string ());
95
- }
96
- LOG (FATAL) << " Cannot find any tokenizer under: " << _path;
97
- }
98
45
99
46
// ------------------------------
100
47
// support functions
@@ -315,23 +262,6 @@ struct FunctionTable {
315
262
PackedFunc fkvcache_array_popn_;
316
263
};
317
264
318
- class RandomGenerator {
319
- private:
320
- std::mt19937 gen;
321
- std::uniform_real_distribution<> dis;
322
-
323
- RandomGenerator (int seed) : gen(seed), dis(0.0 , 1.0 ) {}
324
-
325
- public:
326
- static RandomGenerator& GetInstance (int seed = std::random_device{}()) {
327
- static RandomGenerator instance (seed);
328
- return instance;
329
- }
330
-
331
- double GetRandomNumber () { return dis (gen); }
332
-
333
- void SetSeed (int seed) { gen.seed (seed); }
334
- };
335
265
} // namespace
336
266
337
267
// ------------------------------
@@ -708,9 +638,10 @@ class LLMChat {
708
638
return view;
709
639
}
710
640
711
- std::vector<int32_t > PrepareBeforeEmbedding (std::string inp, bool append_conversation = true ,
712
- PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll ,
713
- picojson::object generation_config = picojson::object()) {
641
+ std::vector<int32_t > PrepareBeforeEmbedding (
642
+ std::string inp, bool append_conversation = true ,
643
+ PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll ,
644
+ picojson::object generation_config = picojson::object()) {
714
645
if (conversation_.separator_style == SeparatorStyle::kLM ||
715
646
conversation_.separator_style == SeparatorStyle::kCodeCompletion ) {
716
647
this ->ResetChat ();
@@ -742,7 +673,7 @@ class LLMChat {
742
673
String generation_config_str = " " ) {
743
674
// process generation settings
744
675
picojson::object generation_config = picojson::object ();
745
- if (!generation_config_str.empty ()) {
676
+ if (!generation_config_str.empty ()) {
746
677
picojson::value generation_config_json;
747
678
picojson::parse (generation_config_json, generation_config_str);
748
679
generation_config = generation_config_json.get <picojson::object>();
@@ -778,7 +709,8 @@ class LLMChat {
778
709
* \param embedding The embedding to prefill with.
779
710
* \param decode_next_token Whether to decode next token.
780
711
*/
781
- void PrefillWithEmbedStep (NDArray embedding, bool decode_next_token = true , String generation_config_str = " " ) {
712
+ void PrefillWithEmbedStep (NDArray embedding, bool decode_next_token = true ,
713
+ String generation_config_str = " " ) {
782
714
if (ft_.use_disco ) {
783
715
LOG (FATAL) << " NotImplementedError: Distributed inference is not supported for this model" ;
784
716
throw ;
@@ -799,7 +731,7 @@ class LLMChat {
799
731
800
732
// process generation settings
801
733
picojson::object generation_config = picojson::object ();
802
- if (!generation_config_str.empty ()) {
734
+ if (!generation_config_str.empty ()) {
803
735
picojson::value generation_config_json;
804
736
picojson::parse (generation_config_json, generation_config_str);
805
737
generation_config = generation_config_json.get <picojson::object>();
@@ -830,14 +762,15 @@ class LLMChat {
830
762
if (ft_.use_disco ) {
831
763
LOG (FATAL) << " NotImplementedError: Distributed inference is not supported for this model" ;
832
764
}
833
- NDArray embedding = Downcast<NDArray>(EmbedStep (inp, append_conversation, place_in_prompt, generation_config_str));
765
+ NDArray embedding = Downcast<NDArray>(
766
+ EmbedStep (inp, append_conversation, place_in_prompt, generation_config_str));
834
767
PrefillWithEmbedStep (embedding, decode_next_token, generation_config_str);
835
768
return ;
836
769
}
837
770
838
771
// process generation settings
839
772
picojson::object generation_config = picojson::object ();
840
- if (!generation_config_str.empty ()) {
773
+ if (!generation_config_str.empty ()) {
841
774
picojson::value generation_config_json;
842
775
picojson::parse (generation_config_json, generation_config_str);
843
776
generation_config = generation_config_json.get <picojson::object>();
@@ -876,7 +809,7 @@ class LLMChat {
876
809
void DecodeStep (String generation_config_str = " " ) {
877
810
// process generation settings
878
811
picojson::object generation_config = picojson::object ();
879
- if (!generation_config_str.empty ()) {
812
+ if (!generation_config_str.empty ()) {
880
813
picojson::value generation_config_json;
881
814
picojson::parse (generation_config_json, generation_config_str);
882
815
generation_config = generation_config_json.get <picojson::object>();
0 commit comments