Skip to content

Converting GGML back to Torch checkpoint for HuggingFace/Pytorch consumption/training/finetuning #403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 294 additions & 0 deletions convert_ggml_to_pth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
# Author: github.com/ductai199x
import argparse
import os
import struct

import numpy as np
import torch
from numba import njit
from tqdm.auto import tqdm


def read_header(fin):
values = struct.unpack("i" * 9, fin.read(4 * 9))
_, _, vocab_size, dim, multiple_of, n_heads, n_layers, rot, ftype = values
return {
"vocab_size": vocab_size,
"dim": dim,
"multiple_of": multiple_of,
"n_heads": n_heads,
"n_layers": n_layers,
}, ftype


def read_tokens(fin, vocab_size):
tokens = []
for _ in range(vocab_size):
text_len = struct.unpack("i", fin.read(4))[0]
text_bytes = fin.read(text_len)
try:
text = text_bytes.decode("utf-8")
except UnicodeDecodeError:
text = text_bytes.decode("utf-8", "replace")
score = struct.unpack("f", fin.read(4))[0]
tokens.append((text, score))
return tokens


@njit
def dequantize_weights_numba(fin_data, n_rows, n_cols):
qk = 32
nb = n_cols // qk
bs = 4 + (qk // 2)

weights = np.zeros((n_rows, n_cols), dtype=np.float32)
data_pos = 0

for row in range(n_rows):
for block in range(nb):
d = np.frombuffer(fin_data[data_pos : data_pos + 4], dtype=np.float32)[0]
data_pos += 4
packed_values = fin_data[data_pos : data_pos + (qk // 2)]
data_pos += qk // 2

for i in range(qk // 2):
packed_value = packed_values[i]
v0 = np.float32((packed_value & 0b00001111) - 8) * d
v1 = np.float32((packed_value >> 4) - 8) * d

weights[row, block * qk + 2 * i] = v0
weights[row, block * qk + 2 * i + 1] = v1

return weights


def dequantize_weights(fin, n_rows, n_cols):
qk = 32
nb = n_cols // qk
data_size = n_rows * n_cols // 2 + n_rows * nb * 4
fin_data = fin.read(data_size)
return dequantize_weights_numba(fin_data, n_rows, n_cols)


def read_variables(fin):
model = {}
pbar = tqdm(total=os.path.getsize(fin.name), unit="B", unit_scale=True, desc="Reading variables")
while True:
start_pos = fin.tell()
try:
n_dims, name_length, ftype_cur = struct.unpack("iii", fin.read(4 * 3))
except struct.error:
break

shape = tuple(struct.unpack("i" * n_dims, fin.read(4 * n_dims)))
shape = shape[::-1]
name = fin.read(name_length).decode("utf-8")

if ftype_cur == 2:
# 4-bit quantized weights
dtype = np.uint8
data = dequantize_weights(fin, shape[0], shape[1])
data = data.reshape(shape)
elif ftype_cur == 0:
dtype = np.float32
data_size = np.prod(shape)
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
elif ftype_cur == 1:
dtype = np.float16
data_size = np.prod(shape)
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)

model[name] = torch.tensor(data, dtype=torch.float32 if dtype == np.float32 else torch.float16)

pbar.update(fin.tell() - start_pos)

return model


def convert_to_hf_format(model, hparams):
# This works for llama 7B, need to test with other models
n_layers = hparams["n_layers"]
n_heads = hparams["n_heads"]
dim = hparams["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))

# permute for sliced rotary
def permute(w):
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)

state_dict = {}
for layer_i in range(n_layers):
state_dict.update(
{
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
model[f"layers.{layer_i}.attention.wq.weight"]
),
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
model[f"layers.{layer_i}.attention.wk.weight"]
),
f"model.layers.{layer_i}.self_attn.v_proj.weight": model[
f"layers.{layer_i}.attention.wv.weight"
],
f"model.layers.{layer_i}.self_attn.o_proj.weight": model[
f"layers.{layer_i}.attention.wo.weight"
],
f"model.layers.{layer_i}.mlp.gate_proj.weight": model[
f"layers.{layer_i}.feed_forward.w1.weight"
],
f"model.layers.{layer_i}.mlp.down_proj.weight": model[
f"layers.{layer_i}.feed_forward.w2.weight"
],
f"model.layers.{layer_i}.mlp.up_proj.weight": model[
f"layers.{layer_i}.feed_forward.w3.weight"
],
f"model.layers.{layer_i}.input_layernorm.weight": model[
f"layers.{layer_i}.attention_norm.weight"
],
f"model.layers.{layer_i}.post_attention_layernorm.weight": model[
f"layers.{layer_i}.ffn_norm.weight"
],
}
)
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
state_dict.update(
{
"model.embed_tokens.weight": model["tok_embeddings.weight"],
"model.norm.weight": model["norm.weight"],
"lm_head.weight": model["output.weight"],
}
)

return state_dict


def chat(model, hparams, llama_dir):
from transformers import (GenerationConfig, LlamaForCausalLM,
LlamaTokenizer, StoppingCriteria,
StoppingCriteriaList)
from transformers.models.llama.configuration_llama import LlamaConfig

class StoppingCriteriaSub(StoppingCriteria):
def __init__(self):
super().__init__()

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]):
print(tokenizer.decode(input_ids[0]), end="", flush=True)
if input_ids[0][-1] == 13:
return True

