From e700b442173268ebd714bc3aef3ac7bbf6b67d3d Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Mon, 26 Feb 2024 11:58:25 +0100 Subject: [PATCH 01/10] additional methods to read model and ctx parameters --- llama.cpp | 15 +++++++++++++++ llama.h | 4 ++++ 2 files changed, 19 insertions(+) diff --git a/llama.cpp b/llama.cpp index 4225f955590dd..4f92488d1aef6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12496,6 +12496,14 @@ int32_t llama_n_embd(const struct llama_model * model) { return model->hparams.n_embd; } +int32_t llama_n_layers(const struct llama_model * model) { + return model->hparams.n_layer; +} + +int32_t llama_n_heads(const struct llama_model * model) { + return model->hparams.n_head; +} + float llama_rope_freq_scale_train(const struct llama_model * model) { return model->hparams.rope_freq_scale_train; } @@ -13153,6 +13161,13 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_ ctx->cparams.n_threads_batch = n_threads_batch; } +void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch) { + assert(n_threads); + assert(n_threads_batch); + *n_threads = ctx->cparams.n_threads; + *n_threads_batch = ctx->cparams.n_threads_batch; +} + void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { ctx->abort_callback = abort_callback; ctx->abort_callback_data = abort_callback_data; diff --git a/llama.h b/llama.h index 3dc162b078d30..3eabaa9b69639 100644 --- a/llama.h +++ b/llama.h @@ -383,6 +383,8 @@ extern "C" { LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_n_layers (const struct llama_model * model); + LLAMA_API int32_t llama_n_heads (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); @@ -640,6 +642,8 @@ extern "C" { // n_threads is the number of threads used for generation (single token) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); + // Get the number of threads used for decoding + LLAMA_API void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch); // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); From cc3fe18b43844d3a4a4663ee7fc89e9ae40a7b3e Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Tue, 27 Feb 2024 14:34:29 +0100 Subject: [PATCH 02/10] vocab size as a part of a model metadata --- convert.py | 1 + gguf-py/gguf/constants.py | 2 ++ gguf-py/gguf/gguf_writer.py | 3 +++ llama.cpp | 4 +++- 4 files changed, 9 insertions(+), 1 deletion(-) diff --git a/convert.py b/convert.py index c15f8c47ea4f7..3dba0439524ee 100755 --- a/convert.py +++ b/convert.py @@ -977,6 +977,7 @@ def add_meta_arch(self, params: Params) -> None: name = str(params.path_model.parent).split('/')[-1] self.gguf.add_name (name) + self.gguf.add_vocab_size (params.n_vocab) self.gguf.add_context_length (params.n_ctx) self.gguf.add_embedding_length (params.n_embd) self.gguf.add_block_count (params.n_layer) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a62139811ef36..3828253b09d3f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -32,6 +32,7 @@ class General: FILE_TYPE = "general.file_type" class LLM: + VOCAB_SIZE = "{arch}.vocab_size" CONTEXT_LENGTH = "{arch}.context_length" EMBEDDING_LENGTH = "{arch}.embedding_length" BLOCK_COUNT = "{arch}.block_count" @@ -711,6 +712,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE # LLM +KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH KEY_BLOCK_COUNT = Keys.LLM.BLOCK_COUNT diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 8011608323c45..4dc02e8ef1891 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -313,6 +313,9 @@ def add_custom_alignment(self, alignment: int) -> None: self.data_alignment = alignment self.add_uint32(Keys.General.ALIGNMENT, alignment) + def add_vocab_size(self, size: int) -> None: + self.add_uint32(Keys.LLM.VOCAB_SIZE.format(arch=self.arch), size) + def add_context_length(self, length: int) -> None: self.add_uint32(Keys.LLM.CONTEXT_LENGTH.format(arch=self.arch), length) diff --git a/llama.cpp b/llama.cpp index 4f92488d1aef6..b1411eec9f566 100644 --- a/llama.cpp +++ b/llama.cpp @@ -256,6 +256,7 @@ enum llm_kv { LLM_KV_GENERAL_SOURCE_URL, LLM_KV_GENERAL_SOURCE_HF_REPO, + LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_BLOCK_COUNT, @@ -314,6 +315,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_GENERAL_SOURCE_URL, "general.source.url" }, { LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" }, + { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, @@ -12485,7 +12487,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { } int32_t llama_n_vocab(const struct llama_model * model) { - return model->vocab.id_to_token.size(); + return model->hparams.n_vocab; } int32_t llama_n_ctx_train(const struct llama_model * model) { From 4f4258fbde7f2242d51b8ddaf3d75ab466b44764 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Tue, 27 Feb 2024 18:29:34 +0100 Subject: [PATCH 03/10] models without vocabulary, convert.py part --- convert.py | 167 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 113 insertions(+), 54 deletions(-) diff --git a/convert.py b/convert.py index 3dba0439524ee..2e61a9bacc45b 100755 --- a/convert.py +++ b/convert.py @@ -385,6 +385,12 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield from self.bpe_tokens() yield from self.added_tokens() + def get_tokenizer_model(self) -> str: + return "gpt2" + + def get_name(self) -> str: + return "bpe" + def __repr__(self) -> str: return f"" @@ -448,6 +454,12 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield from self.sentencepiece_tokens() yield from self.added_tokens() + def get_tokenizer_model(self) -> str: + return "llama" + + def get_name(self) -> str: + return "spm" + def __repr__(self) -> str: return f"" @@ -549,11 +561,28 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield from self.hf_tokens() yield from self.added_tokens() + def get_tokenizer_model(self) -> str: + return "llama" + + def get_name(self) -> str: + return "hfft" + def __repr__(self) -> str: return f"" -Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab" +class NoVocab: + def get_tokenizer_model(self) -> str: + return "no_vocab" + + def get_name(self) -> str: + return "no_vocab" + + def __repr__(self) -> str: + return "" + + +Vocab: TypeAlias = "BpeVocab | SentencePieceVocab | HfVocab | NoVocab" # @@ -931,13 +960,18 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc yield result -def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: +def check_vocab_size(params: Params, vocab: Vocab) -> None: # Handle special case where the model's vocab size is not set if params.n_vocab == -1: raise ValueError( - f"The model's vocab size is set to -1 in params.json. Please update it manually. Maybe {vocab.vocab_size}?" + f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}" ) + +def prepare_vocab(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: + assert not isinstance(vocab, NoVocab) + check_vocab_size(params, vocab) + # Check for a vocab size mismatch if params.n_vocab == vocab.vocab_size: print("Ignoring added_tokens.json since model matches vocab size without it.") @@ -1014,20 +1048,6 @@ def add_meta_arch(self, params: Params) -> None: if params.ftype is not None: self.gguf.add_file_type(params.ftype) - def handle_tokenizer_model(self, vocab: Vocab) -> str: - # Map the vocab types to the supported tokenizer models - tokenizer_model = { - SentencePieceVocab: "llama", - HfVocab: "llama", - BpeVocab: "gpt2", - }.get(type(vocab)) - - # Block if vocab type is not predefined - if tokenizer_model is None: - raise ValueError("Unknown vocab type: Not supported") - - return tokenizer_model - def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: tokens = [] scores = [] @@ -1044,11 +1064,8 @@ def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list return tokens, scores, toktypes def add_meta_vocab(self, vocab: Vocab) -> None: - # Handle the tokenizer model - tokenizer_model = self.handle_tokenizer_model(vocab) - # Ensure that tokenizer_model is added to the GGUF model - self.gguf.add_tokenizer_model(tokenizer_model) + self.gguf.add_tokenizer_model(vocab.get_tokenizer_model()) # Extract model vocabulary for model conversion tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab) @@ -1075,6 +1092,26 @@ def write_meta(self) -> None: def write_tensor_info(self) -> None: self.gguf.write_ti_data_to_file() + def write_tensor_data(self, ftype: GGMLFileType, model: LazyModel, concurrency: int) -> None: + ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency=concurrency) + if ftype == GGMLFileType.MostlyQ8_0: + ndarrays = bounded_parallel_map( + OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, + use_processpool_executor=True, + ) + else: + ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) + + start = time.time() + for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): + elapsed = time.time() - start + size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) + padi = len(str(len(model))) + print( + f"[{i + 1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" + ) + self.gguf.write_tensor_data(ndarray) + def close(self) -> None: self.gguf.close() @@ -1083,7 +1120,7 @@ def write_vocab_only( fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, ) -> None: - check_vocab_size(params, vocab, pad_vocab = pad_vocab) + prepare_vocab(params, vocab, pad_vocab=pad_vocab) of = OutputFile(fname_out, endianess=endianess) @@ -1115,7 +1152,7 @@ def write_all( concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, ) -> None: - check_vocab_size(params, vocab, pad_vocab=pad_vocab) + prepare_vocab(params, vocab, pad_vocab=pad_vocab) of = OutputFile(fname_out, endianess=endianess) @@ -1132,24 +1169,33 @@ def write_all( of.write_tensor_info() # tensor data - ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency) - if ftype == GGMLFileType.MostlyQ8_0: - ndarrays = bounded_parallel_map( - OutputFile.maybe_do_quantize, ndarrays_inner, concurrency=concurrency, max_workers=concurrency, - use_processpool_executor=True, - ) - else: - ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner) + of.write_tensor_data(ftype, model, concurrency) - start = time.time() - for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): - elapsed = time.time() - start - size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) - padi = len(str(len(model))) - print( - f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}" - ) - of.gguf.write_tensor_data(ndarray) + of.close() + + @staticmethod + def write_without_vocab( + fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, + concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, + ) -> None: + assert isinstance(vocab, NoVocab) + check_vocab_size(params, vocab) + + of = OutputFile(fname_out, endianess=endianess) + + # meta data + of.add_meta_arch(params) + of.gguf.add_tokenizer_model(vocab.get_tokenizer_model()) + + # tensor info + for name, lazy_tensor in model.items(): + of.add_tensor_info(name, lazy_tensor) + + of.write_meta() + of.write_tensor_info() + + # tensor data + of.write_tensor_data(ftype, model, concurrency) of.close() @@ -1310,8 +1356,8 @@ def _select_file(self, vocab_types: list[str]) -> tuple[str, Path]: return vtype, path raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}") - def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: Path) -> gguf.SpecialVocab: - load_merges = vocabtype == "bpe" + def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab: + load_merges = vocab.get_name() == "bpe" n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None return gguf.SpecialVocab( model_parent_path, @@ -1320,30 +1366,34 @@ def _create_special_vocab(self, vocab: Vocab, vocabtype: str, model_parent_path: n_vocab=n_vocab, ) - def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]: + def _create_vocab_by_path(self, vocab_types: list[str]) -> Vocab: vocab_type, path = self._select_file(vocab_types) print(f"Loading vocab file {path!r}, type {vocab_type!r}") added_tokens_path = path.parent / "added_tokens.json" - vocab: Vocab if vocab_type == "bpe": - vocab = BpeVocab( + return BpeVocab( path, added_tokens_path if added_tokens_path.exists() else None ) - elif vocab_type == "spm": - vocab = SentencePieceVocab( + if vocab_type == "spm": + return SentencePieceVocab( path, added_tokens_path if added_tokens_path.exists() else None ) - elif vocab_type == "hfft": - vocab = HfVocab( + if vocab_type == "hfft": + return HfVocab( path.parent, added_tokens_path if added_tokens_path.exists() else None ) + raise ValueError(vocab_type) + + def load_vocab(self, vocab_types: list[str], model_parent_path: Path) -> tuple[Vocab, gguf.SpecialVocab]: + vocab: Vocab + if len(vocab_types) == 1 and "no_vocab" in vocab_types: + vocab = NoVocab() else: - raise ValueError(vocab_type) + vocab = self._create_vocab_by_path(vocab_types) # FIXME: Respect --vocab-dir? special_vocab = self._create_special_vocab( vocab, - vocab_type, model_parent_path, ) return vocab, special_vocab @@ -1381,6 +1431,7 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") + parser.add_argument("--no-vocab", action="store_true", help="store model without the vocab") parser.add_argument("--outtype", choices=output_choices, help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--vocab-type", help="vocab types to try in order, choose from 'spm', 'bpe', 'hfft' (default: spm,hfft)", default="spm,hfft") @@ -1393,6 +1444,10 @@ def main(args_in: list[str] | None = None) -> None: parser.add_argument("--skip-unknown", action="store_true", help="skip unknown tensor names instead of failing") args = parser.parse_args(args_in) + if args.no_vocab: + if args.vocab_only: + raise ValueError("no need to specify --vocab-only if using --no-vocab") + args.vocab_type = "no_vocab" if args.dump_single: model_plus = lazy_load_file(args.model) @@ -1443,7 +1498,7 @@ def main(args_in: list[str] | None = None) -> None: print(f"Wrote {outfile}") return - if model_plus.vocab is not None and args.vocab_dir is None: + if model_plus.vocab is not None and args.vocab_dir is None and not args.no_vocab: vocab = model_plus.vocab print(f"Vocab info: {vocab}") @@ -1458,8 +1513,12 @@ def main(args_in: list[str] | None = None) -> None: params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, - concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab) + if not args.no_vocab: + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, + concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab) + else: + OutputFile.write_without_vocab(outfile, ftype, params, model, vocab, + concurrency=args.concurrency, endianess=endianess) print(f"Wrote {outfile}") From afa9d0953b400718e61898ddca81b5f8852b8558 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Wed, 28 Feb 2024 10:49:26 +0100 Subject: [PATCH 04/10] models without vocabulary, llama.cpp part --- llama.cpp | 93 +++++++++++++++++++++++++++++++++++-------------------- llama.h | 7 +++-- 2 files changed, 63 insertions(+), 37 deletions(-) diff --git a/llama.cpp b/llama.cpp index b1411eec9f566..13af325a756a3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3037,10 +3037,11 @@ static const char * llama_model_type_name(e_model type) { static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ switch (type) { - case LLAMA_VOCAB_TYPE_SPM: return "SPM"; - case LLAMA_VOCAB_TYPE_BPE: return "BPE"; - case LLAMA_VOCAB_TYPE_WPM: return "WPM"; - default: return "unknown"; + case LLAMA_VOCAB_TYPE_SPM: return "SPM"; + case LLAMA_VOCAB_TYPE_BPE: return "BPE"; + case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + case LLAMA_VOCAB_TYPE_NO_VOCAB: return "no vocab"; + default: return "unknown"; } } @@ -3071,15 +3072,14 @@ static void llm_load_hparams( // get general kv ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); - // get hparams kv - ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); - ml.get_key (LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); - ml.get_key (LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); - ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); - ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer); - ml.get_key (LLM_KV_EXPERT_COUNT, hparams.n_expert, false); - ml.get_key (LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); + ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); + ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); GGML_ASSERT(hparams.n_expert <= LLAMA_MAX_EXPERTS); GGML_ASSERT(hparams.n_expert_used <= hparams.n_expert); @@ -3410,30 +3410,25 @@ static void llm_load_vocab( const auto kv = LLM_KV(model.arch); - const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); - if (token_idx == -1) { - throw std::runtime_error("cannot find tokenizer vocab in model file\n"); - } - - const float * scores = nullptr; - const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); - if (score_idx != -1) { - scores = (const float * ) gguf_get_arr_data(ctx, score_idx); - } - - const int * toktypes = nullptr; - const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); - if (toktype_idx != -1) { - toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); - } - // determine vocab type { std::string tokenizer_name; ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name); - if (tokenizer_name == "llama") { + if (tokenizer_name == "no_vocab") { + vocab.type = LLAMA_VOCAB_TYPE_NO_VOCAB; + + // default special tokens + vocab.special_bos_id = -1; + vocab.special_eos_id = -1; + vocab.special_unk_id = -1; + vocab.special_sep_id = -1; + vocab.special_pad_id = -1; + vocab.linefeed_id = -1; + + return; + } else if (tokenizer_name == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; // default special tokens @@ -3499,6 +3494,23 @@ static void llm_load_vocab( } } + const int token_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_LIST).c_str()); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } + + const float * scores = nullptr; + const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); + if (score_idx != -1) { + scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + } + + const int * toktypes = nullptr; + const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); + if (toktype_idx != -1) { + toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + } + const uint32_t n_vocab = gguf_get_arr_n(ctx, token_idx); vocab.id_to_token.resize(n_vocab); @@ -4725,7 +4737,8 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam llm_load_print_meta(ml, model); - if (model.hparams.n_vocab != model.vocab.id_to_token.size()) { + if (model.vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB && + model.hparams.n_vocab != model.vocab.id_to_token.size()) { throw std::runtime_error("vocab size mismatch"); } @@ -8714,26 +8727,32 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { } static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL; } static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN; } static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; } static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; } static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) { + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; } static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { + GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NO_VOCAB); GGML_ASSERT(llama_is_byte_token(vocab, id)); const auto& token_data = vocab.id_to_token.at(id); switch (llama_vocab_get_type(vocab)) { @@ -8754,6 +8773,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { } static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { + GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NO_VOCAB); static const char * hex = "0123456789ABCDEF"; switch (llama_vocab_get_type(vocab)) { case LLAMA_VOCAB_TYPE_SPM: { @@ -9598,6 +9618,8 @@ static std::vector llama_tokenize_internal(const llama_vocab & } } } break; + case LLAMA_VOCAB_TYPE_NO_VOCAB: + GGML_ASSERT(false); } return output; @@ -13164,8 +13186,8 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_ } void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch) { - assert(n_threads); - assert(n_threads_batch); + GGML_ASSERT(n_threads); + GGML_ASSERT(n_threads_batch); *n_threads = ctx->cparams.n_threads; *n_threads_batch = ctx->cparams.n_threads_batch; } @@ -13268,14 +13290,17 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id } const char * llama_token_get_text(const struct llama_model * model, llama_token token) { + GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); return model->vocab.id_to_token[token].text.c_str(); } float llama_token_get_score(const struct llama_model * model, llama_token token) { + GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); return model->vocab.id_to_token[token].score; } llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) { + GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); return model->vocab.id_to_token[token].type; } diff --git a/llama.h b/llama.h index 3eabaa9b69639..af58a1d056c3d 100644 --- a/llama.h +++ b/llama.h @@ -59,9 +59,10 @@ extern "C" { typedef int32_t llama_seq_id; enum llama_vocab_type { - LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece - LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding - LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece + LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece + LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding + LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece + LLAMA_VOCAB_TYPE_NO_VOCAB = 3, // For models without vocab }; // note: these values should be synchronized with ggml_rope From e0504d536c0867ddba0b5b09cc5a9923263441ba Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Thu, 29 Feb 2024 18:01:14 +0100 Subject: [PATCH 05/10] PR clean up --- llama.cpp | 16 +--------------- llama.h | 4 ---- 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/llama.cpp b/llama.cpp index 13af325a756a3..423333c851c42 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3072,6 +3072,7 @@ static void llm_load_hparams( // get general kv ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); + // get hparams kv ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); @@ -12520,14 +12521,6 @@ int32_t llama_n_embd(const struct llama_model * model) { return model->hparams.n_embd; } -int32_t llama_n_layers(const struct llama_model * model) { - return model->hparams.n_layer; -} - -int32_t llama_n_heads(const struct llama_model * model) { - return model->hparams.n_head; -} - float llama_rope_freq_scale_train(const struct llama_model * model) { return model->hparams.rope_freq_scale_train; } @@ -13185,13 +13178,6 @@ void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_ ctx->cparams.n_threads_batch = n_threads_batch; } -void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch) { - GGML_ASSERT(n_threads); - GGML_ASSERT(n_threads_batch); - *n_threads = ctx->cparams.n_threads; - *n_threads_batch = ctx->cparams.n_threads_batch; -} - void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)(void * data), void * abort_callback_data) { ctx->abort_callback = abort_callback; ctx->abort_callback_data = abort_callback_data; diff --git a/llama.h b/llama.h index af58a1d056c3d..4c76c2cbb58c8 100644 --- a/llama.h +++ b/llama.h @@ -384,8 +384,6 @@ extern "C" { LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model); - LLAMA_API int32_t llama_n_layers (const struct llama_model * model); - LLAMA_API int32_t llama_n_heads (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); @@ -643,8 +641,6 @@ extern "C" { // n_threads is the number of threads used for generation (single token) // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); - // Get the number of threads used for decoding - LLAMA_API void llama_get_n_threads(struct llama_context * ctx, uint32_t * n_threads, uint32_t * n_threads_batch); // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); From 0c69016171cfc3ec900ba34b9aa006312b0bfe04 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Sun, 10 Mar 2024 18:28:10 +0100 Subject: [PATCH 06/10] converter scrypt fixes --- convert.py | 82 +++++++++++++++--------------------------------------- 1 file changed, 22 insertions(+), 60 deletions(-) diff --git a/convert.py b/convert.py index 2e61a9bacc45b..27d0f49c09dc7 100755 --- a/convert.py +++ b/convert.py @@ -332,6 +332,9 @@ def load(model_plus: ModelPlus) -> Params: # class BpeVocab: + tokenizer_model = "gpt2" + name = "bpe" + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) if isinstance(self.bpe_tokenizer.get('model'), dict): @@ -385,17 +388,14 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield from self.bpe_tokens() yield from self.added_tokens() - def get_tokenizer_model(self) -> str: - return "gpt2" - - def get_name(self) -> str: - return "bpe" - def __repr__(self) -> str: return f"" class SentencePieceVocab: + tokenizer_model = "llama" + name = "spm" + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None) -> None: self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) added_tokens: dict[str, int] @@ -454,17 +454,14 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield from self.sentencepiece_tokens() yield from self.added_tokens() - def get_tokenizer_model(self) -> str: - return "llama" - - def get_name(self) -> str: - return "spm" - def __repr__(self) -> str: return f"" class HfVocab: + tokenizer_model = "llama" + name = "hfft" + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Path | None = None) -> None: try: from transformers import AutoTokenizer @@ -561,22 +558,13 @@ def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: yield from self.hf_tokens() yield from self.added_tokens() - def get_tokenizer_model(self) -> str: - return "llama" - - def get_name(self) -> str: - return "hfft" - def __repr__(self) -> str: return f"" class NoVocab: - def get_tokenizer_model(self) -> str: - return "no_vocab" - - def get_name(self) -> str: - return "no_vocab" + tokenizer_model = "no_vocab" + name = "no_vocab" def __repr__(self) -> str: return "" @@ -969,8 +957,9 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None: def prepare_vocab(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: - assert not isinstance(vocab, NoVocab) check_vocab_size(params, vocab) + if vocab.name == "no_vocab": + return # Check for a vocab size mismatch if params.n_vocab == vocab.vocab_size: @@ -1065,7 +1054,7 @@ def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list def add_meta_vocab(self, vocab: Vocab) -> None: # Ensure that tokenizer_model is added to the GGUF model - self.gguf.add_tokenizer_model(vocab.get_tokenizer_model()) + self.gguf.add_tokenizer_model(vocab.tokenizer_model) # Extract model vocabulary for model conversion tokens, scores, toktypes = self.extract_vocabulary_from_model(vocab) @@ -1158,34 +1147,11 @@ def write_all( # meta data of.add_meta_arch(params) - of.add_meta_vocab(vocab) - of.add_meta_special_vocab(svocab) - - # tensor info - for name, lazy_tensor in model.items(): - of.add_tensor_info(name, lazy_tensor) - - of.write_meta() - of.write_tensor_info() - - # tensor data - of.write_tensor_data(ftype, model, concurrency) - - of.close() - - @staticmethod - def write_without_vocab( - fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, - concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, - ) -> None: - assert isinstance(vocab, NoVocab) - check_vocab_size(params, vocab) - - of = OutputFile(fname_out, endianess=endianess) - - # meta data - of.add_meta_arch(params) - of.gguf.add_tokenizer_model(vocab.get_tokenizer_model()) + if vocab.name == "no_vocab": + of.gguf.add_tokenizer_model(vocab.tokenizer_model) + else: + of.add_meta_vocab(vocab) + of.add_meta_special_vocab(svocab) # tensor info for name, lazy_tensor in model.items(): @@ -1357,7 +1323,7 @@ def _select_file(self, vocab_types: list[str]) -> tuple[str, Path]: raise FileNotFoundError(f"Could not find any of {[self._FILES[vt] for vt in vocab_types]}") def _create_special_vocab(self, vocab: Vocab, model_parent_path: Path) -> gguf.SpecialVocab: - load_merges = vocab.get_name() == "bpe" + load_merges = vocab.name == "bpe" n_vocab = vocab.vocab_size if hasattr(vocab, "vocab_size") else None return gguf.SpecialVocab( model_parent_path, @@ -1513,12 +1479,8 @@ def main(args_in: list[str] | None = None) -> None: params.ftype = ftype print(f"Writing {outfile}, format {ftype}") - if not args.no_vocab: - OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, - concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab) - else: - OutputFile.write_without_vocab(outfile, ftype, params, model, vocab, - concurrency=args.concurrency, endianess=endianess) + OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, + concurrency=args.concurrency, endianess=endianess, pad_vocab=args.pad_vocab) print(f"Wrote {outfile}") From 80f66a8af7daa274065aedaa7614783f66a1af66 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Sun, 10 Mar 2024 18:32:32 +0100 Subject: [PATCH 07/10] llama_vocab_type update (renamed the new key) --- llama.cpp | 36 ++++++++++++++++++------------------ llama.h | 8 ++++---- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/llama.cpp b/llama.cpp index 423333c851c42..8480c6be960bc 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3037,11 +3037,11 @@ static const char * llama_model_type_name(e_model type) { static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ switch (type) { - case LLAMA_VOCAB_TYPE_SPM: return "SPM"; - case LLAMA_VOCAB_TYPE_BPE: return "BPE"; - case LLAMA_VOCAB_TYPE_WPM: return "WPM"; - case LLAMA_VOCAB_TYPE_NO_VOCAB: return "no vocab"; - default: return "unknown"; + case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; + case LLAMA_VOCAB_TYPE_SPM: return "SPM"; + case LLAMA_VOCAB_TYPE_BPE: return "BPE"; + case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + default: return "unknown"; } } @@ -3418,7 +3418,7 @@ static void llm_load_vocab( ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name); if (tokenizer_name == "no_vocab") { - vocab.type = LLAMA_VOCAB_TYPE_NO_VOCAB; + vocab.type = LLAMA_VOCAB_TYPE_NONE; // default special tokens vocab.special_bos_id = -1; @@ -4738,7 +4738,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam llm_load_print_meta(ml, model); - if (model.vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB && + if (model.vocab.type != LLAMA_VOCAB_TYPE_NONE && model.hparams.n_vocab != model.vocab.id_to_token.size()) { throw std::runtime_error("vocab size mismatch"); } @@ -8728,32 +8728,32 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { } static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL; } static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN; } static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; } static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; } static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) { - GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); + GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; } static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { - GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NO_VOCAB); + GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(llama_is_byte_token(vocab, id)); const auto& token_data = vocab.id_to_token.at(id); switch (llama_vocab_get_type(vocab)) { @@ -8774,7 +8774,7 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { } static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { - GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NO_VOCAB); + GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); static const char * hex = "0123456789ABCDEF"; switch (llama_vocab_get_type(vocab)) { case LLAMA_VOCAB_TYPE_SPM: { @@ -9619,7 +9619,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & } } } break; - case LLAMA_VOCAB_TYPE_NO_VOCAB: + case LLAMA_VOCAB_TYPE_NONE: GGML_ASSERT(false); } @@ -13276,17 +13276,17 @@ float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id } const char * llama_token_get_text(const struct llama_model * model, llama_token token) { - GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); + GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE); return model->vocab.id_to_token[token].text.c_str(); } float llama_token_get_score(const struct llama_model * model, llama_token token) { - GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); + GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE); return model->vocab.id_to_token[token].score; } llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) { - GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NO_VOCAB); + GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE); return model->vocab.id_to_token[token].type; } diff --git a/llama.h b/llama.h index 4c76c2cbb58c8..44eea70b71804 100644 --- a/llama.h +++ b/llama.h @@ -59,10 +59,10 @@ extern "C" { typedef int32_t llama_seq_id; enum llama_vocab_type { - LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece - LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding - LLAMA_VOCAB_TYPE_WPM = 2, // WordPiece - LLAMA_VOCAB_TYPE_NO_VOCAB = 3, // For models without vocab + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab + LLAMA_VOCAB_TYPE_SPM = 1, // SentencePiece + LLAMA_VOCAB_TYPE_BPE = 2, // Byte Pair Encoding + LLAMA_VOCAB_TYPE_WPM = 3, // WordPiece }; // note: these values should be synchronized with ggml_rope From 0a1322acbdbc73a15700ca9d18ae9be116314e38 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Wed, 13 Mar 2024 11:44:08 +0100 Subject: [PATCH 08/10] pr review fixes --- convert.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/convert.py b/convert.py index 27d0f49c09dc7..9b7d6b6936f01 100755 --- a/convert.py +++ b/convert.py @@ -948,18 +948,14 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc yield result -def check_vocab_size(params: Params, vocab: Vocab) -> None: +def prepare_vocab(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: # Handle special case where the model's vocab size is not set if params.n_vocab == -1: raise ValueError( f"The model's vocab size is set to -1 in params.json. Please update it manually.{f' Maybe {vocab.vocab_size}?' if hasattr(vocab, 'vocab_size') else ''}" ) - - -def prepare_vocab(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: - check_vocab_size(params, vocab) - if vocab.name == "no_vocab": - return + if isinstance(vocab, NoVocab): + return # model has no vocab # Check for a vocab size mismatch if params.n_vocab == vocab.vocab_size: @@ -1147,7 +1143,7 @@ def write_all( # meta data of.add_meta_arch(params) - if vocab.name == "no_vocab": + if isinstance(vocab, NoVocab): of.gguf.add_tokenizer_model(vocab.tokenizer_model) else: of.add_meta_vocab(vocab) From 94a1050e57e2abc403f62f0ecce7e1140bc023c8 Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Wed, 13 Mar 2024 12:07:22 +0100 Subject: [PATCH 09/10] revert function renaming --- convert.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert.py b/convert.py index 9b7d6b6936f01..93d4d3405afea 100755 --- a/convert.py +++ b/convert.py @@ -948,7 +948,7 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc yield result -def prepare_vocab(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: +def check_vocab_size(params: Params, vocab: Vocab, pad_vocab: bool = False) -> None: # Handle special case where the model's vocab size is not set if params.n_vocab == -1: raise ValueError( @@ -1105,7 +1105,7 @@ def write_vocab_only( fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, ) -> None: - prepare_vocab(params, vocab, pad_vocab=pad_vocab) + check_vocab_size(params, vocab, pad_vocab=pad_vocab) of = OutputFile(fname_out, endianess=endianess) @@ -1137,7 +1137,7 @@ def write_all( concurrency: int = DEFAULT_CONCURRENCY, endianess: gguf.GGUFEndian = gguf.GGUFEndian.LITTLE, pad_vocab: bool = False, ) -> None: - prepare_vocab(params, vocab, pad_vocab=pad_vocab) + check_vocab_size(params, vocab, pad_vocab=pad_vocab) of = OutputFile(fname_out, endianess=endianess) From 9cb1554fb08c233efe8184f8f5be8dace3fb21fb Mon Sep 17 00:00:00 2001 From: Michael Podvitskiy Date: Thu, 14 Mar 2024 09:29:13 +0100 Subject: [PATCH 10/10] one more NoVocab assert --- convert.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/convert.py b/convert.py index 93d4d3405afea..161430f3e717e 100755 --- a/convert.py +++ b/convert.py @@ -1034,6 +1034,8 @@ def add_meta_arch(self, params: Params) -> None: self.gguf.add_file_type(params.ftype) def extract_vocabulary_from_model(self, vocab: Vocab) -> tuple[list[bytes], list[float], list[gguf.TokenType]]: + assert not isinstance(vocab, NoVocab) + tokens = [] scores = [] toktypes = []