Skip to content

Commit 1732737

Browse files
committed
convert : add phi-3 support
1 parent 4e96a81 commit 1732737

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

convert-hf-to-gguf.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def __init__(self, dir_model: Path, ftype: int, fname_out: Path, is_big_endian:
6262
def model_arch(self) -> gguf.MODEL_ARCH:
6363
pass
6464

65+
# TODO: add "default" argument
6566
def find_hparam(self, keys: Sequence[str], optional: bool = False) -> Any:
6667
key = next((k for k in keys if k in self.hparams), None)
6768
if key is not None:
@@ -89,7 +90,12 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
8990
yield name, data
9091

9192
def set_gguf_parameters(self):
92-
self.gguf_writer.add_name(self.dir_model.name)
93+
if (mtype := self.find_hparam(["model_type"], optional=True)) is not None:
94+
self.gguf_writer.add_name(mtype)
95+
print(f"gguf: model type = {mtype}")
96+
else:
97+
self.gguf_writer.add_name(self.dir_model.name)
98+
9399
self.gguf_writer.add_block_count(self.block_count)
94100

95101
if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx"], optional=True)) is not None:
@@ -363,6 +369,13 @@ def _set_vocab_sentencepiece(self):
363369
scores.append(-1000.0)
364370
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
365371

372+
# pad remaining tokens
373+
for i in range(vocab_size - len(tokens)):
374+
print(f"gguf: padding token {i}")
375+
tokens.append(f"[PAD{i}]")
376+
scores.append(-1000.0)
377+
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
378+
366379
assert len(tokens) == vocab_size
367380

368381
self.gguf_writer.add_tokenizer_model("llama")
@@ -1293,7 +1306,7 @@ def _stack_qk_norm(self, block_count, name, tensor_map, n_head, norms, n_dims, l
12931306
self.gguf_writer.add_tensor(new_name, data)
12941307

12951308

1296-
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
1309+
@Model.register("LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "Phi3ForCausalLM")
12971310
class LlamaModel(Model):
12981311
model_arch = gguf.MODEL_ARCH.LLAMA
12991312

@@ -1322,18 +1335,39 @@ def set_vocab(self):
13221335
def set_gguf_parameters(self):
13231336
super().set_gguf_parameters()
13241337
hparams = self.hparams
1338+
13251339
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
13261340
self.gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"])
13271341

13281342
# Same as super class, but permuting q_proj, k_proj
13291343
def write_tensors(self):
13301344
block_count = self.hparams.get("n_layers", self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")))
13311345
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
1346+
n_embd = self.hparams.get("hidden_size")
13321347
n_head = self.hparams.get("num_attention_heads")
13331348
n_kv_head = self.hparams.get("num_key_value_heads")
13341349
n_experts = self.hparams.get("num_local_experts")
13351350
experts = dict()
1336-
for name, data_torch in self.get_tensors():
1351+
1352+
head_dim = n_embd // n_head
1353+
1354+
tensors = dict(self.get_tensors())
1355+
for i in range(block_count):
1356+
# Phi-3 transformations
1357+
# ref: https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/8b29aca7bb785d6336fc19819b045bc7bc584b06/modeling_phi3.py#L379-L384
1358+
if (w := tensors.get(f"model.layers.{i}.self_attn.qkv_proj.weight")) is not None:
1359+
qpos = n_head * head_dim
1360+
tensors[f"model.layers.{i}.self_attn.q_proj.weight"] = w[:qpos]
1361+
tensors[f"model.layers.{i}.self_attn.k_proj.weight"] = w[qpos:qpos + n_kv_head * head_dim]
1362+
tensors[f"model.layers.{i}.self_attn.v_proj.weight"] = w[qpos + n_kv_head * head_dim:]
1363+
del tensors[f"model.layers.{i}.self_attn.qkv_proj.weight"]
1364+
if (w := tensors.get(f"model.layers.{i}.mlp.gate_up_proj.weight")) is not None:
1365+
ff_dim = w.shape[0] // 2
1366+
tensors[f"model.layers.{i}.mlp.gate_proj.weight"] = w[:ff_dim]
1367+
tensors[f"model.layers.{i}.mlp.up_proj.weight"] = w[ff_dim:]
1368+
del tensors[f"model.layers.{i}.mlp.gate_up_proj.weight"]
1369+
1370+
for name, data_torch in tensors.items():
13371371
# we don't need these
13381372
if name.endswith((".attention.masked_bias", ".attention.bias", ".attention.rotary_emb.inv_freq")):
13391373
continue

llama.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4352,6 +4352,7 @@ static void llm_load_vocab(
43524352
//vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
43534353
(t.first == "<|eot_id|>" ||
43544354
t.first == "<|im_end|>" ||
4355+
t.first == "<|end|>" ||
43554356
t.first == "<end_of_turn>"
43564357
)
43574358
) {

0 commit comments

Comments
 (0)