Skip to content

Commit bdd9d9b

Browse files
authored
[CPP] Separate common utils out from llm_chat.cc (mlc-ai#1044)
This PR separates out the tokenizer creation function, the random number generator out from `llm_chat.cc` as a preparation step for batching inference support, since these functions/modules are also used in the same way in batching inference.
1 parent a58605f commit bdd9d9b

File tree

8 files changed

+186
-104
lines changed

8 files changed

+186
-104
lines changed

cpp/base.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/*!
2+
* Copyright (c) 2023 by Contributors
3+
* \file base.h
4+
*/
5+
6+
#ifndef MLC_LLM_DLL
7+
#ifdef _WIN32
8+
#ifdef MLC_LLM_EXPORTS
9+
#define MLC_LLM_DLL __declspec(dllexport)
10+
#else
11+
#define MLC_LLM_DLL __declspec(dllimport)
12+
#endif
13+
#else
14+
#define MLC_LLM_DLL __attribute__((visibility("default")))
15+
#endif
16+
#endif

cpp/image_embed.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,7 @@
66
#include <tvm/runtime/container/string.h>
77
#include <tvm/runtime/module.h>
88

9-
#ifndef MLC_LLM_DLL
10-
#ifdef _WIN32
11-
#ifdef MLC_LLM_EXPORTS
12-
#define MLC_LLM_DLL __declspec(dllexport)
13-
#else
14-
#define MLC_LLM_DLL __declspec(dllimport)
15-
#endif
16-
#else
17-
#define MLC_LLM_DLL __attribute__((visibility("default")))
18-
#endif
19-
#endif
9+
#include "base.h"
2010

2111
namespace mlc {
2212
namespace llm {

cpp/llm_chat.cc

Lines changed: 15 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -32,69 +32,16 @@
3232
#include <vector>
3333

3434
#include "conversation.h"
35+
#include "random.h"
36+
#include "support.h"
37+
#include "tokenizers.h"
3538

3639
namespace mlc {
3740
namespace llm {
3841

3942
using tvm::Device;
4043
using namespace tvm::runtime;
4144
namespace {
42-
//----------------------------
43-
// Tokenizers
44-
//----------------------------
45-
using tokenizers::Tokenizer;
46-
47-
std::string LoadBytesFromFile(const std::string& path) {
48-
std::ifstream fs(path, std::ios::in | std::ios::binary);
49-
ICHECK(!fs.fail()) << "Cannot open " << path;
50-
std::string data;
51-
fs.seekg(0, std::ios::end);
52-
size_t size = static_cast<size_t>(fs.tellg());
53-
fs.seekg(0, std::ios::beg);
54-
data.resize(size);
55-
fs.read(data.data(), size);
56-
return data;
57-
}
58-
59-
std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path) {
60-
std::filesystem::path path(_path);
61-
std::filesystem::path sentencepiece;
62-
std::filesystem::path huggingface;
63-
std::filesystem::path rwkvworld;
64-
CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path;
65-
if (std::filesystem::is_directory(path)) {
66-
sentencepiece = path / "tokenizer.model";
67-
huggingface = path / "tokenizer.json";
68-
rwkvworld = path / "tokenizer_model";
69-
// Check ByteLevelBPE
70-
{
71-
std::filesystem::path merges_path = path / "merges.txt";
72-
std::filesystem::path vocab_path = path / "vocab.json";
73-
std::filesystem::path added_tokens_path = path / "added_tokens.json";
74-
if (std::filesystem::exists(merges_path) && std::filesystem::exists(vocab_path) &&
75-
std::filesystem::exists(added_tokens_path)) {
76-
std::string vocab = LoadBytesFromFile(vocab_path.string());
77-
std::string merges = LoadBytesFromFile(merges_path.string());
78-
std::string added_tokens = LoadBytesFromFile(added_tokens_path.string());
79-
return Tokenizer::FromBlobByteLevelBPE(vocab, merges, added_tokens);
80-
}
81-
}
82-
} else {
83-
sentencepiece = path.parent_path() / "tokenizer.model";
84-
huggingface = path.parent_path() / "tokenizer.json";
85-
rwkvworld = path.parent_path() / "tokenizer_model";
86-
}
87-
if (std::filesystem::exists(sentencepiece)) {
88-
return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string()));
89-
}
90-
if (std::filesystem::exists(huggingface)) {
91-
return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string()));
92-
}
93-
if (std::filesystem::exists(rwkvworld)) {
94-
return Tokenizer::FromBlobRWKVWorld(rwkvworld.string());
95-
}
96-
LOG(FATAL) << "Cannot find any tokenizer under: " << _path;
97-
}
9845

9946
//------------------------------
10047
// support functions
@@ -315,23 +262,6 @@ struct FunctionTable {
315262
PackedFunc fkvcache_array_popn_;
316263
};
317264

318-
class RandomGenerator {
319-
private:
320-
std::mt19937 gen;
321-
std::uniform_real_distribution<> dis;
322-
323-
RandomGenerator(int seed) : gen(seed), dis(0.0, 1.0) {}
324-
325-
public:
326-
static RandomGenerator& GetInstance(int seed = std::random_device{}()) {
327-
static RandomGenerator instance(seed);
328-
return instance;
329-
}
330-
331-
double GetRandomNumber() { return dis(gen); }
332-
333-
void SetSeed(int seed) { gen.seed(seed); }
334-
};
335265
} // namespace
336266

