Skip to content

Commit b97bc39

Browse files
pcuencapacman100ggerganov
authored
llama : support Llama 3 HF conversion (#6745)
* Support Llama 3 conversion The tokenizer is BPE. * style * Accept suggestion Co-authored-by: Sourab Mangrulkar <[email protected]> * llama : add llama_token_is_eog() ggml-ci * llama : auto-detect more EOT tokens when missing in KV data * convert : replacing EOS token is a hack * llama : fix codegemma EOT token + add TODOs * llama : fix model type string for 8B model --------- Co-authored-by: Sourab Mangrulkar <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
1 parent b8109bc commit b97bc39

File tree

20 files changed

+124
-65
lines changed

20 files changed

+124
-65
lines changed

convert-hf-to-gguf.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1301,15 +1301,23 @@ def set_vocab(self):
13011301
try:
13021302
self. _set_vocab_sentencepiece()
13031303
except FileNotFoundError:
1304-
self._set_vocab_llama_hf()
1305-
1306-
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
1307-
special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
1308-
special_vocab._set_special_token("prefix", 32007)
1309-
special_vocab._set_special_token("suffix", 32008)
1310-
special_vocab._set_special_token("middle", 32009)
1311-
special_vocab._set_special_token("eot", 32010)
1312-
special_vocab.add_to_gguf(self.gguf_writer)
1304+
try:
1305+
self._set_vocab_llama_hf()
1306+
except (FileNotFoundError, TypeError):
1307+
# Llama 3
1308+
self._set_vocab_gpt2()
1309+
1310+
# Apply to CodeLlama only (and ignore for Llama 3 with a vocab size of 128256)
1311+
if self.hparams.get("vocab_size", 32000) == 32016:
1312+
special_vocab = gguf.SpecialVocab(
1313+
self.dir_model, load_merges=False,
1314+
special_token_types = ['prefix', 'suffix', 'middle', 'eot']
1315+
)
1316+
special_vocab._set_special_token("prefix", 32007)
1317+
special_vocab._set_special_token("suffix", 32008)
1318+
special_vocab._set_special_token("middle", 32009)
1319+
special_vocab._set_special_token("eot", 32010)
1320+
special_vocab.add_to_gguf(self.gguf_writer)
13131321

13141322
def set_gguf_parameters(self):
13151323
super().set_gguf_parameters()
@@ -2194,6 +2202,8 @@ def set_vocab(self):
21942202
old_eos = special_vocab.special_token_ids["eos"]
21952203
if "chat" in os.path.basename(self.dir_model.absolute()):
21962204
# For the chat model, we replace the eos with '<|im_end|>'.
2205+
# TODO: this is a hack, should be fixed
2206+
# https://github.com/ggerganov/llama.cpp/pull/6745#issuecomment-2067687048
21972207
special_vocab.special_token_ids["eos"] = self._try_get_sft_eos(tokenizer)
21982208
print(f"Replace eos:{old_eos} with a special token:{special_vocab.special_token_ids['eos']} \
21992209
in chat mode so that the conversation can end normally.")
@@ -2429,12 +2439,15 @@ class GemmaModel(Model):
24292439

24302440
def set_vocab(self):
24312441
self._set_vocab_sentencepiece()
2442+
2443+
# TODO: these special tokens should be exported only for the CodeGemma family
24322444
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
2433-
special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
2445+
special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot'])
24342446
special_vocab._set_special_token("prefix", 67)
24352447
special_vocab._set_special_token("suffix", 69)
24362448
special_vocab._set_special_token("middle", 68)
2437-
special_vocab._set_special_token("eot", 70)
2449+
special_vocab._set_special_token("fsep", 70)
2450+
special_vocab._set_special_token("eot", 107)
24382451
special_vocab.add_to_gguf(self.gguf_writer)
24392452

24402453
def set_gguf_parameters(self):
@@ -2523,28 +2536,34 @@ def set_vocab(self):
25232536

25242537
field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
25252538
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]))
2539+
25262540
field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
25272541
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
2542+
25282543
field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
25292544
self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
2545+
25302546
field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
25312547
self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
2548+
25322549
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
25332550
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
2551+
25342552
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
25352553
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
2554+
25362555
field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
25372556
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])
25382557

