Skip to content

Commit 04e48c0

Browse files
author
junhuihe
committed
Enable conversion from .safetensors checkpoints to gguf files
1 parent 69a2045 commit 04e48c0

File tree

5 files changed

+232
-0
lines changed

5 files changed

+232
-0
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,17 @@ python utils/generate-dummy-bitnet-model.py models/bitnet_b1_58-large --outfile
292292
# Run benchmark with the generated model, use -m to specify the model path, -p to specify the prompt processed, -n to specify the number of token to generate
293293
python utils/e2e_benchmark.py -m models/dummy-bitnet-125m.tl1.gguf -p 512 -n 128
294294
```
295+
296+
### Convert from `.safetensors` Checkpoints
297+
298+
```sh
299+
# Prepare the .safetensors model file
300+
huggingface-cli download microsoft/bitnet-b1.58-2B-4T-bf16 --local-dir ./models/bitnet-b1.58-2B-4T-bf16
301+
302+
# Convert to gguf model
303+
python ./utils/convert-helper-bitnet.py ./models/bitnet-b1.58-2B-4T-bf16
304+
```
305+
295306
### FAQ (Frequently Asked Questions)📌
296307

297308
#### Q1: The build dies with errors building llama.cpp due to issues with std::chrono in log.cpp?

utils/convert-helper-bitnet.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
#!/usr/bin/env python3
2+
3+
import sys
4+
import os
5+
import shutil
6+
import subprocess
7+
from pathlib import Path
8+
9+
def run_command(command_list, cwd=None, check=True):
10+
print(f"Executing: {' '.join(map(str, command_list))}")
11+
try:
12+
process = subprocess.run(command_list, cwd=cwd, check=check, capture_output=False, text=True)
13+
return process
14+
except subprocess.CalledProcessError as e:
15+
print(f"Error executing command: {' '.join(map(str, e.cmd))}")
16+
print(f"Return code: {e.returncode}")
17+
raise
18+
19+
def main():
20+
if len(sys.argv) < 2:
21+
script_name = Path(sys.argv[0]).name
22+
print(f"Usage: python {script_name} <model-directory>")
23+
sys.exit(1)
24+
25+
model_dir_arg = sys.argv[1]
26+
model_dir = Path(model_dir_arg).resolve()
27+
28+
if not model_dir.is_dir():
29+
print(f"Error: Model directory '{model_dir}' not found or is not a directory.")
30+
sys.exit(1)
31+
32+
utils_dir = Path(__file__).parent.resolve()
33+
project_root_dir = utils_dir.parent
34+
35+
preprocess_script = utils_dir / "preprocess-huggingface-bitnet.py"
36+
convert_script = utils_dir / "convert-ms-to-gguf-bitnet.py"
37+
38+
llama_quantize_binary = project_root_dir / "build" / "bin" / "llama-quantize"
39+
40+
input_file = model_dir / "model.safetensors"
41+
input_backup_file = model_dir / "model.safetensors.backup"
42+
preprocessed_output_file = model_dir / "model.safetensors"
43+
44+
gguf_f32_output = model_dir / "ggml-model-f32-bitnet.gguf"
45+
gguf_i2s_output = model_dir / "ggml-model-i2s-bitnet.gguf"
46+
47+
if not preprocess_script.is_file():
48+
print(f"Error: Preprocess script not found at '{preprocess_script}'")
49+
sys.exit(1)
50+
if not convert_script.is_file():
51+
print(f"Error: Convert script not found at '{convert_script}'")
52+
sys.exit(1)
53+
if not llama_quantize_binary.is_file():
54+
print(f"Error: llama-quantize binary not found at '{llama_quantize_binary}'")
55+
sys.exit(1)
56+
57+
if not input_file.is_file():
58+
print(f"Error: Input safetensors file not found at '{input_file}'")
59+
sys.exit(1)
60+
61+
try:
62+
print(f"Backing up '{input_file}' to '{input_backup_file}'")
63+
if input_backup_file.exists():
64+
print(f"Warning: Removing existing backup file '{input_backup_file}'")
65+
input_backup_file.unlink()
66+
shutil.move(input_file, input_backup_file)
67+
68+
print("Preprocessing huggingface checkpoint...")
69+
cmd_preprocess = [
70+
sys.executable,
71+
str(preprocess_script),
72+
"--input", str(input_backup_file),
73+
"--output", str(preprocessed_output_file)
74+
]
75+
run_command(cmd_preprocess)
76+
77+
print("Converting to GGUF (f32)...")
78+
cmd_convert = [
79+
sys.executable,
80+
str(convert_script),
81+
str(model_dir),
82+
"--vocab-type", "bpe",
83+
"--outtype", "f32",
84+
"--concurrency", "1",
85+
"--outfile", str(gguf_f32_output)
86+
]
87+
run_command(cmd_convert)
88+
89+
print("Quantizing model to I2_S...")
90+
cmd_quantize = [
91+
str(llama_quantize_binary),
92+
str(gguf_f32_output),
93+
str(gguf_i2s_output),
94+
"I2_S",
95+
"1"
96+
]
97+
run_command(cmd_quantize)
98+
99+
print("Convert successfully.")
100+
101+
except Exception as e:
102+
print(f"An error occurred: {e}")
103+
finally:
104+
print("Cleaning up intermediate files...")
105+
if preprocessed_output_file.exists() and preprocessed_output_file != input_backup_file:
106+
print(f"Removing preprocessed file: {preprocessed_output_file}")
107+
try:
108+
preprocessed_output_file.unlink()
109+
except OSError as e:
110+
print(f"Warning: Could not remove {preprocessed_output_file}: {e}")
111+
112+
if gguf_f32_output.exists():
113+
print(f"Removing f32 GGUF: {gguf_f32_output}")
114+
try:
115+
gguf_f32_output.unlink()
116+
except OSError as e:
117+
print(f"Warning: Could not remove {gguf_f32_output}: {e}")
118+
119+
if input_backup_file.exists():
120+
if not input_file.exists():
121+
print(f"Restoring original '{input_file}' from '{input_backup_file}'")
122+
try:
123+
shutil.move(input_backup_file, input_file)
124+
except Exception as e:
125+
print(f"Warning: Could not restore {input_file} from backup: {e}")
126+
else:
127+
print(f"Removing backup '{input_backup_file}' as original '{input_file}' should be present.")
128+
try:
129+
input_backup_file.unlink()
130+
except OSError as e:
131+
print(f"Warning: Could not remove backup {input_backup_file}: {e}")
132+
133+
if __name__ == "__main__":
134+
main()

utils/convert-ms-to-gguf-bitnet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,9 @@ def write_all(
14171417

14181418
of = OutputFile(fname_out, endianess=endianess)
14191419

1420+
if 'bitnet' in of.gguf.arch:
1421+
svocab.chat_template = "{% for message in messages %}{% if loop.first %}{{ bos_token }}{% endif %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'] + '\\n\\nBITNETAssistant: ' + eos_token }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token }}{% endif %}{% endfor %}"
1422+
14201423
# meta data
14211424
of.add_meta_arch(params)
14221425
if isinstance(vocab, Vocab):

utils/convert.sh

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/bin/bash
2+
3+
if [ -z "$1" ]; then
4+
echo "Usage: $0 <model-directory>"
5+
exit 1
6+
fi
7+
8+
MODEL_DIR=$(realpath "$1")
9+
10+
PREPROCESS_SCRIPT="./utils/preprocess-safetensors.py"
11+
CONVERT_SCRIPT="./utils/convert-ms-to-gguf-bitnet.py"
12+
13+
INPUT_FILE="$MODEL_DIR/model.safetensors"
14+
OUTPUT_FILE="$MODEL_DIR/model.safetensors"
15+
16+
echo "Preprocessing safetensors..."
17+
mv "$INPUT_FILE" "${INPUT_FILE}.backup"
18+
python "$PREPROCESS_SCRIPT" --input "${INPUT_FILE}.backup" --output "$OUTPUT_FILE"
19+
20+
GGUF_F32_OUTPUT="$MODEL_DIR/ggml-model-f32-bitnet.gguf"
21+
22+
echo "Converting to GGUF (f32)..."
23+
python "$CONVERT_SCRIPT" "$MODEL_DIR" --vocab-type bpe --outtype f32 --concurrency 1 --outfile "$GGUF_F32_OUTPUT"
24+
25+
GGUF_I2S_OUTPUT="$MODEL_DIR/ggml-model-i2s-bitnet.gguf"
26+
27+
echo "Quantizing model to I2_S..."
28+
./build/bin/llama-quantize "$GGUF_F32_OUTPUT" "$GGUF_I2S_OUTPUT" I2_S 1
29+
30+
echo "Cleaning up intermediate files..."
31+
rm "$OUTPUT_FILE" "$GGUF_F32_OUTPUT"
32+
mv "${INPUT_FILE}.backup" "$INPUT_FILE"
33+
34+
echo "Convert successfully."
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from safetensors import safe_open
2+
from safetensors.torch import save_file
3+
import torch
4+
5+
def quant_weight_fp16(weight):
6+
weight = weight.to(torch.float)
7+
s = 1.0 / weight.abs().mean().clamp_(min=1e-5)
8+
new_weight = (weight * s).round().clamp(-1, 1) / s
9+
return new_weight
10+
11+
def quant_model(input, output):
12+
tensors = {}
13+
14+
with safe_open(input, framework='pt') as f:
15+
for name in f.keys():
16+
tensors[name] = f.get_tensor(name)
17+
18+
keyword_list = [
19+
'q_proj.weight',
20+
'k_proj.weight',
21+
'v_proj.weight',
22+
'o_proj.weight',
23+
'gate_proj.weight',
24+
'up_proj.weight',
25+
'down_proj.weight'
26+
]
27+
28+
if any(keyword in name for keyword in keyword_list):
29+
print(f'[INFO] Quantizing {name}')
30+
tensors[name] = quant_weight_fp16(tensors[name])
31+
32+
print(f'[INFO] Saving to {output}\nThis may take a while.')
33+
save_file(tensors, output)
34+
35+
36+
if __name__ == "__main__":
37+
import argparse
38+
parser = argparse.ArgumentParser(description="Convert Safetensors back to Torch .pth checkpoint")
39+
parser.add_argument(
40+
"--input", type=str, required=True,
41+
)
42+
parser.add_argument(
43+
"--output", type=str, required=True,
44+
)
45+
args = parser.parse_args()
46+
47+
quant_model(
48+
input=args.input,
49+
output=args.output,
50+
)

0 commit comments

Comments
 (0)