Skip to content

Commit 1a15955

Browse files
staviqggerganov
andauthored
tokenizer : special token handling (#3538)
* Rewrite special token handling from #1931 * shorten param name, add st verification by type * use offsets instead of copy by substr * formatting, remove copying iterator on delete * llama : normalize code-style * swift fix * print pfx/sfx if verb, main: split pfx input sfx * dont add space when using special tokens * minor : comment + spacing --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 281ef73 commit 1a15955

File tree

7 files changed

+332
-39
lines changed

7 files changed

+332
-39
lines changed

common/common.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -879,21 +879,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
879879
std::vector<llama_token> llama_tokenize(
880880
const struct llama_context * ctx,
881881
const std::string & text,
882-
bool add_bos) {
883-
return llama_tokenize(llama_get_model(ctx), text, add_bos);
882+
bool add_bos,
883+
bool special) {
884+
return llama_tokenize(llama_get_model(ctx), text, add_bos, special);
884885
}
885886

886887
std::vector<llama_token> llama_tokenize(
887888
const struct llama_model * model,
888889
const std::string & text,
889-
bool add_bos) {
890+
bool add_bos,
891+
bool special) {
890892
// upper limit for the number of tokens
891893
int n_tokens = text.length() + add_bos;
892894
std::vector<llama_token> result(n_tokens);
893-
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
895+
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
894896
if (n_tokens < 0) {
895897
result.resize(-n_tokens);
896-
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
898+
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special);
897899
GGML_ASSERT(check == -n_tokens);
898900
} else {
899901
result.resize(n_tokens);

common/common.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
137137
std::vector<llama_token> llama_tokenize(
138138
const struct llama_context * ctx,
139139
const std::string & text,
140-
bool add_bos);
140+
bool add_bos,
141+
bool special = false);
141142

142143
std::vector<llama_token> llama_tokenize(
143144
const struct llama_model * model,
144145
const std::string & text,
145-
bool add_bos);
146+
bool add_bos,
147+
bool special = false);
146148

147149
// tokenizes a token into a piece
148150
// should work similar to Python's `tokenizer.id_to_piece`

common/train.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,7 @@ size_t tokenize_file(
863863
(int) buf.size(),
864864
out_tokens.data(),
865865
(int) out_tokens.size(),
866-
false);
866+
false, false);
867867
if (n_tokens < 0) {
868868
out_tokens.resize(-n_tokens);
869869
n_tokens = llama_tokenize(
@@ -872,7 +872,7 @@ size_t tokenize_file(
872872
(int) buf.size(),
873873
out_tokens.data(),
874874
(int) out_tokens.size(),
875-
false);
875+
false, false);
876876
}
877877
if (n_tokens >= 0) {
878878
out_tokens.resize(n_tokens);
@@ -966,15 +966,15 @@ size_t tokenize_file(
966966
(int) buf_sample.size(),
967967
tok_sample.data(),
968968
(int) tok_sample.size(),
969-
false);
969+
false, false);
970970
if (n_tokens < 0) {
971971
tok_sample.resize(-n_tokens);
972972
n_tokens = llama_tokenize(llama_get_model(lctx),
973973
buf_sample.data(),
974974
(int) buf_sample.size(),
975975
tok_sample.data(),
976976
(int) tok_sample.size(),
977-
false);
977+
false, false);
978978
GGML_ASSERT(n_tokens >= 0);
979979
}
980980
GGML_ASSERT(n_tokens <= (int) tok_sample.size());

examples/batched.swift/Sources/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ llama_print_timings(context)
209209
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
210210
let n_tokens = text.count + (add_bos ? 1 : 0)
211211
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
212-
let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos)
212+
let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
213213
var swiftTokens: [llama_token] = []
214214
for i in 0 ..< tokenCount {
215215
swiftTokens.append(tokens[Int(i)])

examples/main/main.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ int main(int argc, char ** argv) {
238238

239239
if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
240240
LOG("tokenize the prompt\n");
241-
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
241+
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
242242
} else {
243243
LOG("use session tokens\n");
244244
embd_inp = session_tokens;
@@ -260,10 +260,10 @@ int main(int argc, char ** argv) {
260260
if (ctx_guidance) {
261261
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
262262

263-
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos);
263+
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true);
264264
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));
265265

266-
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
266+
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
267267
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));
268268

269269
original_prompt_len = original_inp.size();
@@ -320,8 +320,8 @@ int main(int argc, char ** argv) {
320320
}
321321

322322
// prefix & suffix for instruct mode
323-
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos);
324-
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
323+
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
324+
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);
325325

326326
LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
327327
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
@@ -383,6 +383,12 @@ int main(int argc, char ** argv) {
383383
if (!params.antiprompt.empty()) {
384384
for (const auto & antiprompt : params.antiprompt) {
385385
LOG_TEE("Reverse prompt: '%s'\n", antiprompt.c_str());
386+
if (params.verbose_prompt) {
387+
auto tmp = ::llama_tokenize(ctx, antiprompt, false, true);
388+
for (int i = 0; i < (int) tmp.size(); i++) {
389+
LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
390+
}
391+
}
386392
}
387393
}
388394

@@ -392,10 +398,22 @@ int main(int argc, char ** argv) {
392398

393399
if (!params.input_prefix.empty()) {
394400
LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str());
401+
if (params.verbose_prompt) {
402+
auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true);
403+
for (int i = 0; i < (int) tmp.size(); i++) {
404+
LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
405+
}
406+
}
395407
}
396408

397409
if (!params.input_suffix.empty()) {
398410
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
411+
if (params.verbose_prompt) {
412+
auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true);
413+
for (int i = 0; i < (int) tmp.size(); i++) {
414+
LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str());
415+
}
416+
}
399417
}
400418
}
401419
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
@@ -717,7 +735,7 @@ int main(int argc, char ** argv) {
717735
if (params.interactive) {
718736
if (!params.antiprompt.empty()) {
719737
// tokenize and inject first reverse prompt
720-
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
738+
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
721739
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
722740
is_antiprompt = true;
723741
}
@@ -744,8 +762,7 @@ int main(int argc, char ** argv) {
744762
std::string buffer;
745763
if (!params.input_prefix.empty()) {
746764
LOG("appending input prefix: '%s'\n", params.input_prefix.c_str());
747-
buffer += params.input_prefix;
748-
printf("%s", buffer.c_str());
765+
printf("%s", params.input_prefix.c_str());
749766
}
750767

751768
// color user input only
@@ -767,7 +784,6 @@ int main(int argc, char ** argv) {
767784
// append input suffix if any
768785
if (!params.input_suffix.empty()) {
769786
LOG("appending input suffix: '%s'\n", params.input_suffix.c_str());
770-
buffer += params.input_suffix;
771787
printf("%s", params.input_suffix.c_str());
772788
}
773789

@@ -782,10 +798,14 @@ int main(int argc, char ** argv) {
782798
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
783799
}
784800

785-
const auto line_inp = ::llama_tokenize(ctx, buffer, false);
801+
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
802+
const auto line_inp = ::llama_tokenize(ctx, buffer, false, false);
803+
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
786804
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));
787805

806+
embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end());
788807
embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
808+
embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end());
789809

790810
// instruct mode: insert response suffix
791811
if (params.instruct) {

0 commit comments

Comments
 (0)