25392558
def set_gguf_parameters(self):
2540-
d_model = self.find_hparam(["hidden_size", "d_model"])
2541-
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
2559+
d_model = self.find_hparam(["hidden_size", "d_model"])
2560+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
25422561
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
2543-
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
2562+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
25442563
# ceiling division
25452564
# ref: https://stackoverflow.com/a/17511341/22827863
25462565
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
2547-
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
2566+
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
25482567
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
25492568

25502569
# Fail early for models which don't have a block expansion factor of 2

convert.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,14 @@ def __init__(self, base_path: Path):
525525

526526
# pre-check so we know if we need transformers
527527
tokenizer_model: dict[str, Any] = tokenizer_json['model']
528-
if (
528+
is_llama3 = (
529+
tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False)
530+
and not tokenizer_model.get('byte_fallback', True)
531+
)
532+
if is_llama3:
533+
raise TypeError('Llama 3 must be converted with BpeVocab')
534+
535+
if not is_llama3 and (
529536
tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False)
530537
or tokenizer_json['decoder']['type'] != 'Sequence'
531538
):

examples/batched.swift/Sources/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ while n_cur <= n_len {
153153
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
154154

155155
// is it an end of stream? -> mark the stream as finished
156-
if new_token_id == llama_token_eos(model) || n_cur == n_len {
156+
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
157157
i_batch[i] = -1
158158
// print("")
159159
if n_parallel > 1 {

examples/batched/batched.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ int main(int argc, char ** argv) {
191191

192192
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
193193

194-
// is it an end of stream? -> mark the stream as finished
195-
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
194+
// is it an end of generation? -> mark the stream as finished
195+
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
196196
i_batch[i] = -1;
197197
LOG_TEE("\n");
198198
if (n_parallel > 1) {

examples/beam-search/beam-search.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct beam_search_callback_data {
4747
// In this case, end-of-beam (eob) is equivalent to end-of-sentence (eos) but this need not always be the same.
4848
// For example, eob can be flagged due to maximum token length, stop words, etc.
4949
static bool is_at_eob(const beam_search_callback_data & callback_data, const llama_token * tokens, size_t n_tokens) {
50-
return n_tokens && tokens[n_tokens-1] == llama_token_eos(llama_get_model(callback_data.ctx));
50+
return n_tokens && llama_token_is_eog(llama_get_model(callback_data.ctx), tokens[n_tokens-1]);
5151
}
5252

5353
// Function matching type llama_beam_search_callback_fn_t.

examples/infill/infill.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ int main(int argc, char ** argv) {
586586

587587
// deal with eot token in infill mode
588588
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
589-
if(is_interacting && !params.interactive_first) {
589+
if (is_interacting && !params.interactive_first) {
590590
// print an eot token
591591
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
592592
}
@@ -651,8 +651,8 @@ int main(int argc, char ** argv) {
651651
// LOG_TEE("took new input\n");
652652
is_interacting = false;
653653
}
654-
// deal with end of text token in interactive mode
655-
else if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
654+
// deal with end of generation tokens in interactive mode
655+
else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
656656
LOG("found EOS token\n");
657657

658658
if (params.interactive) {
@@ -731,8 +731,8 @@ int main(int argc, char ** argv) {
731731
}
732732
}
733733

734-
// end of text token
735-
if (!embd.empty() && embd.back() == llama_token_eos(model) && !params.interactive) {
734+
// end of generation
735+
if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !params.interactive) {
736736
break;
737737
}
738738

examples/llama.android/app/src/main/cpp/llama-android.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ Java_com_example_llama_Llm_completion_1loop(
408408
const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
409409

410410
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
411-
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
411+
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
412412
return env->NewStringUTF("");
413413
}
414414

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ actor LlamaContext {
158158
new_token_id = llama_sample_token_greedy(context, &candidates_p)
159159
}
160160

161-
if new_token_id == llama_token_eos(model) || n_cur == n_len {
161+
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
162162
print("\n")
163163
let new_token_str = String(cString: temporary_invalid_cchars + [0])
164164
temporary_invalid_cchars.removeAll()

examples/llava/llava-cli.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
4545
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
4646
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
4747
static std::string ret;
48-
if (id == llama_token_eos(llama_get_model(ctx_llama))) {
48+
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
4949
ret = "</s>";
5050
} else {
5151
ret = llama_token_to_piece(ctx_llama, id);

examples/lookahead/lookahead.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ int main(int argc, char ** argv) {
299299
}
300300
fflush(stdout);
301301

302-
if (id == llama_token_eos(model)) {
302+
if (llama_token_is_eog(model, id)) {
303303
has_eos = true;
304304
}
305305

examples/lookup/lookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ int main(int argc, char ** argv){
141141
printf("%s", token_str.c_str());
142142
}
143143

144-
if (id == llama_token_eos(model)) {
144+
if (llama_token_is_eog(model, id)) {
145145
has_eos = true;
146146
}
147147

examples/main/main.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -795,8 +795,8 @@ int main(int argc, char ** argv) {
795795
}
796796
}
797797

798-
// deal with end of text token in interactive mode
799-
if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) {
798+
// deal with end of generation tokens in interactive mode
799+
if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
800800
LOG("found EOS token\n");
801801

802802
if (params.interactive) {
@@ -920,8 +920,8 @@ int main(int argc, char ** argv) {
920920
}
921921
}
922922

923-
// end of text token
924-
if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive || params.chatml)) {
923+
// end of generation
924+
if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !(params.instruct || params.interactive || params.chatml)) {
925925
LOG_TEE(" [end of text]\n");
926926
break;
927927
}

examples/parallel/parallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ int main(int argc, char ** argv) {
359359
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
360360

361361
if (client.n_decoded > 2 &&
362-
(id == llama_token_eos(model) ||
362+
(llama_token_is_eog(model, id) ||
363363
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
364364
client.response.find("User:") != std::string::npos ||
365365
client.response.find('\n') != std::string::npos)) {

examples/passkey/passkey.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ int main(int argc, char ** argv) {
252252
// sample the most likely token
253253
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
254254

255-
// is it an end of stream?
256-
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
255+
// is it an end of generation?
256+
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
257257
LOG_TEE("\n");
258258

259259
break;

examples/server/server.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1201,7 +1201,7 @@ struct server_context {
12011201
});
12021202
}
12031203

1204-
if (result.tok == llama_token_eos(model)) {
1204+
if (llama_token_is_eog(model, result.tok)) {
12051205
slot.stopped_eos = true;
12061206
slot.has_next_token = false;
12071207

examples/server/utils.hpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,10 +381,6 @@ static json oaicompat_completion_params_parse(
381381
} else {
382382
llama_params["stop"] = json_value(body, "stop", json::array());
383383
}
384-
// Some chat templates don't use EOS token to stop generation
385-
// We must add their end sequences to list of stop words
386-
llama_params["stop"].push_back("<|im_end|>"); // chatml
387-
llama_params["stop"].push_back("<end_of_turn>"); // gemma
388384

389385
// Handle "response_format" field
390386
if (body.contains("response_format")) {

examples/simple/simple.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ int main(int argc, char ** argv) {
133133
// sample the most likely token
134134
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
135135

136-
// is it an end of stream?
137-
if (new_token_id == llama_token_eos(model) || n_cur == n_len) {
136+
// is it an end of generation?
137+
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
138138
LOG_TEE("\n");
139139

140140
break;

examples/speculative/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ int main(int argc, char ** argv) {
360360
}
361361
}
362362

363-
if (token_id == llama_token_eos(model_tgt)) {
363+
if (llama_token_is_eog(model_tgt, token_id)) {
364364
has_eos = true;
365365
}
366366
++n_predict;

0 commit comments

Comments
 (0)