Skip to content

convert.py: --pad-vocab not working with SPM, 'SentencePieceVocab' object has no attribute 'added_tokens_dict'. Did you mean: 'added_tokens_list'? #4958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
TheBloke opened this issue Jan 15, 2024 · 9 comments · Fixed by #4971

Comments

@TheBloke
Copy link
Contributor

TheBloke commented Jan 15, 2024

Hi guys

I've just noticed that since the recent convert.py refactor, the new --pad-vocab feature does not work with SPM vocabs. It does work as expected with HFFT. EDIT: actually there might be a different bug with HFFT, see next post on that.

Example command, converting model: https://huggingface.co/TigerResearch/tigerbot-13b-chat-v5

python3 ./convert.py /workspace/process/tigerresearch_tigerbot-13b-chat-v5/source --outtype f16 --outfile /workspace/process/tigerresearch_tigerbot-13b-chat-v5/gguf/tigerbot-13b-chat-v5.fp16.gguf --pad-vocab

Error message:

Writing /workspace/process/tigerresearch_tigerbot-13b-chat-v5/gguf/tigerbot-13b-chat-v5.fp16.gguf, format 1
Padding vocab with 2 token(s) - <dummy00001> through <dummy00002>
Traceback (most recent call last):
  File "/workspace/git/llama.cpp/./convert.py", line 1658, in <module>
    main(sys.argv[1:])  # Exclude the first element (script name) from sys.argv
    ^^^^^^^^^^^^^^^^^^
  File "/workspace/git/llama.cpp/./convert.py", line 1643, in main
    OutputFile.write_all(
  File "/workspace/git/llama.cpp/./convert.py", line 1188, in write_all
    check_vocab_size(params, vocab, pad_vocab=pad_vocab)
  File "/workspace/git/llama.cpp/./convert.py", line 1008, in check_vocab_size
    vocab.added_tokens_dict[f"<dummy{i:05}>"] = -1
    ^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'SentencePieceVocab' object has no attribute 'added_tokens_dict'. Did you mean: 'added_tokens_list'?

In this example, I did the conversion with --vocab-type hfft instead which worked OK.

Thanks in advance for looking at this.

@TheBloke
Copy link
Contributor Author

TheBloke commented Jan 15, 2024

I've got another issue with the same model. This is not directly related to the SPM issue, but I believe might be related to the convert.py refactor, so I'm including it here as well

As the SPM convert wasn't working, I made the TigerResearch model using the HFFT conversion, like so:

 python3 ./convert.py /workspace/process/tigerresearch_tigerbot-13b-chat-v5/source --outtype f16 --outfile /workspace/process/tigerresearch_tigerbot-13b-chat-v5/gguf/tigerbot-13b-chat-v5.fp16.gguf  --vocab-type hfft

This ran successfully. But the resulting FP16 cannot be used:

llama_model_load: error loading model: create_tensor: tensor 'token_embd.weight' has wrong shape; expected  5120, 65110, got  5120, 65112,     1,     1
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/workspace/process/tigerresearch_tigerbot-13b-chat-v5/gguf/tigerbot-13b-chat-v5.fp16.gguf'
main: error: unable to load model

The AWQ has completed and runs fine, so the source model appears fine.

If I'm understanding correctly, it looks like convert.py --vocab-type hfft --pad-vocab is not incrementing n_vocab as it should be.

Here is the top of the logs from convert.py:

 [py11torch] tomj@MC:/workspace/git/llama.cpp (master ✘)✭ ᐅ python3 ./convert.py /workspace/process/tigerresearch_tigerbot-13b-chat-v5/source --outtype f16 --outfile /workspace/process/tigerresearch_tigerbot-13b-chat-v5/gguf/tigerbot-13b-chat-v5.fp16.gguf  --vocab-type hfft
/workspace/git/llama.cpp/gguf-py
Loading model file /workspace/process/tigerresearch_tigerbot-13b-chat-v5/source/pytorch_model-00001-of-00003.bin
Loading model file /workspace/process/tigerresearch_tigerbot-13b-chat-v5/source/pytorch_model-00001-of-00003.bin
Loading model file /workspace/process/tigerresearch_tigerbot-13b-chat-v5/source/pytorch_model-00002-of-00003.bin
Loading model file /workspace/process/tigerresearch_tigerbot-13b-chat-v5/source/pytorch_model-00003-of-00003.bin
params = Params(n_vocab=65112, n_embd=5120, n_layer=40, n_ctx=2048, n_ff=13824, n_head=40, n_head_kv=40, f_norm_eps=1e-05, n_experts=None, n_experts_used=None, rope_scaling_type=None, f_rope_freq_base=10000, f_rope_scale=None, n_orig_ctx=None, rope_finetuned=None, ftype=<GGMLFileType.MostlyF16: 1>, path_model=PosixPath('/workspace/process/tigerresearch_tigerbot-13b-chat-v5/source'))
Loading vocab file '/workspace/process/tigerresearch_tigerbot-13b-chat-v5/source', type 'hfft'
fname_tokenizer: /workspace/process/tigerresearch_tigerbot-13b-chat-v5/source
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Permuting layer 0
...
model.embed_tokens.weight                        -> token_embd.weight                        | BF16   | [65112, 5120]

Note n_vocab = 65112

But then on running main:

llm_load_vocab: mismatch in special tokens definition ( 262/65110 vs 6/65110 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = llama
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 65110
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 2048
llm_load_print_meta: n_embd           = 5120
llm_load_print_meta: n_head           = 40
llm_load_print_meta: n_head_kv        = 40
llm_load_print_meta: n_layer          = 40
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 1
llm_load_print_meta: n_embd_k_gqa     = 5120
llm_load_print_meta: n_embd_v_gqa     = 5120
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-05
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: n_ff             = 13824
llm_load_print_meta: n_expert         = 0
llm_load_print_meta: n_expert_used    = 0
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 2048
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: model type       = 13B
llm_load_print_meta: model ftype      = F16
llm_load_print_meta: model params     = 13.35 B
llm_load_print_meta: model size       = 24.88 GiB (16.00 BPW)
llm_load_print_meta: general.name     = tigerresearch_tigerbot-13b-chat-v5
llm_load_print_meta: BOS token        = 1 '<s>'
llm_load_print_meta: EOS token        = 2 '</s>'
llm_load_print_meta: UNK token        = 0 '<unk>'
llm_load_print_meta: PAD token        = 65109 '<pad>'
llm_load_print_meta: LF token         = 13 '<0x0A>'
llm_load_tensors: ggml ctx size =    0.14 MiB
llama_model_load: error loading model: create_tensor: tensor 'token_embd.weight' has wrong shape; expected  5120, 65110, got  5120, 65112,     1,     1
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/workspace/process/tigerresearch_tigerbot-13b-chat-v5/gguf/tigerbot-13b-chat-v5.fp16.gguf'
main: error: unable to load model

Now it shows llm_load_print_meta: n_vocab = 65110

So maybe the hfft pad vocab is no longer functional since the refactor?

@ManuelFay
Copy link

same problem here as problem 2:

llama_model_load: error loading model: _Map_base::at
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model '/home/manuel/base_190k-GGUF/base_190k.Q4_K_M.gguf'
main: error: unable to load model

@ddh0
Copy link
Contributor

ddh0 commented Jan 16, 2024

same error here

@jhen0409
Copy link
Collaborator

It can be confirmed that it is caused by 6efb8eb (#4818). I just tested it by reverting the commit.

@BrickBee
Copy link

There's a potentially related issue when converting the https://huggingface.co/ise-uiuc/Magicoder-DS-6.7B model:
I've manually added pad tokens to added_tokens.json to get from the provided 32022 tokens to the 32256 vocab_size of the model as a quick test. This enabled the conversion to work, but the resulting model only produced mostly unreadable output - not even word fragments like for other errors. Maybe there's something else in addition here to be done in the context of #4419.

@ggerganov
Copy link
Member

Please check if #4971 resolves the original issue.

However, notice that there seems to be something wrong with this model - token_embd.weight and output.weight assume vocabulary of size 65112 but there are only 65110 tokens defined in the tokenizer. I've extended the convert.py script to pad with unknown tokens <unk65110> and <unk65111> to workaround this, though I think it is some sort of error in the model data

@BrickBee
Copy link

Please check if #4971 resolves the original issue.

Conversion with --pad-vocab works for me now. The converted Magicoder-DS still generates garbage, but that's potentially a separate issue.

@TheBloke
Copy link
Contributor Author

TheBloke commented Jan 16, 2024

@ggerganov Thanks for the fix, I can confirm that this model now converts in both SPM and HFFT tokenisation modes.

Regarding:

However, notice that there seems to be something wrong with this model - token_embd.weight and output.weight assume vocabulary of size 65112 but there are only 65110 tokens defined in the tokenizer. I've extended the convert.py script to pad with unknown tokens and to workaround this, though I think it is some sort of error in the model data

It's not uncommon for models to require a larger vocab than they provide tokens for. Some model trainers do this deliberately, so as to increase the vocab size to an amount which is equally divisible for tensor parallelism. I don't know if that's the case here or if this one is just a mistake, but in any case it's fairly common to see this mismatch, which is why I originally asked for --pad-vocab (before that was added, I used to add dummy tokens in my own code, artificially increasing added_tokens.json with extra dummy tokens)

The extra tokens are no issue with Transformers models, but of course breaks llama.cpp, so we need the workaround of adding extra dummy tokens.

FYI with your fix, we now seem to have a bit of a duplication / re-definition of the extra tokens added:

Padding vocab with 2 token(s) - <dummy00001> through <dummy00002>
gguf: This GGUF file is for Little Endian only
Warning: token 65110 not found in vocab - padding with b'<unk65110>'
Warning: token 65111 not found in vocab - padding with b'<unk65111>'

The --pad-vocab flag is indicating it's adding N new tokens called dummy but then with your PR that's overriden by the addition of unk tokens instead.

Doesn't really matter as it now works fine, and it doesn't matter what the extra tokens are called, but I thought I'd highlight that discrepancy for a future patch.

Thanks very much for the fix!

@ggerganov
Copy link
Member

Thanks for the feedback! Yup, the implementation is probably not great - I will revisit tomorrow and see if I can improve it and avoid this duplication

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants