Description
As reported in #6944 (comment)
The llama.cpp tokenizers give different results than HF for old GGUF files.
This is a subtle footgun and at least there should be a warning, since it is impossible now to determine what at what vintage your old GGUF models suddenly spoil.
Right now, the only reliable way to determine this is by running perplexity.cpp and comparing it to HF. The key numbers for the first 512-toks of wiki-2-test are as follows.
Using llama.cpp tokenizer:
model. quant perplexity
llama.cpp Q8_0 perplexity: 15.4660
Llama-CPP-python Q8_0 perplexity: 15.392684936523438
llama.cpp Q5_K_M perplexity: 15.6994
Llama-CPP-python Q5_K_M perplexity: 15.637877464294434
Using HF tokenizer and passing those tokens into different model implementations:
model quant perplexity
Huggingface perplexity: 6.205880641937256
Llama-CPP-python Q8_0 perplexity: 6.204566478729248
Llama-CPP-python Q5_K_M perplexity: 6.228440761566162
This is demonstrated through an attached notebook, which you can play with at this colab. I'll paste the code below too.
# -*- coding: utf-8 -*-
"""Tokenizer: HF vs llama-cpp-python vs llama.cpp (perplexity.cpp)
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1RYlEj2UhylYWyaASFo-LLATzZ8d29Z0T
:# Tokenizer: HF vs llama-cpp-python vs. llama.cpp (perplexity.cpp)
We show that an OLD previously converted TinyLlama GGUF a) has buggy tokenization in llama.cpp and b) llama.cpp doesn't provide any warning.
This leads to unusually bad perplexity. For simplicity, we report the perplexity of the first 512-token window of wikitext-2-raw-test.
If we use the HF tokenizer and feed the output the llama.cpp, we get the perplexity we expect.
Using llama.cpp tokenizer:
model. quant perplexity
llama.cpp Q8_0 perplexity: 15.4660
Llama-CPP-python Q8_0 perplexity: 15.392684936523438
llama.cpp Q5_K_M perplexity: 15.6994
Llama-CPP-python Q5_K_M perplexity: 15.637877464294434
Using HF tokenizer and passing those tokens into different model implementations:
model quant perplexity
Huggingface perplexity: 6.205880641937256
Llama-CPP-python Q8_0 perplexity: 6.204566478729248
Llama-CPP-python Q5_K_M perplexity: 6.228440761566162
"""
N_CTX = 512 # perplexity.cpp default
"""We trim to roughly 1024 tokens of text because the context window is 512 and due to pecularities in how perplexity.cpp is implemented we can't just work with 512 tokens of text.
Note that all our measurements are JUST over the first 512 tokens of wikitext-2-raw-test. Other windows are dummies needed to get perplexity.cpp to run.
"""
!wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip
!unzip wikitext-2-raw-v1.zip
# Install llama-cpp-python
!pip install llama-cpp-python \
--extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
# Download quantized TinyLlama models for llama-cpp-python
!wget https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf
!wget https://huggingface.co/TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF/resolve/main/tinyllama-1.1b-chat-v1.0.Q8_0.gguf
# Download Huggingface TinyLlama model and tokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
hf_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
hf_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
# Load llama-cpp-python quantized models
from llama_cpp import Llama
gguf_models = [
"tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf",
"tinyllama-1.1b-chat-v1.0.Q8_0.gguf"
]
llama_models = [Llama(model_path=path, logits_all=True) for path in gguf_models]
"""3612 characters gives us 1024 tokens in llama.cpp tokenizer and 1032 tokens in HF tokenizer."""
text = open("wikitext-2-raw/wiki.test.raw").read()[:3612]
open("wikitext-2-test-3612.raw", "wt").write(text)
hf_tokens = hf_tokenizer.encode(text, add_special_tokens=True)
print(f"n HF tokens: {len(hf_tokens)}")
for llama_model, gguf_file in zip(llama_models, gguf_models):
llama_tokens = llama_model.tokenize(text.encode("utf-8"))
print(f"n Llama-CPP {gguf_file} tokens: {len(llama_tokens)}")
"""As a sanity check, we will use both a hand-rolled perplexity implementation AND the one from perplexity.cpp (later in the notebook)"""
# Helper Functions
import torch
import numpy as np
from scipy.special import log_softmax
def get_logits_hf(model, tokens):
"""Get logits from a Huggingface Transformers model.
Preconditions:
- model is a valid Huggingface Transformers model
- tokens is a list of integers
Postconditions:
- logits is a 2D numpy array of shape (len(tokens), vocab_size)
"""
assert isinstance(tokens, list)
assert all(isinstance(t, int) for t in tokens)
input_ids = torch.tensor(tokens).unsqueeze(0)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
logits = outputs.logits.squeeze(0).cpu().numpy()
assert logits.ndim == 2
assert logits.shape[0] == len(tokens), f"logits.shape: {logits.shape}, tokens: {len(tokens)} = {tokens}"
return logits
# Method to clear llama context
def clear_llama_context(llama):
"""Clear the llama context and reset the number of tokens."""
llama.reset()
llama._ctx.kv_cache_clear()
llama.input_ids.fill(0)
llama.scores.fill(0)
def get_logits_llama(model, tokens):
"""Get logits from a llama-cpp-python model.
Preconditions:
- model is a valid llama-cpp-python model
- tokens is a list of integers
Postconditions:
- logits is a 2D numpy array of shape (len(tokens), vocab_size)
"""
assert isinstance(model, Llama)
assert isinstance(tokens, list)
assert all(isinstance(t, int) for t in tokens)
clear_llama_context(model)
assert model.n_tokens == 0
model.eval(tokens)
assert model.n_tokens == len(tokens)
logits = np.array(model.scores).reshape(-1, model.n_vocab())[:len(tokens)]
assert logits.ndim == 2
assert logits.shape[0] == len(tokens), f"logits.shape: {logits.shape}, tokens: {len(tokens)} = {tokens}"
return logits
def compute_token_nlls(logits, tokens):
"""Compute NLLs of all tokens from logits and target tokens.
Preconditions:
- logits is a 2D numpy array of shape (len(tokens), vocab_size)
- tokens is a list of integers with at least 2 elements
Postconditions:
- nlls is a numpy array of floats representing NLL of each token
"""
assert logits.ndim == 2
assert len(tokens) >= 2
assert logits.shape[0] == len(tokens), f"logits.shape: {logits.shape}, tokens: {len(tokens)} = {tokens}"
target_tokens = tokens[1:]
log_probs = log_softmax(logits[:-1], axis=-1)
nlls = -log_probs[np.arange(len(target_tokens)), target_tokens]
assert isinstance(nlls, np.ndarray)
assert nlls.ndim == 1
assert len(nlls) == len(target_tokens)
return nlls
"""Here we see the huggingface tokens and perplexity:"""
# Compute NLLs for Huggingface model
hf_logits = get_logits_hf(hf_model, hf_tokens[:N_CTX])
assert hf_logits.shape[0] == N_CTX
halfway = N_CTX // 2
hf_nll = compute_token_nlls(hf_logits[halfway:,...], hf_tokens[halfway:N_CTX])
print(f"Huggingface tokens:", hf_tokens[:N_CTX])
print(f"Huggingface NLL: {np.mean(hf_nll)}")
print(f"Huggingface perplexity: {np.exp(np.mean(hf_nll))}")
"""Now we see llama-cpp-python tokens and perplexity.
The GGUF llama-cpp-python tokenizers are different after the first few tokens. This really hurts perplexity and NLL.
"""
# Compute NLLs for llama-cpp-python models
for model, gguf_file in zip(llama_models, gguf_models):
# llama_tokens = model.tokenize(text.encode("utf-8"), add_bos=True)
llama_tokens = model.tokenize(text.encode("utf-8"))
llama_logits = get_logits_llama(model, llama_tokens[:N_CTX])
assert llama_logits.shape[0] == N_CTX
halfway = N_CTX // 2
llama_nll = compute_token_nlls(llama_logits[halfway:,...], llama_tokens[halfway:N_CTX])
print(f"Llama-CPP {gguf_file} tokens:", llama_tokens[:N_CTX])
print(f"Llama-CPP {gguf_file} NLL: {np.mean(llama_nll)}")
print(f"Llama-CPP {gguf_file} perplexity: {np.exp(np.mean(llama_nll))}")
"""But if we use HF tokens with llama-cpp-python, we get identical tokens and nearly identical perplexity:"""
# Compute NLLs for llama-cpp-python models, but using HF tokens
for model, gguf_file in zip(llama_models, gguf_models):
llama_logits = get_logits_llama(model, hf_tokens[:N_CTX])
assert llama_logits.shape[0] == N_CTX
halfway = N_CTX // 2
llama_nll = compute_token_nlls(llama_logits[halfway:,...], hf_tokens[halfway:N_CTX])
print(f"HF tokens:", hf_tokens[:N_CTX])
print(f"Llama-CPP {gguf_file} on HF tokens NLL: {np.mean(llama_nll)}")
print(f"Llama-CPP {gguf_file} on HF tokens perplexity: {np.exp(np.mean(llama_nll))}")
"""Here I use llama.cpp main, so we can make sure the tokens and perplexity are same as llama-cpp-python above."""
!git clone https://github.com/ggerganov/llama.cpp.git
!cd llama.cpp ; git log -1 --format="Commit Hash: %H%nCommit Date: %cd" --date=iso
"""This gnarly little one-liner just patches perplexity.cpp to output the token lists."""
#!perl -i -pe 's/(std::vector<llama_token> tokens = ::llama_tokenize\(.*\);)/$1 fprintf\(stderr, "%s: %d tokens\\n", __func__, int\(tokens.size\(\)\)\);/' llama.cpp/examples/perplexity/perplexity.cpp
!perl -i -pe 's/(std::vector<llama_token> tokens = ::llama_tokenize\(.*\);)/$1 fprintf\(stderr, "%s: %d tokens: ", __func__, int\(tokens.size\(\)\)\); for \(const auto& token : tokens\) { fprintf\(stderr, "%d ", int\(token\)\); } fprintf\(stderr, "\\n"\);/' llama.cpp/examples/perplexity/perplexity.cpp
!cd llama.cpp ; git diff
!cd llama.cpp && make -j2
"""Now, we see that llama.cpp (and not just llama-cpp-python) give bogus non-HF tokens with this old GGUF. And the perplexity scores of the FIRST WINDOW (`[1]`) are almost identical those above, and differ slightly possibly due to seeding which I didn't control for:
below:
llama.cpp Q5_K_M perplexity: 15.6994
llama.cpp Q5_K_M perplexity: 15.4660
above:
Llama-CPP tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf perplexity: 15.637877464294434
Llama-CPP tinyllama-1.1b-chat-v1.0.Q8_0.gguf perplexity: 15.392684936523438
"""
!llama.cpp/perplexity -f "wikitext-2-test-3612.raw" -m tinyllama-1.1b-chat-v1.0.Q5_K_M.gguf
!llama.cpp/perplexity -f "wikitext-2-test-3612.raw" -m tinyllama-1.1b-chat-v1.0.Q8_0.gguf