Skip to content

Commit d0330fd

Browse files
authored
py : add capabiliy to convert from ggml back to torch or hf format for further consumption/training/finetuning (#403)
1 parent 99c5b27 commit d0330fd

File tree

1 file changed

+294
-0
lines changed

1 file changed

+294
-0
lines changed

convert_ggml_to_pth.py

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
# Author: github.com/ductai199x
2+
import argparse
3+
import os
4+
import struct
5+
6+
import numpy as np
7+
import torch
8+
from numba import njit
9+
from tqdm.auto import tqdm
10+
11+
12+
def read_header(fin):
13+
values = struct.unpack("i" * 9, fin.read(4 * 9))
14+
_, _, vocab_size, dim, multiple_of, n_heads, n_layers, rot, ftype = values
15+
return {
16+
"vocab_size": vocab_size,
17+
"dim": dim,
18+
"multiple_of": multiple_of,
19+
"n_heads": n_heads,
20+
"n_layers": n_layers,
21+
}, ftype
22+
23+
24+
def read_tokens(fin, vocab_size):
25+
tokens = []
26+
for _ in range(vocab_size):
27+
text_len = struct.unpack("i", fin.read(4))[0]
28+
text_bytes = fin.read(text_len)
29+
try:
30+
text = text_bytes.decode("utf-8")
31+
except UnicodeDecodeError:
32+
text = text_bytes.decode("utf-8", "replace")
33+
score = struct.unpack("f", fin.read(4))[0]
34+
tokens.append((text, score))
35+
return tokens
36+
37+
38+
@njit
39+
def dequantize_weights_numba(fin_data, n_rows, n_cols):
40+
qk = 32
41+
nb = n_cols // qk
42+
bs = 4 + (qk // 2)
43+
44+
weights = np.zeros((n_rows, n_cols), dtype=np.float32)
45+
data_pos = 0
46+
47+
for row in range(n_rows):
48+
for block in range(nb):
49+
d = np.frombuffer(fin_data[data_pos : data_pos + 4], dtype=np.float32)[0]
50+
data_pos += 4
51+
packed_values = fin_data[data_pos : data_pos + (qk // 2)]
52+
data_pos += qk // 2
53+
54+
for i in range(qk // 2):
55+
packed_value = packed_values[i]
56+
v0 = np.float32((packed_value & 0b00001111) - 8) * d
57+
v1 = np.float32((packed_value >> 4) - 8) * d
58+
59+
weights[row, block * qk + 2 * i] = v0
60+
weights[row, block * qk + 2 * i + 1] = v1
61+
62+
return weights
63+
64+
65+
def dequantize_weights(fin, n_rows, n_cols):
66+
qk = 32
67+
nb = n_cols // qk
68+
data_size = n_rows * n_cols // 2 + n_rows * nb * 4
69+
fin_data = fin.read(data_size)
70+
return dequantize_weights_numba(fin_data, n_rows, n_cols)
71+
72+
73+
def read_variables(fin):
74+
model = {}
75+
pbar = tqdm(total=os.path.getsize(fin.name), unit="B", unit_scale=True, desc="Reading variables")
76+
while True:
77+
start_pos = fin.tell()
78+
try:
79+
n_dims, name_length, ftype_cur = struct.unpack("iii", fin.read(4 * 3))
80+
except struct.error:
81+
break
82+
83+
shape = tuple(struct.unpack("i" * n_dims, fin.read(4 * n_dims)))
84+
shape = shape[::-1]
85+
name = fin.read(name_length).decode("utf-8")
86+
87+
if ftype_cur == 2:
88+
# 4-bit quantized weights
89+
dtype = np.uint8
90+
data = dequantize_weights(fin, shape[0], shape[1])
91+
data = data.reshape(shape)
92+
elif ftype_cur == 0:
93+
dtype = np.float32
94+
data_size = np.prod(shape)
95+
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
96+
elif ftype_cur == 1:
97+
dtype = np.float16
98+
data_size = np.prod(shape)
99+
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
100+
101+
model[name] = torch.tensor(data, dtype=torch.float32 if dtype == np.float32 else torch.float16)
102+
103+
pbar.update(fin.tell() - start_pos)
104+
105+
return model
106+
107+
108+
def convert_to_hf_format(model, hparams):
109+
# This works for llama 7B, need to test with other models
110+
n_layers = hparams["n_layers"]
111+
n_heads = hparams["n_heads"]
112+
dim = hparams["dim"]
113+
dims_per_head = dim // n_heads
114+
base = 10000.0
115+
inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
116+
117+
# permute for sliced rotary
118+
def permute(w):
119+
return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
120+
121+
state_dict = {}
122+
for layer_i in range(n_layers):
123+
state_dict.update(
124+
{
125+
f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
126+
model[f"layers.{layer_i}.attention.wq.weight"]
127+
),
128+
f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
129+
model[f"layers.{layer_i}.attention.wk.weight"]
130+
),
131+
f"model.layers.{layer_i}.self_attn.v_proj.weight": model[
132+
f"layers.{layer_i}.attention.wv.weight"
133+
],
134+
f"model.layers.{layer_i}.self_attn.o_proj.weight": model[
135+
f"layers.{layer_i}.attention.wo.weight"
136+
],
137+
f"model.layers.{layer_i}.mlp.gate_proj.weight": model[
138+
f"layers.{layer_i}.feed_forward.w1.weight"
139+
],
140+
f"model.layers.{layer_i}.mlp.down_proj.weight": model[
141+
f"layers.{layer_i}.feed_forward.w2.weight"
142+
],
143+
f"model.layers.{layer_i}.mlp.up_proj.weight": model[
144+
f"layers.{layer_i}.feed_forward.w3.weight"
145+
],
146+
f"model.layers.{layer_i}.input_layernorm.weight": model[
147+
f"layers.{layer_i}.attention_norm.weight"
148+
],
149+
f"model.layers.{layer_i}.post_attention_layernorm.weight": model[
150+
f"layers.{layer_i}.ffn_norm.weight"
151+
],
152+
}
153+
)
154+
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
155+
state_dict.update(
156+
{
157+
"model.embed_tokens.weight": model["tok_embeddings.weight"],
158+
"model.norm.weight": model["norm.weight"],
159+
"lm_head.weight": model["output.weight"],
160+
}
161+
)
162+
163+
return state_dict
164+
165+
166+
def chat(model, hparams, llama_dir):
167+
from transformers import (GenerationConfig, LlamaForCausalLM,
168+
LlamaTokenizer, StoppingCriteria,
169+
StoppingCriteriaList)
170+
from transformers.models.llama.configuration_llama import LlamaConfig
171+
172+
class StoppingCriteriaSub(StoppingCriteria):
173+
def __init__(self):
174+
super().__init__()
175+
176+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops=[]):
177+
print(tokenizer.decode(input_ids[0]), end="", flush=True)
178+
if input_ids[0][-1] == 13:
179+
return True
180+
181+
return False
182+
183+
config = LlamaConfig(
184+
vocab_size=hparams["vocab_size"],
185+
dim=hparams["dim"],
186+
num_hidden_layers=hparams["n_layers"],
187+
num_attention_heads=hparams["n_heads"],
188+
)
189+
190+
llama = LlamaForCausalLM(config=config)
191+
llama.load_state_dict(state_dict=model, strict=True)
192+
tokenizer = LlamaTokenizer.from_pretrained(llama_dir)
193+
194+
device = torch.device("cpu")
195+
llama = llama.to(device)
196+
197+
ctx = """You are AI.
198+
This is a dialog, where User interacts with AI. AI is helpful, kind, obedient, honest, respectful, direct, concise, should try to protect User's privacy, and knows its own limits. Also, AI must answer User and AI cannot stop the conversation by itself.
199+
User: Hello, AI.
200+
AI: Hello! How can I assist you today?
201+
"""
202+
print(ctx.rstrip("\n"))
203+
while True:
204+
print("-" * 60)
205+
prompt = input(f"User: ")
206+
if ctx != "":
207+
ctx = ctx + "User: " + prompt + "\n"
208+
else:
209+
ctx = prompt + "\nAI:"
210+
211+
ctx = (ctx[-1920:]) if len(ctx) >= 2048 else ctx
212+
213+
print("-" * 60)
214+
if len(ctx.strip()) > 0:
215+
input_ids = tokenizer(ctx, return_tensors="pt")["input_ids"].to(device)
216+
generation_config = GenerationConfig(
217+
temperature=0.8,
218+
top_p=0.95,
219+
top_k=50,
220+
repetition_penalty=1.1764,
221+
)
222+
with torch.no_grad():
223+
generation_output = llama.generate(
224+
input_ids=input_ids,
225+
generation_config=generation_config,
226+
return_dict_in_generate=True,
227+
output_scores=True,
228+
max_length=2048,
229+
do_sample=True,
230+
stopping_criteria=StoppingCriteriaList([StoppingCriteriaSub()]),
231+
)
232+
s = generation_output.sequences[0]
233+
decoded = tokenizer.decode(s)
234+
ctx = decoded + "\n"
235+
236+
237+
def main():
238+
parser = argparse.ArgumentParser()
239+
parser.add_argument(
240+
"--input_dir", "-i", type=str, required=True, help="The input directory containing the ggml files."
241+
)
242+
parser.add_argument(
243+
"--prefix",
244+
"-p",
245+
type=str,
246+
required=True,
247+
help="The prefix of the ggml files (ggml-model-f16 or ggml-model-q4_0).",
248+
)
249+
parser.add_argument(
250+
"--hf",
251+
action="store_true",
252+
help="Whether to save the model in the huggingface format. (default: False)",
253+
)
254+
parser.add_argument(
255+
"--chat", "-c", action="store_true", help="Whether to open a chat with the model. (default: False)"
256+
)
257+
args = parser.parse_args()
258+
259+
llama_dir = os.path.abspath(f"{args.input_dir}/../")
260+
261+
ggml_files = sorted(
262+
[f"{args.input_dir}/{f}" for f in os.listdir(args.input_dir) if f.startswith(args.prefix)]
263+
)
264+
265+
fin = open(ggml_files[0], "rb")
266+
hparams, ftype = read_header(fin)
267+
tokens = read_tokens(fin, hparams["vocab_size"])
268+
model = read_variables(fin)
269+
270+
for f in tqdm(ggml_files[1:]):
271+
fin = open(f, "rb")
272+
read_header(fin)
273+
read_tokens(fin, hparams["vocab_size"])
274+
model.update(read_variables(fin))
275+
276+
if args.hf:
277+
model = convert_to_hf_format(model, hparams)
278+
279+
pth_ckpt = {
280+
"state_dict": model,
281+
"hparams": hparams,
282+
"tokens": tokens,
283+
}
284+
285+
torch.save(pth_ckpt, f"{args.input_dir}/{args.prefix}-to-torch.pth")
286+
287+
if args.chat:
288+
if not args.hf:
289+
model = convert_to_hf_format(model, hparams)
290+
chat(model, hparams, llama_dir)
291+
292+
293+
if __name__ == "__main__":
294+
main()

0 commit comments

Comments
 (0)