@@ -508,17 +508,16 @@ struct gguf_load_tensors_map {
508508
509509enum gguf_file_version {
510510 GGUF_FILE_VERSION_V1 = 1 ,
511-
512511};
513512
514-
515513struct gguf_file_loader {
516514 gguf_file file;
517515 gguf_context * gguf_ctx;
518516 gguf_file_version file_version;
519517 llama_hparams hparams;
520518 llama_vocab vocab;
521- struct ggml_context * ctx_data = NULL ;
519+
520+ struct ggml_context * ctx_data = NULL ;
522521
523522 gguf_file_loader (const char * fname, gguf_load_tensors_map & tensors_map)
524523 : file(fname, " rb" ) {
@@ -537,7 +536,7 @@ struct ggml_context * ctx_data = NULL;
537536 read_tensor_metadata (tensors_map);
538537 }
539538
540- uint32_t read_u32 (const char * key) {
539+ uint32_t read_u32 (const char * key) const {
541540 int i = gguf_find_key (gguf_ctx, key);
542541 if (i == -1 ) {
543542 throw std::runtime_error (format (" cannot find param with key %s\n " , key));
@@ -546,7 +545,7 @@ struct ggml_context * ctx_data = NULL;
546545 return gguf_get_val_u32 (gguf_ctx, i);
547546 }
548547
549- float read_f32 (const char * key) {
548+ float read_f32 (const char * key) const {
550549 int i = gguf_find_key (gguf_ctx, key);
551550 if (i == -1 ) {
552551 throw std::runtime_error (format (" cannot find param with key %s\n " , key));
@@ -555,27 +554,26 @@ struct ggml_context * ctx_data = NULL;
555554 return gguf_get_val_f32 (gguf_ctx, i);
556555 }
557556
558- int read_n_vocab () {
557+ int read_n_vocab () const {
559558 int i = gguf_find_key (gguf_ctx, " tokenizer.ggml.tokens" );
560- if (i == -1 ) {
561- throw std::runtime_error (" cannot find token list in GGUF file\n " );
562- }
559+ if (i == -1 ) {
560+ throw std::runtime_error (" cannot find token list in GGUF file\n " );
561+ }
563562
564- return gguf_get_arr_n (gguf_ctx, i);
563+ return gguf_get_arr_n (gguf_ctx, i);
565564 }
566565
567566 void read_hparams () {
568-
569567 // TODO define keys as constants in header
570568 // TODO: read all hparams from file
571569
572- hparams.n_vocab = read_n_vocab ();
573- hparams.n_ctx = read_u32 (" llama.context_length" );
574- hparams.n_embd = read_u32 (" llama.embedding_length" );
575- hparams.n_ff = read_u32 (" llama.feed_forward_length" );
576- hparams.n_head = read_u32 (" llama.attention.head_count" );
577- hparams.n_layer = read_u32 (" llama.layer_count" );
578- hparams.n_rot = read_u32 (" llama.rope.dimension_count" );
570+ hparams.n_vocab = read_n_vocab ();
571+ hparams.n_ctx = read_u32 (" llama.context_length" );
572+ hparams.n_embd = read_u32 (" llama.embedding_length" );
573+ hparams.n_ff = read_u32 (" llama.feed_forward_length" );
574+ hparams.n_head = read_u32 (" llama.attention.head_count" );
575+ hparams.n_layer = read_u32 (" llama.layer_count" );
576+ hparams.n_rot = read_u32 (" llama.rope.dimension_count" );
579577 hparams.f_rms_norm_eps = read_f32 (" llama.attention.layer_norm_rms_epsilon" );
580578
581579 // LLaMAv2
@@ -606,24 +604,27 @@ struct ggml_context * ctx_data = NULL;
606604 }
607605 }
608606
609- void read_tensor_metadata (gguf_load_tensors_map & tensors_map) {
607+ void read_tensor_metadata (gguf_load_tensors_map & tensors_map) const {
610608 const int n_tensors = gguf_get_n_tensors (gguf_ctx);
611609
612610 for (int i = 0 ; i < n_tensors; ++i) {
613611 gguf_load_tensor tensor;
614612 const char * name = gguf_get_tensor_name (gguf_ctx, i);
615613
616614 struct ggml_tensor * cur = ggml_get_tensor (ctx_data, name);
617- uint32_t n_dims = cur->n_dims ;
615+
616+ const uint32_t n_dims = cur->n_dims ;
618617 tensor.type = cur->type ;
619618 tensor.ne .resize (n_dims);
619+
620620 for (uint32_t j = 0 ; j < n_dims; ++j) {
621- tensor.ne [j] = cur->ne [j];
621+ tensor.ne [j] = cur->ne [j];
622622 }
623623
624624 if (n_dims < 1 || n_dims > 2 ) {
625625 throw std::runtime_error (format (" llama.cpp: tensor '%s' should not be %u-dimensional" , name, n_dims));
626626 }
627+
627628 switch (tensor.type ) {
628629 case GGML_TYPE_F32:
629630 case GGML_TYPE_F16:
@@ -643,7 +644,6 @@ struct ggml_context * ctx_data = NULL;
643644 }
644645 }
645646
646-
647647 tensor.file_off = gguf_get_data_offset (gguf_ctx) + gguf_get_tensor_offset (gguf_ctx, i);
648648
649649 tensor.name = name;
@@ -670,47 +670,47 @@ struct gguf_file_saver {
670670
671671 gguf_file_saver (const char * fname, gguf_file_loader * fl, enum llama_ftype new_ftype)
672672 : file(fname, " wb" ), fl(fl) {
673- fprintf (stderr, " llama.cpp: saving model to %s\n " , fname);
674- write_header ();
675- write_hparams (new_ftype);
676- }
673+ fprintf (stderr, " llama.cpp: saving model to %s\n " , fname);
674+ write_header ();
675+ write_hparams (new_ftype);
676+ }
677677
678678 void write_header () {
679679 const int32_t magic = GGUF_MAGIC;
680680 file.write_i32 (magic);
681681
682- const int32_t version = GGUF_VERSION;
683- file.write_i32 (version);
682+ const int32_t version = GGUF_VERSION;
683+ file.write_i32 (version);
684684
685- const int32_t n_tensors = gguf_get_n_tensors (fl->gguf_ctx );
686- file.write_i32 (n_tensors);
685+ const int32_t n_tensors = gguf_get_n_tensors (fl->gguf_ctx );
686+ file.write_i32 (n_tensors);
687687
688- const int32_t n_kv = gguf_get_n_kv (fl->gguf_ctx );
689- file.write_i32 (n_kv);
690- }
688+ const int32_t n_kv = gguf_get_n_kv (fl->gguf_ctx );
689+ file.write_i32 (n_kv);
690+ }
691691
692- void write_hparam_arr_str (const std::string & key, enum gguf_type type, int i, int n_arr) {
693- std::vector<std::string> data (n_arr);
692+ void write_hparam_arr_str (const std::string & key, enum gguf_type type, int i, int n_arr) {
693+ std::vector<std::string> data (n_arr);
694694
695- for (int j = 0 ; j < n_arr; ++j) {
696- std::string val = gguf_get_arr_str (fl->gguf_ctx , i, j);
697- data[j] = val;
698- }
699-
700- file.write_arr <std::string>(key, type, data);
695+ for (int j = 0 ; j < n_arr; ++j) {
696+ std::string val = gguf_get_arr_str (fl->gguf_ctx , i, j);
697+ data[j] = val;
701698 }
702699
703- void write_hparam_arr_f32 ( const std::string & key, enum gguf_type type, int i, int n_arr) {
704- std::vector< float > data (n_arr);
700+ file. write_arr < std::string>( key, type, data);
701+ }
705702
706- for (int j = 0 ; j < n_arr; ++j) {
707- float val = gguf_get_arr_f32 (fl->gguf_ctx , i, j);
708- data[j] = val;
709- }
703+ void write_hparam_arr_f32 (const std::string & key, enum gguf_type type, int i, int n_arr) {
704+ std::vector<float > data (n_arr);
710705
711- file.write_arr <float >(key, type, data);
706+ for (int j = 0 ; j < n_arr; ++j) {
707+ float val = gguf_get_arr_f32 (fl->gguf_ctx , i, j);
708+ data[j] = val;
712709 }
713710
711+ file.write_arr <float >(key, type, data);
712+ }
713+
714714 void write_hparams (enum llama_ftype new_ftype) {
715715 const int32_t n_kv = gguf_get_n_kv (fl->gguf_ctx );
716716 for (int i = 0 ; i < n_kv; ++i) {
@@ -734,59 +734,62 @@ struct gguf_file_saver {
734734
735735 switch (vtype) {
736736 case GGUF_TYPE_BOOL:
737- bool_val = gguf_get_val_bool (fl->gguf_ctx , i);
738- file.write_val <bool >(key, GGUF_TYPE_BOOL, bool_val);
739- break ;
737+ bool_val = gguf_get_val_bool (fl->gguf_ctx , i);
738+ file.write_val <bool >(key, GGUF_TYPE_BOOL, bool_val);
739+ break ;
740740 case GGUF_TYPE_FLOAT32:
741- f32_val = gguf_get_val_f32 (fl->gguf_ctx , i);
742- file.write_val <float >(key, GGUF_TYPE_FLOAT32, f32_val);
743- break ;
741+ f32_val = gguf_get_val_f32 (fl->gguf_ctx , i);
742+ file.write_val <float >(key, GGUF_TYPE_FLOAT32, f32_val);
743+ break ;
744744 case GGUF_TYPE_INT16:
745- i16_val = gguf_get_val_i16 (fl->gguf_ctx , i);
746- file.write_val <int16_t >(key, GGUF_TYPE_INT16, i16_val);
747- break ;
745+ i16_val = gguf_get_val_i16 (fl->gguf_ctx , i);
746+ file.write_val <int16_t >(key, GGUF_TYPE_INT16, i16_val);
747+ break ;
748748 case GGUF_TYPE_INT32:
749- i32_val = gguf_get_val_i32 (fl->gguf_ctx , i);
750- file.write_val <int32_t >(key, GGUF_TYPE_INT32, i32_val);
751- break ;
749+ i32_val = gguf_get_val_i32 (fl->gguf_ctx , i);
750+ file.write_val <int32_t >(key, GGUF_TYPE_INT32, i32_val);
751+ break ;
752752 case GGUF_TYPE_INT8:
753- i8_val = gguf_get_val_i8 (fl->gguf_ctx , i);
754- file.write_val <int8_t >(key, GGUF_TYPE_INT8, i8_val);
755- break ;
753+ i8_val = gguf_get_val_i8 (fl->gguf_ctx , i);
754+ file.write_val <int8_t >(key, GGUF_TYPE_INT8, i8_val);
755+ break ;
756756 case GGUF_TYPE_STRING:
757- str_val = gguf_get_val_str (fl->gguf_ctx , i);
758- file.write_val <std::string>(key, GGUF_TYPE_STRING, str_val);
759- break ;
757+ str_val = gguf_get_val_str (fl->gguf_ctx , i);
758+ file.write_val <std::string>(key, GGUF_TYPE_STRING, str_val);
759+ break ;
760760 case GGUF_TYPE_UINT16:
761- u16_val = gguf_get_val_u16 (fl->gguf_ctx , i);
762- file.write_val <uint16_t >(key, GGUF_TYPE_UINT16, u16_val);
763- break ;
761+ u16_val = gguf_get_val_u16 (fl->gguf_ctx , i);
762+ file.write_val <uint16_t >(key, GGUF_TYPE_UINT16, u16_val);
763+ break ;
764764 case GGUF_TYPE_UINT32:
765- u32_val = gguf_get_val_u32 (fl->gguf_ctx , i);
766- file.write_val <uint32_t >(key, GGUF_TYPE_UINT32, u32_val);
767- break ;
765+ u32_val = gguf_get_val_u32 (fl->gguf_ctx , i);
766+ file.write_val <uint32_t >(key, GGUF_TYPE_UINT32, u32_val);
767+ break ;
768768 case GGUF_TYPE_UINT8:
769- u8_val = gguf_get_val_u8 (fl->gguf_ctx , i);
770- file.write_val <uint8_t >(key, GGUF_TYPE_UINT8, u8_val);
771- break ;
769+ u8_val = gguf_get_val_u8 (fl->gguf_ctx , i);
770+ file.write_val <uint8_t >(key, GGUF_TYPE_UINT8, u8_val);
771+ break ;
772772 case GGUF_TYPE_ARRAY:
773- arr_type = gguf_get_arr_type (fl->gguf_ctx , i);
774- n_arr = gguf_get_arr_n (fl->gguf_ctx , i);
775- if (arr_type == GGUF_TYPE_FLOAT32) {
776- write_hparam_arr_f32 (key, arr_type, i, n_arr);
773+ arr_type = gguf_get_arr_type (fl->gguf_ctx , i);
774+ n_arr = gguf_get_arr_n (fl->gguf_ctx , i);
775+ if (arr_type == GGUF_TYPE_FLOAT32) {
776+ write_hparam_arr_f32 (key, arr_type, i, n_arr);
777777 } else if (arr_type == GGUF_TYPE_STRING) {
778778 write_hparam_arr_str (key, GGUF_TYPE_STRING, i, n_arr);
779779 } else {
780780 throw std::runtime_error (" not implemented" );
781781 }
782- break ;
782+ break ;
783783 default :
784- throw std::runtime_error (format (" cannot recognize value type for key %s\n " , key));
784+ throw std::runtime_error (format (" cannot recognize value type for key %s\n " , key));
785785 }
786786 }
787787 }
788788
789- info_offset = file.tell ();
789+ info_offset = file.tell ();
790+
791+ GGML_ASSERT (gguf_get_data_offset (fl->gguf_ctx ) >= info_offset);
792+
790793 size_t count = gguf_get_data_offset (fl->gguf_ctx ) - info_offset;
791794 file.write_zeros (count);
792795 file.seek (info_offset, SEEK_SET);
0 commit comments