Skip to content

Commit b592c70

Browse files
committed
Rewrite special token handling from ggml-org#1931
1 parent c47066d commit b592c70

File tree

6 files changed

+243
-31
lines changed

6 files changed

+243
-31
lines changed

common/common.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -862,21 +862,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
862862
std::vector<llama_token> llama_tokenize(
863863
const struct llama_context * ctx,
864864
const std::string & text,
865-
bool add_bos) {
866-
return llama_tokenize(llama_get_model(ctx), text, add_bos);
865+
bool add_bos,
866+
bool allow_special_tokens) {
867+
return llama_tokenize(llama_get_model(ctx), text, add_bos, allow_special_tokens);
867868
}
868869

869870
std::vector<llama_token> llama_tokenize(
870871
const struct llama_model * model,
871872
const std::string & text,
872-
bool add_bos) {
873+
bool add_bos,
874+
bool allow_special_tokens) {
873875
// upper limit for the number of tokens
874876
int n_tokens = text.length() + add_bos;
875877
std::vector<llama_token> result(n_tokens);
876-
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
878+
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
877879
if (n_tokens < 0) {
878880
result.resize(-n_tokens);
879-
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
881+
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
880882
GGML_ASSERT(check == -n_tokens);
881883
} else {
882884
result.resize(n_tokens);

common/common.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
151151
std::vector<llama_token> llama_tokenize(
152152
const struct llama_context * ctx,
153153
const std::string & text,
154-
bool add_bos);
154+
bool add_bos,
155+
bool allow_special_tokens = false);
155156

156157
std::vector<llama_token> llama_tokenize(
157158
const struct llama_model * model,
158159
const std::string & text,
159-
bool add_bos);
160+
bool add_bos,
161+
bool allow_special_tokens = false);
160162

161163
// tokenizes a token into a piece
162164
// should work similar to Python's `tokenizer.id_to_piece`

common/train.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,7 @@ size_t tokenize_file(
863863
(int) buf.size(),
864864
out_tokens.data(),
865865
(int) out_tokens.size(),
866-
false);
866+
false,false);
867867
if (n_tokens < 0) {
868868
out_tokens.resize(-n_tokens);
869869
n_tokens = llama_tokenize(
@@ -872,7 +872,7 @@ size_t tokenize_file(
872872
(int) buf.size(),
873873
out_tokens.data(),
874874
(int) out_tokens.size(),
875-
false);
875+
false,false);
876876
}
877877
if (n_tokens >= 0) {
878878
out_tokens.resize(n_tokens);
@@ -966,15 +966,15 @@ size_t tokenize_file(
966966
(int) buf_sample.size(),
967967
tok_sample.data(),
968968
(int) tok_sample.size(),
969-
false);
969+
false,false);
970970
if (n_tokens < 0) {
971971
tok_sample.resize(-n_tokens);
972972
n_tokens = llama_tokenize(llama_get_model(lctx),
973973
buf_sample.data(),
974974
(int) buf_sample.size(),
975975
tok_sample.data(),
976976
(int) tok_sample.size(),
977-
false);
977+
false,false);
978978
GGML_ASSERT(n_tokens >= 0);
979979
}
980980
GGML_ASSERT(n_tokens <= (int) tok_sample.size());

examples/main/main.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ int main(int argc, char ** argv) {
237237

238238
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
239239
LOG("tokenize the prompt\n");
240-
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
240+
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
241241
} else {
242242
LOG("use session tokens\n");
243243
embd_inp = session_tokens;
@@ -259,10 +259,10 @@ int main(int argc, char ** argv) {
259259
if (ctx_guidance) {
260260
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt));
261261

262-
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
262+
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos, true);
263263
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
264264

265-
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
265+
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
266266
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
267267

268268
original_prompt_len = original_inp.size();
@@ -316,8 +316,8 @@ int main(int argc, char ** argv) {
316316
}
317317

318318
// prefix & suffix for instruct mode
319-
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos);
320-
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
319+
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
320+
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
321321

322322
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
323323
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
@@ -715,7 +715,7 @@ int main(int argc, char ** argv) {
715715
if (params.interactive) {
716716
if (!params.antiprompt.empty()) {
717717
// tokenize and inject first reverse prompt
718-
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
718+
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
719719
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
720720
is_antiprompt = true;
721721
}
@@ -780,7 +780,7 @@ int main(int argc, char ** argv) {
780780
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
781781
}
782782

783-
const auto line_inp = ::llama_tokenize(ctx, buffer, false);
783+
const auto line_inp = ::llama_tokenize(ctx, buffer, false, true);
784784
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
785785

786786
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());

0 commit comments

Comments
 (0)