return False

config = LlamaConfig(
vocab_size=hparams["vocab_size"],
dim=hparams["dim"],
num_hidden_layers=hparams["n_layers"],
num_attention_heads=hparams["n_heads"],
)

llama = LlamaForCausalLM(config=config)
llama.load_state_dict(state_dict=model, strict=True)
tokenizer = LlamaTokenizer.from_pretrained(llama_dir)

device = torch.device("cpu")
llama = llama.to(device)

ctx = """You are AI.
This is a dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, respectful, direct, concise, should try to protect User's privacy, and knows its own limits. Also, AI must answer User and AI cannot stop the conversation by itself.
User: Hello, AI.
AI: Hello! How can I assist you today?
"""
print(ctx.rstrip("\n"))
while True:
print("-" * 60)
prompt = input(f"User: ")
if ctx != "":
ctx = ctx + "User: " + prompt + "\n"
else:
ctx = prompt + "\nAI:"

ctx = (ctx[-1920:]) if len(ctx) >= 2048 else ctx

print("-" * 60)
if len(ctx.strip()) > 0:
input_ids = tokenizer(ctx, return_tensors="pt")["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=0.8,
top_p=0.95,
top_k=50,
repetition_penalty=1.1764,
)
with torch.no_grad():
generation_output = llama.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_length=2048,
do_sample=True,
stopping_criteria=StoppingCriteriaList([StoppingCriteriaSub()]),
)
s = generation_output.sequences[0]
decoded = tokenizer.decode(s)
ctx = decoded + "\n"


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_dir", "-i", type=str, required=True, help="The input directory containing the ggml files."
)
parser.add_argument(
"--prefix",
"-p",
type=str,
required=True,
help="The prefix of the ggml files (ggml-model-f16 or ggml-model-q4_0).",
)
parser.add_argument(
"--hf",
action="store_true",
help="Whether to save the model in the huggingface format. (default: False)",
)
parser.add_argument(
"--chat", "-c", action="store_true", help="Whether to open a chat with the model. (default: False)"
)
args = parser.parse_args()

llama_dir = os.path.abspath(f"{args.input_dir}/../")

ggml_files = sorted(
[f"{args.input_dir}/{f}" for f in os.listdir(args.input_dir) if f.startswith(args.prefix)]
)

fin = open(ggml_files[0], "rb")
hparams, ftype = read_header(fin)
tokens = read_tokens(fin, hparams["vocab_size"])
model = read_variables(fin)

for f in tqdm(ggml_files[1:]):
fin = open(f, "rb")
read_header(fin)
read_tokens(fin, hparams["vocab_size"])
model.update(read_variables(fin))

if args.hf:
model = convert_to_hf_format(model, hparams)

pth_ckpt = {
"state_dict": model,
"hparams": hparams,
"tokens": tokens,
}

torch.save(pth_ckpt, f"{args.input_dir}/{args.prefix}-to-torch.pth")

if args.chat:
if not args.hf:
model = convert_to_hf_format(model, hparams)
chat(model, hparams, llama_dir)


if __name__ == "__main__":
main()