@@ -128,13 +128,12 @@ static const std::map<e_model, size_t> & MEM_REQ_EVAL()
128
128
// default hparams (LLaMA 7B)
129
129
struct llama_hparams {
130
130
uint32_t n_vocab = 32000 ;
131
- uint32_t n_vocab_sp = 0 ;
131
+ uint32_t n_vocab_base = 32000 ;
132
132
uint32_t n_ctx = 512 ; // this is provided as user input?
133
133
uint32_t n_embd = 4096 ;
134
134
uint32_t n_mult = 256 ;
135
135
uint32_t n_head = 32 ;
136
136
uint32_t n_layer = 32 ;
137
- uint32_t n_rot = 64 ;
138
137
enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16;
139
138
140
139
bool operator !=(const llama_hparams & other) const {
@@ -460,7 +459,6 @@ enum llama_file_version {
460
459
LLAMA_FILE_VERSION_GGJT_V1, // added padding
461
460
LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format
462
461
LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format
463
- LLAMA_FILE_VERSION_GGJT_V4, // improved support for added/special tokens
464
462
};
465
463
466
464
struct llama_file_loader {
@@ -476,6 +474,7 @@ struct llama_file_loader {
476
474
read_hparams ();
477
475
read_vocab ();
478
476
read_tensor_metadata (file_idx, tensors_map);
477
+ set_vocab_sp ();
479
478
}
480
479
void read_magic () {
481
480
uint32_t magic = file.read_u32 ();
@@ -498,7 +497,6 @@ struct llama_file_loader {
498
497
case 1 : file_version = LLAMA_FILE_VERSION_GGJT_V1; return ;
499
498
case 2 : file_version = LLAMA_FILE_VERSION_GGJT_V2; return ;
500
499
case 3 : file_version = LLAMA_FILE_VERSION_GGJT_V3; return ;
501
- case 4 : file_version = LLAMA_FILE_VERSION_GGJT_V4; return ;
502
500
}
503
501
}
504
502
@@ -507,12 +505,12 @@ struct llama_file_loader {
507
505
}
508
506
void read_hparams () {
509
507
hparams.n_vocab = file.read_u32 ();
510
- hparams.n_vocab_sp = file_version >= LLAMA_FILE_VERSION_GGJT_V4 ? file.read_u32 () : 0 ;
511
508
hparams.n_embd = file.read_u32 ();
512
509
hparams.n_mult = file.read_u32 ();
513
510
hparams.n_head = file.read_u32 ();
514
511
hparams.n_layer = file.read_u32 ();
515
- hparams.n_rot = file.read_u32 ();
512
+ hparams.n_vocab_base = file.read_u32 ();
513
+ hparams.n_vocab_base = (hparams.n_vocab_base & 0xF0000000 ) == 0 ? hparams.n_vocab : (hparams.n_vocab_base & ~0xF0000000 ); // this bitwise operation is necessary for compatibility with older models
516
514
hparams.ftype = (enum llama_ftype) file.read_u32 ();
517
515
}
518
516
void read_vocab () {
@@ -533,20 +531,6 @@ struct llama_file_loader {
533
531
tok_score.tok = std::move (word);
534
532
tok_score.score = score;
535
533
}
536
-
537
- vocab.special_token_to_id .reserve (hparams.n_vocab_sp );
538
-
539
- for (uint32_t i = 0 ; i < hparams.n_vocab_sp ; i++) {
540
- llama_vocab::id token_id = file.read_u32 ();
541
- const auto & word = vocab.id_to_token [token_id].tok ;
542
-
543
- vocab.special_token_trie .add (word);
544
- vocab.special_token_to_id [word] = token_id;
545
-
546
- if (vocab.max_special_token_length < word.size ()) {
547
- vocab.max_special_token_length = word.size ();
548
- }
549
- }
550
534
}
551
535
void read_tensor_metadata (size_t file_idx, llama_load_tensors_map & tensors_map) {
552
536
while (file.tell () < file.size ) {
@@ -601,6 +585,24 @@ struct llama_file_loader {
601
585
tensors_map.tensors .at (idx).shards .push_back (shard);
602
586
}
603
587
}
588
+ void set_vocab_sp () {
589
+ uint32_t vocab_sp = 3 + hparams.n_vocab - hparams.n_vocab_base ;
590
+ vocab.special_token_to_id .reserve (vocab_sp);
591
+ for (uint32_t i = 0 ; i < vocab_sp; i++) {
592
+ llama_vocab::id token_id = i > 2 ? hparams.n_vocab_base + i : i;
593
+ const auto & word = vocab.id_to_token [token_id].tok ;
594
+ if (word.empty ()) {
595
+ continue ;
596
+ }
597
+
598
+ vocab.special_token_trie .add (word);
599
+ vocab.special_token_to_id [word] = token_id;
600
+
601
+ if (vocab.max_special_token_length < word.size ()) {
602
+ vocab.max_special_token_length = word.size ();
603
+ }
604
+ }
605
+ }
604
606
};
605
607
606
608
struct llama_file_saver {
@@ -620,12 +622,11 @@ struct llama_file_saver {
620
622
void write_hparams (enum llama_ftype new_ftype) {
621
623
const llama_hparams & hparams = any_file_loader->hparams ;
622
624
file.write_u32 (hparams.n_vocab );
623
- file.write_u32 (hparams.n_vocab_sp );
624
625
file.write_u32 (hparams.n_embd );
625
626
file.write_u32 (hparams.n_mult );
626
627
file.write_u32 (hparams.n_head );
627
628
file.write_u32 (hparams.n_layer );
628
- file.write_u32 (hparams.n_rot );
629
+ file.write_u32 (hparams.n_vocab_base | 0xF0000000 ); // this bitwise operation is necessary for compatibility with older models
629
630
file.write_u32 (new_ftype);
630
631
}
631
632
void write_vocab () {
@@ -639,9 +640,6 @@ struct llama_file_saver {
639
640
file.write_raw (token_score.tok .data (), token_score.tok .size ());
640
641
file.write_raw (&token_score.score , sizeof (token_score.score ));
641
642
}
642
- for (const auto & pair : any_file_loader->vocab .special_token_to_id ) {
643
- file.write_u32 (pair.second );
644
- }
645
643
}
646
644
void write_tensor (llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) {
647
645
switch (new_type) {
@@ -1015,8 +1013,7 @@ static const char *llama_file_version_name(llama_file_version version) {
1015
1013
case LLAMA_FILE_VERSION_GGMF_V1: return " ggmf v1 (old version with no mmap support)" ;
1016
1014
case LLAMA_FILE_VERSION_GGJT_V1: return " ggjt v1 (pre #1405)" ;
1017
1015
case LLAMA_FILE_VERSION_GGJT_V2: return " ggjt v2 (pre #1508)" ;
1018
- case LLAMA_FILE_VERSION_GGJT_V3: return " ggjt v3 (pre #1931)" ;
1019
- case LLAMA_FILE_VERSION_GGJT_V4: return " ggjt v4 (latest)" ;
1016
+ case LLAMA_FILE_VERSION_GGJT_V3: return " ggjt v3 (latest)" ;
1020
1017
}
1021
1018
1022
1019
return " unknown" ;
@@ -1113,7 +1110,7 @@ static void llama_model_load_internal(
1113
1110
fprintf (stderr, " %s: n_mult = %u\n " , __func__, hparams.n_mult );
1114
1111
fprintf (stderr, " %s: n_head = %u\n " , __func__, hparams.n_head );
1115
1112
fprintf (stderr, " %s: n_layer = %u\n " , __func__, hparams.n_layer );
1116
- fprintf (stderr, " %s: n_rot = %u\n " , __func__, hparams.n_rot );
1113
+ fprintf (stderr, " %s: n_rot = %u\n " , __func__, hparams.n_embd /hparams. n_head );
1117
1114
fprintf (stderr, " %s: ftype = %u (%s)\n " , __func__, hparams.ftype , llama_ftype_name (hparams.ftype ));
1118
1115
fprintf (stderr, " %s: n_ff = %u\n " , __func__, n_ff);
1119
1116
fprintf (stderr, " %s: n_parts = %zu\n " , __func__, ml->file_loaders .size ());
0 commit comments