Skip to content

Commit 0e797c2

Browse files
llm : support Adept Persimmon 8B (#3410)
* Produces garbage output * wip: correct tensors up to RoPE * correct tensors thru RoPE * Correct outputs through masked & softmax'd KQ * fp32 works * Rename adept->persimmon * Produces correct outputs * clean up convert scripts * remove printing logic from ggml.c * remove prints from llama.cpp & fix merge * trivial cleanups * Add offload funcs * update conversion script to directly take adept artifacts rather than .saftensors file * Fix norm eps bug * Support sqr and concat on metal, persimmon-8b-q4 runs correctly * Small changes from review * Formatting changes * Minor changes to conversion script * Remove old script * Fix editorconfig formatting * Fix build * add overlooked offload code ggml-ci
1 parent 3a716b4 commit 0e797c2

File tree

5 files changed

+854
-76
lines changed

5 files changed

+854
-76
lines changed

convert-persimmon-to-gguf.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import torch
2+
import os
3+
from pprint import pprint
4+
import sys
5+
import argparse
6+
from pathlib import Path
7+
from sentencepiece import SentencePieceProcessor
8+
if 'NO_LOCAL_GGUF' not in os.environ:
9+
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
10+
import gguf
11+
12+
def _flatten_dict(dct, tensors, prefix=None):
13+
assert isinstance(dct, dict)
14+
for key in dct.keys():
15+
new_prefix = prefix + '.' + key if prefix is not None else key
16+
if isinstance(dct[key], torch.Tensor):
17+
tensors[new_prefix] = dct[key]
18+
elif isinstance(dct[key], dict):
19+
_flatten_dict(dct[key], tensors, new_prefix)
20+
else:
21+
raise ValueError(type(dct[key]))
22+
return None
23+
24+
def _get_sentencepiece_tokenizer_info(dir_model: Path):
25+
tokenizer_path = dir_model / 'adept_vocab.model'
26+
print('gguf: getting sentencepiece tokenizer from', tokenizer_path)
27+
tokenizer = SentencePieceProcessor(str(tokenizer_path))
28+
print('gguf: adding tokens')
29+
tokens: list[bytes] = []
30+
scores: list[float] = []
31+
toktypes: list[int] = []
32+
33+
for i in range(tokenizer.vocab_size()):
34+
text: bytes
35+
score: float
36+
37+
piece = tokenizer.id_to_piece(i)
38+
text = piece.encode("utf-8")
39+
score = tokenizer.get_score(i)
40+
41+
toktype = 1
42+
if tokenizer.is_unknown(i):
43+
toktype = 2
44+
if tokenizer.is_control(i):
45+
toktype = 3
46+
if tokenizer.is_unused(i):
47+
toktype = 5
48+
if tokenizer.is_byte(i):
49+
toktype = 6
50+
51+
tokens.append(text)
52+
scores.append(score)
53+
toktypes.append(toktype)
54+
pass
55+
return tokens, scores, toktypes
56+
57+
def main():
58+
parser = argparse.ArgumentParser(description="Convert a Persimmon model from Adept (e.g. Persimmon 8b chat) to a GGML compatible file")
59+
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
60+
parser.add_argument("--ckpt-path", type=Path, help="path to persimmon checkpoint .pt file")
61+
parser.add_argument("--model-dir", type=Path, help="directory containing model e.g. 8b_chat_model_release")
62+
parser.add_argument("--adept-inference-dir", type=str, help="path to adept-inference code directory")
63+
args = parser.parse_args()
64+
sys.path.append(str(args.adept_inference_dir))
65+
persimmon_model = torch.load(args.ckpt_path)
66+
hparams = persimmon_model['args']
67+
pprint(hparams)
68+
tensors = {}
69+
_flatten_dict(persimmon_model['model'], tensors, None)
70+
71+
arch = gguf.MODEL_ARCH.PERSIMMON
72+
gguf_writer = gguf.GGUFWriter(args.outfile, gguf.MODEL_ARCH_NAMES[arch])
73+
74+
block_count = hparams.num_layers
75+
head_count = hparams.num_attention_heads
76+
head_count_kv = head_count
77+
ctx_length = hparams.seq_length
78+
hidden_size = hparams.hidden_size
79+
80+
gguf_writer.add_name('persimmon-8b-chat')
81+
gguf_writer.add_context_length(ctx_length)
82+
gguf_writer.add_embedding_length(hidden_size)
83+
gguf_writer.add_block_count(block_count)
84+
gguf_writer.add_feed_forward_length(hparams.ffn_hidden_size)
85+
gguf_writer.add_rope_dimension_count(hidden_size // head_count)
86+
gguf_writer.add_head_count(head_count)
87+
gguf_writer.add_head_count_kv(head_count_kv)
88+
gguf_writer.add_rope_freq_base(hparams.rotary_emb_base)
89+
gguf_writer.add_layer_norm_eps(hparams.layernorm_epsilon)
90+
91+
tokens, scores, toktypes = _get_sentencepiece_tokenizer_info(args.model_dir)
92+
gguf_writer.add_tokenizer_model('llama')
93+
gguf_writer.add_token_list(tokens)
94+
gguf_writer.add_token_scores(scores)
95+
gguf_writer.add_token_types(toktypes)
96+
gguf_writer.add_bos_token_id(71013)
97+
gguf_writer.add_eos_token_id(71013)
98+
99+
tensor_map = gguf.get_tensor_name_map(arch, block_count)
100+
print(tensor_map)
101+
for name in tensors.keys():
102+
data = tensors[name]
103+
if name.endswith(".self_attention.rotary_emb.inv_freq"):
104+
continue
105+
old_dtype = data.dtype
106+
# TODO: FP16 conversion produces garbage outputs. (Q8_0 does not, so..?)
107+
data = data.to(torch.float32).squeeze().numpy()
108+
new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias"))
109+
if new_name is None:
110+
print("Can not map tensor '" + name + "'")
111+
sys.exit()
112+
n_dims = len(data.shape)
113+
print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype))
114+
gguf_writer.add_tensor(new_name, data)
115+
print("gguf: write header")
116+
gguf_writer.write_header_to_file()
117+
print("gguf: write metadata")
118+
gguf_writer.write_kv_data_to_file()
119+
print("gguf: write tensors")
120+
gguf_writer.write_tensors_to_file()
121+
122+
gguf_writer.close()
123+
124+
print(f"gguf: model successfully exported to '{args.outfile}'")
125+
print("")
126+
127+
128+
129+
if __name__ == '__main__':
130+
main()