337267
//------------------------------
@@ -708,9 +638,10 @@ class LLMChat {
708638
return view;
709639
}
710640

711-
std::vector<int32_t> PrepareBeforeEmbedding(std::string inp, bool append_conversation = true,
712-
PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll,
713-
picojson::object generation_config = picojson::object()) {
641+
std::vector<int32_t> PrepareBeforeEmbedding(
642+
std::string inp, bool append_conversation = true,
643+
PlaceInPrompt place_in_prompt = PlaceInPrompt::kAll,
644+
picojson::object generation_config = picojson::object()) {
714645
if (conversation_.separator_style == SeparatorStyle::kLM ||
715646
conversation_.separator_style == SeparatorStyle::kCodeCompletion) {
716647
this->ResetChat();
@@ -742,7 +673,7 @@ class LLMChat {
742673
String generation_config_str = "") {
743674
// process generation settings
744675
picojson::object generation_config = picojson::object();
745-
if(!generation_config_str.empty()) {
676+
if (!generation_config_str.empty()) {
746677
picojson::value generation_config_json;
747678
picojson::parse(generation_config_json, generation_config_str);
748679
generation_config = generation_config_json.get<picojson::object>();
@@ -778,7 +709,8 @@ class LLMChat {
778709
* \param embedding The embedding to prefill with.
779710
* \param decode_next_token Whether to decode next token.
780711
*/
781-
void PrefillWithEmbedStep(NDArray embedding, bool decode_next_token = true, String generation_config_str = "") {
712+
void PrefillWithEmbedStep(NDArray embedding, bool decode_next_token = true,
713+
String generation_config_str = "") {
782714
if (ft_.use_disco) {
783715
LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model";
784716
throw;
@@ -799,7 +731,7 @@ class LLMChat {
799731

800732
// process generation settings
801733
picojson::object generation_config = picojson::object();
802-
if(!generation_config_str.empty()) {
734+
if (!generation_config_str.empty()) {
803735
picojson::value generation_config_json;
804736
picojson::parse(generation_config_json, generation_config_str);
805737
generation_config = generation_config_json.get<picojson::object>();
@@ -830,14 +762,15 @@ class LLMChat {
830762
if (ft_.use_disco) {
831763
LOG(FATAL) << "NotImplementedError: Distributed inference is not supported for this model";
832764
}
833-
NDArray embedding = Downcast<NDArray>(EmbedStep(inp, append_conversation, place_in_prompt, generation_config_str));
765+
NDArray embedding = Downcast<NDArray>(
766+
EmbedStep(inp, append_conversation, place_in_prompt, generation_config_str));
834767
PrefillWithEmbedStep(embedding, decode_next_token, generation_config_str);
835768
return;
836769
}
837770

838771
// process generation settings
839772
picojson::object generation_config = picojson::object();
840-
if(!generation_config_str.empty()) {
773+
if (!generation_config_str.empty()) {
841774
picojson::value generation_config_json;
842775
picojson::parse(generation_config_json, generation_config_str);
843776
generation_config = generation_config_json.get<picojson::object>();
@@ -876,7 +809,7 @@ class LLMChat {
876809
void DecodeStep(String generation_config_str = "") {
877810
// process generation settings
878811
picojson::object generation_config = picojson::object();
879-
if(!generation_config_str.empty()) {
812+
if (!generation_config_str.empty()) {
880813
picojson::value generation_config_json;
881814
picojson::parse(generation_config_json, generation_config_str);
882815
generation_config = generation_config_json.get<picojson::object>();

cpp/llm_chat.h

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,7 @@
66
#include <tvm/runtime/container/string.h>
77
#include <tvm/runtime/module.h>
88

9-
#ifndef MLC_LLM_DLL
10-
#ifdef _WIN32
11-
#ifdef MLC_LLM_EXPORTS
12-
#define MLC_LLM_DLL __declspec(dllexport)
13-
#else
14-
#define MLC_LLM_DLL __declspec(dllimport)
15-
#endif
16-
#else
17-
#define MLC_LLM_DLL __attribute__((visibility("default")))
18-
#endif
19-
#endif
9+
#include "base.h"
2010

2111
namespace mlc {
2212
namespace llm {

cpp/random.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/*!
2+
* Copyright (c) 2023 by Contributors
3+
* \file random.h
4+
* \brief Header of random number generator.
5+
*/
6+
7+
#ifndef MLC_LLM_RANDOM_H_
8+
#define MLC_LLM_RANDOM_H_
9+
10+
#include <random>
11+
12+
namespace mlc {
13+
namespace llm {
14+
15+
// Random number generator
16+
class RandomGenerator {
17+
private:
18+
std::mt19937 gen;
19+
std::uniform_real_distribution<> dis;
20+
21+
RandomGenerator(int seed) : gen(seed), dis(0.0, 1.0) {}
22+
23+
public:
24+
static RandomGenerator& GetInstance(int seed = std::random_device{}()) {
25+
static RandomGenerator instance(seed);
26+
return instance;
27+
}
28+
29+
double GetRandomNumber() { return dis(gen); }
30+
31+
void SetSeed(int seed) { gen.seed(seed); }
32+
};
33+
34+
} // namespace llm
35+
} // namespace mlc
36+
37+
#endif // MLC_LLM_RANDOM_H_

cpp/support.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*!
2+
* Copyright (c) 2023 by Contributors
3+
* \file support.h
4+
* \brief Header of utilities.
5+
*/
6+
7+
#ifndef MLC_LLM_COMMON_H_
8+
#define MLC_LLM_COMMON_H_
9+
10+
#include <fstream>
11+
#include <string>
12+
13+
namespace mlc {
14+
namespace llm {
15+
16+
inline std::string LoadBytesFromFile(const std::string& path) {
17+
std::ifstream fs(path, std::ios::in | std::ios::binary);
18+
ICHECK(!fs.fail()) << "Cannot open " << path;
19+
std::string data;
20+
fs.seekg(0, std::ios::end);
21+
size_t size = static_cast<size_t>(fs.tellg());
22+
fs.seekg(0, std::ios::beg);
23+
data.resize(size);
24+
fs.read(data.data(), size);
25+
return data;
26+
}
27+
28+
} // namespace llm
29+
} // namespace mlc
30+
31+
#endif // MLC_LLM_COMMON_H_

cpp/tokenizers.cc

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*!
2+
* Copyright (c) 2023 by Contributors
3+
* \file tokenizer.cc
4+
*/
5+
6+
#include "tokenizers.h"
7+
8+
#include <tokenizers_cpp.h>
9+
#include <tvm/runtime/logging.h>
10+
11+
#include <filesystem>
12+
#include <fstream>
13+
#include <string>
14+
15+
#include "support.h"
16+
17+
namespace mlc {
18+
namespace llm {
19+
20+
std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path) {
21+
std::filesystem::path path(_path);
22+
std::filesystem::path sentencepiece;
23+
std::filesystem::path huggingface;
24+
std::filesystem::path rwkvworld;
25+
CHECK(std::filesystem::exists(path)) << "Cannot find tokenizer via path: " << _path;
26+
if (std::filesystem::is_directory(path)) {
27+
sentencepiece = path / "tokenizer.model";
28+
huggingface = path / "tokenizer.json";
29+
rwkvworld = path / "tokenizer_model";
30+
// Check ByteLevelBPE
31+
{
32+
std::filesystem::path merges_path = path / "merges.txt";
33+
std::filesystem::path vocab_path = path / "vocab.json";
34+
std::filesystem::path added_tokens_path = path / "added_tokens.json";
35+
if (std::filesystem::exists(merges_path) && std::filesystem::exists(vocab_path) &&
36+
std::filesystem::exists(added_tokens_path)) {
37+
std::string vocab = LoadBytesFromFile(vocab_path.string());
38+
std::string merges = LoadBytesFromFile(merges_path.string());
39+
std::string added_tokens = LoadBytesFromFile(added_tokens_path.string());
40+
return Tokenizer::FromBlobByteLevelBPE(vocab, merges, added_tokens);
41+
}
42+
}
43+
} else {
44+
sentencepiece = path.parent_path() / "tokenizer.model";
45+
huggingface = path.parent_path() / "tokenizer.json";
46+
rwkvworld = path.parent_path() / "tokenizer_model";
47+
}
48+
if (std::filesystem::exists(sentencepiece)) {
49+
return Tokenizer::FromBlobSentencePiece(LoadBytesFromFile(sentencepiece.string()));
50+
}
51+
if (std::filesystem::exists(huggingface)) {
52+
return Tokenizer::FromBlobJSON(LoadBytesFromFile(huggingface.string()));
53+
}
54+
if (std::filesystem::exists(rwkvworld)) {
55+
return Tokenizer::FromBlobRWKVWorld(rwkvworld.string());
56+
}
57+
LOG(FATAL) << "Cannot find any tokenizer under: " << _path;
58+
}
59+
60+
} // namespace llm
61+
} // namespace mlc

cpp/tokenizers.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/*!
2+
* Copyright (c) 2023 by Contributors
3+
* \file tokenizers.h
4+
* \brief Header of tokenizer related functions.
5+
*/
6+
7+
#ifndef MLC_LLM_TOKENIZER_H_
8+
#define MLC_LLM_TOKENIZER_H_
9+
10+
#include <tokenizers_cpp.h>
11+
12+
#include "base.h"
13+
14+
namespace mlc {
15+
namespace llm {
16+
17+
using tokenizers::Tokenizer;
18+
19+
MLC_LLM_DLL std::unique_ptr<Tokenizer> TokenizerFromPath(const std::string& _path);
20+
21+
} // namespace llm
22+
} // namespace mlc
23+
24+
#endif // MLC_LLM_TOKENIZER_H_

0 commit comments

Comments
 (0)