ggml-metal.m

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@
109109
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
110110
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
111111
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
112+
GGML_METAL_DECL_KERNEL(concat);
113+
GGML_METAL_DECL_KERNEL(sqr);
112114

113115
#undef GGML_METAL_DECL_KERNEL
114116
};
@@ -300,6 +302,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
300302
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
301303
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
302304
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
305+
GGML_METAL_ADD_KERNEL(concat);
306+
GGML_METAL_ADD_KERNEL(sqr);
303307

304308
#undef GGML_METAL_ADD_KERNEL
305309
}
@@ -375,6 +379,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
375379
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
376380
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
377381
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
382+
GGML_METAL_DEL_KERNEL(concat);
383+
GGML_METAL_DEL_KERNEL(sqr);
378384

379385
#undef GGML_METAL_DEL_KERNEL
380386

@@ -766,6 +772,43 @@ void ggml_metal_graph_compute(
766772
{
767773
// noop
768774
} break;
775+
case GGML_OP_CONCAT:
776+
{
777+
778+
int64_t nb = ne00;
779+
[encoder setComputePipelineState:ctx->pipeline_concat];
780+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
781+
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
782+
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
783+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
784+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
785+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
786+
[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
787+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
788+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
789+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
790+
[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
791+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
792+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
793+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
794+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
795+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
796+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
797+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
798+
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
799+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
800+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
801+
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
802+
[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
803+
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
804+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
805+
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
806+
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
807+
[encoder setBytes:&nb length:sizeof(nb) atIndex:27];
808+
809+
const int nth = MIN(1024, ne0);
810+
[encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
811+
} break;
769812
case GGML_OP_ADD:
770813
{
771814
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -903,6 +946,17 @@ void ggml_metal_graph_compute(
903946
GGML_ASSERT(false);
904947
}
905948
} break;
949+
case GGML_OP_SQR:
950+
{
951+
GGML_ASSERT(ggml_is_contiguous(src0));
952+
953+
[encoder setComputePipelineState:ctx->pipeline_sqr];
954+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
955+
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
956+
957+
const int64_t n = ggml_nelements(dst);
958+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
959+
} break;
906960
case GGML_OP_SOFT_MAX:
907961
{
908962
const int nth = MIN(32, ne00);

ggml-metal.metal

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ kernel void kernel_relu(
132132
dst[tpig] = max(0.0f, src0[tpig]);
133133
}
134134

135+
kernel void kernel_sqr(
136+
device const float * src0,
137+
device float * dst,
138+
uint tpig[[thread_position_in_grid]]) {
139+
dst[tpig] = src0[tpig] * src0[tpig];
140+
}
141+
135142
constant float GELU_COEF_A = 0.044715f;
136143
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
137144

@@ -1098,6 +1105,62 @@ kernel void kernel_cpy_f32_f32(
10981105
}
10991106
}
11001107

1108+
kernel void kernel_concat(
1109+
device const char * src0,
1110+
device const char * src1,
1111+
device char * dst,
1112+
constant int64_t & ne00,
1113+
constant int64_t & ne01,
1114+
constant int64_t & ne02,
1115+
constant int64_t & ne03,
1116+
constant uint64_t & nb00,
1117+
constant uint64_t & nb01,
1118+
constant uint64_t & nb02,
1119+
constant uint64_t & nb03,
1120+
constant int64_t & ne10,
1121+
constant int64_t & ne11,
1122+
constant int64_t & ne12,
1123+
constant int64_t & ne13,
1124+
constant uint64_t & nb10,
1125+
constant uint64_t & nb11,
1126+
constant uint64_t & nb12,
1127+
constant uint64_t & nb13,
1128+
constant int64_t & ne0,
1129+
constant int64_t & ne1,
1130+
constant int64_t & ne2,
1131+
constant int64_t & ne3,
1132+
constant uint64_t & nb0,
1133+
constant uint64_t & nb1,
1134+
constant uint64_t & nb2,
1135+
constant uint64_t & nb3,
1136+
uint3 tgpig[[threadgroup_position_in_grid]],
1137+
uint3 tpitg[[thread_position_in_threadgroup]],
1138+
uint3 ntg[[threads_per_threadgroup]]) {
1139+
1140+
const int64_t i03 = tgpig.z;
1141+
const int64_t i02 = tgpig.y;
1142+
const int64_t i01 = tgpig.x;
1143+
1144+
const int64_t i13 = i03 % ne13;
1145+
const int64_t i12 = i02 % ne12;
1146+
const int64_t i11 = i01 % ne11;
1147+
1148+
device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
1149+
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
1150+
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
1151+
1152+
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
1153+
if (i02 < ne02) {
1154+
((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
1155+
src0_ptr += ntg.x*nb00;
1156+
} else {
1157+
((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
1158+
src1_ptr += ntg.x*nb10;
1159+
}
1160+
dst_ptr += ntg.x*nb0;
1161+
}
1162+
}
1163+
11011164
//============================================ k-quants ======================================================
11021165

11031166
#ifndef QK_K

0 commit comments

Comments
 (0)