Skip to content

Commit ee3dff6

Browse files
Add support for DeepseekV2ForCausalLM (#7519)
* common : increase max number of experts to 160 * common : add tensors ATTN_Q_A, ATTN_Q_A_NORM, ATTN_Q_B, ATTN_KV_A_MQA, ATTN_KV_A_NORM, ATTN_KV_B needed by DeepSeek-V2 MLA (multi-head latent attention) architecture * common : add model header parameters: leading_dense_block_count, expert_feed_forward_length, expert_shared_count, expert_weights_scale, attention.q_lora_rank, attention.kv_lora_rank, rope.scaling.yarn_log_multiplier * convert-hf : add model conversion support for DeepseekV2ForCausalLM * llama : add model types for DeepSeek-V2 and DeepSeek-V2-Lite models * llama : add two new llm_build_moe_ffn() arguments: scale_w (whether to scale weights of selected MoE experts) and w_scale (numerical value of the scaling factor) * llama : add inference support for LLM_ARCH_DEEPSEEK2 --------- Co-authored-by: Stanisław Szymczyk <[email protected]>
1 parent edc2943 commit ee3dff6

File tree

5 files changed

+599
-26
lines changed

5 files changed

+599
-26
lines changed

convert-hf-to-gguf.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2620,6 +2620,85 @@ def write_tensors(self):
26202620
raise ValueError(f"Unprocessed experts: {experts}")
26212621

26222622

2623+
@Model.register("DeepseekV2ForCausalLM")
2624+
class DeepseekV2Model(Model):
2625+
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
2626+
2627+
def set_vocab(self):
2628+
self._set_vocab_gpt2()
2629+
2630+
def set_gguf_parameters(self):
2631+
super().set_gguf_parameters()
2632+
hparams = self.hparams
2633+
2634+
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
2635+
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
2636+
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
2637+
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
2638+
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
2639+
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
2640+
self.gguf_writer.add_value_length(hparams["v_head_dim"])
2641+
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
2642+
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
2643+
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
2644+
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
2645+
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
2646+
2647+
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
2648+
if self.hparams["rope_scaling"].get("type") == "yarn":
2649+
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
2650+
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
2651+
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
2652+
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * hparams["rope_scaling"]["mscale_all_dim"])
2653+
2654+
_experts: list[dict[str, Tensor]] | None = None
2655+
2656+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
2657+
# process the experts separately
2658+
if name.find("mlp.experts") != -1:
2659+
n_experts = self.hparams["n_routed_experts"]
2660+
assert bid is not None
2661+
2662+
if self._experts is None:
2663+
self._experts = [{} for _ in range(self.block_count)]
2664+
2665+
self._experts[bid][name] = data_torch
2666+
2667+
if len(self._experts[bid]) >= n_experts * 3:
2668+
tensors: list[tuple[str, Tensor]] = []
2669+
2670+
# merge the experts into a single 3d tensor
2671+
for w_name in ["down_proj", "gate_proj", "up_proj"]:
2672+
datas: list[Tensor] = []
2673+
2674+
for xid in range(n_experts):
2675+
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
2676+
datas.append(self._experts[bid][ename])
2677+
del self._experts[bid][ename]
2678+
2679+
data_torch = torch.stack(datas, dim=0)
2680+
2681+
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
2682+
2683+
new_name = self.map_tensor_name(merged_name)
2684+
2685+
tensors.append((new_name, data_torch))
2686+
return tensors
2687+
else:
2688+
return []
2689+
2690+
return [(self.map_tensor_name(name), data_torch)]
2691+
2692+
def write_tensors(self):
2693+
super().write_tensors()
2694+
2695+
if self._experts is not None:
2696+
# flatten `list[dict[str, Tensor]]` into `list[str]`
2697+
experts = [k for d in self._experts for k in d.keys()]
2698+
if len(experts) > 0:
2699+
raise ValueError(f"Unprocessed experts: {experts}")
2700+
2701+
26232702
###### CONVERSION LOGIC ######
26242703

26252704

gguf-py/gguf/constants.py

Lines changed: 63 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,21 @@ class General:
3333
FILE_TYPE = "general.file_type"
3434

3535
class LLM:
36-
VOCAB_SIZE = "{arch}.vocab_size"
37-
CONTEXT_LENGTH = "{arch}.context_length"
38-
EMBEDDING_LENGTH = "{arch}.embedding_length"
39-
BLOCK_COUNT = "{arch}.block_count"
40-
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
41-
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
42-
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
43-
EXPERT_COUNT = "{arch}.expert_count"
44-
EXPERT_USED_COUNT = "{arch}.expert_used_count"
45-
POOLING_TYPE = "{arch}.pooling_type"
46-
LOGIT_SCALE = "{arch}.logit_scale"
36+
VOCAB_SIZE = "{arch}.vocab_size"
37+
CONTEXT_LENGTH = "{arch}.context_length"
38+
EMBEDDING_LENGTH = "{arch}.embedding_length"
39+
BLOCK_COUNT = "{arch}.block_count"
40+
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
41+
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
42+
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
43+
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
44+
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
45+
EXPERT_COUNT = "{arch}.expert_count"
46+
EXPERT_USED_COUNT = "{arch}.expert_used_count"
47+
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
48+
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
49+
POOLING_TYPE = "{arch}.pooling_type"
50+
LOGIT_SCALE = "{arch}.logit_scale"
4751

4852
class Attention:
4953
HEAD_COUNT = "{arch}.attention.head_count"
@@ -55,6 +59,8 @@ class Attention:
5559
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
5660
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
5761
CAUSAL = "{arch}.attention.causal"
62+
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
63+
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
5864

5965
class Rope:
6066
DIMENSION_COUNT = "{arch}.rope.dimension_count"
@@ -64,6 +70,7 @@ class Rope:
6470
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
6571
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
6672
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
73+
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
6774

6875
class SSM:
6976
CONV_KERNEL = "{arch}.ssm.conv_kernel"
@@ -140,6 +147,7 @@ class MODEL_ARCH(IntEnum):
140147
DBRX = auto()
141148
OLMO = auto()
142149
ARCTIC = auto()
150+
DEEPSEEK2 = auto()
143151

144152

145153
class MODEL_TENSOR(IntEnum):
@@ -185,6 +193,12 @@ class MODEL_TENSOR(IntEnum):
185193
SSM_A = auto()
186194
SSM_D = auto()
187195
SSM_OUT = auto()
196+
ATTN_Q_A = auto()
197+
ATTN_Q_B = auto()
198+
ATTN_KV_A_MQA = auto()
199+
ATTN_KV_B = auto()
200+
ATTN_Q_A_NORM = auto()
201+
ATTN_KV_A_NORM = auto()
188202

189203

190204
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -221,6 +235,7 @@ class MODEL_TENSOR(IntEnum):
221235
MODEL_ARCH.DBRX: "dbrx",
222236
MODEL_ARCH.OLMO: "olmo",
223237
MODEL_ARCH.ARCTIC: "arctic",
238+
MODEL_ARCH.DEEPSEEK2: "deepseek2",
224239
}
225240

226241
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -266,6 +281,12 @@ class MODEL_TENSOR(IntEnum):
266281
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
267282
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
268283
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
284+
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
285+
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
286+
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
287+
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
288+
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
289+
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
269290
}
270291

271292
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -757,6 +778,33 @@ class MODEL_TENSOR(IntEnum):
757778
MODEL_TENSOR.FFN_DOWN_EXP,
758779
MODEL_TENSOR.FFN_UP_EXP,
759780
],
781+
MODEL_ARCH.DEEPSEEK2: [
782+
MODEL_TENSOR.TOKEN_EMBD,
783+
MODEL_TENSOR.OUTPUT_NORM,
784+
MODEL_TENSOR.OUTPUT,
785+
MODEL_TENSOR.ROPE_FREQS,
786+
MODEL_TENSOR.ATTN_NORM,
787+
MODEL_TENSOR.ATTN_Q,
788+
MODEL_TENSOR.ATTN_Q_A,
789+
MODEL_TENSOR.ATTN_Q_B,
790+
MODEL_TENSOR.ATTN_KV_A_MQA,
791+
MODEL_TENSOR.ATTN_KV_B,
792+
MODEL_TENSOR.ATTN_Q_A_NORM,
793+
MODEL_TENSOR.ATTN_KV_A_NORM,
794+
MODEL_TENSOR.ATTN_OUT,
795+
MODEL_TENSOR.ATTN_ROT_EMBD,
796+
MODEL_TENSOR.FFN_GATE_INP,
797+
MODEL_TENSOR.FFN_NORM,
798+
MODEL_TENSOR.FFN_GATE,
799+
MODEL_TENSOR.FFN_DOWN,
800+
MODEL_TENSOR.FFN_UP,
801+
MODEL_TENSOR.FFN_GATE_EXP,
802+
MODEL_TENSOR.FFN_DOWN_EXP,
803+
MODEL_TENSOR.FFN_UP_EXP,
804+
MODEL_TENSOR.FFN_GATE_SHEXP,
805+
MODEL_TENSOR.FFN_DOWN_SHEXP,
806+
MODEL_TENSOR.FFN_UP_SHEXP,
807+
],
760808
# TODO
761809
}
762810

@@ -790,6 +838,10 @@ class MODEL_TENSOR(IntEnum):
790838
MODEL_TENSOR.ROPE_FREQS,
791839
MODEL_TENSOR.ATTN_ROT_EMBD,
792840
],
841+
MODEL_ARCH.DEEPSEEK2: [
842+
MODEL_TENSOR.ROPE_FREQS,
843+
MODEL_TENSOR.ATTN_ROT_EMBD,
844+
],
793845
}
794846

795847
#

gguf-py/gguf/gguf_writer.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,15 @@ def add_embedding_length(self, length: int) -> None:
374374
def add_block_count(self, length: int) -> None:
375375
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
376376

377+
def add_leading_dense_block_count(self, length: int) -> None:
378+
self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
379+
377380
def add_feed_forward_length(self, length: int) -> None:
378381
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
379382

383+
def add_expert_feed_forward_length(self, length: int) -> None:
384+
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
385+
380386
def add_parallel_residual(self, use: bool) -> None:
381387
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
382388

@@ -407,6 +413,12 @@ def add_expert_count(self, count: int) -> None:
407413
def add_expert_used_count(self, count: int) -> None:
408414
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
409415

416+
def add_expert_shared_count(self, count: int) -> None:
417+
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
418+
419+
def add_expert_weights_scale(self, value: float) -> None:
420+
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
421+
410422
def add_layer_norm_eps(self, value: float) -> None:
411423
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
412424

@@ -416,6 +428,12 @@ def add_layer_norm_rms_eps(self, value: float) -> None:
416428
def add_causal_attention(self, value: bool) -> None:
417429
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
418430

431+
def add_q_lora_rank(self, length: int) -> None:
432+
self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length)
433+
434+
def add_kv_lora_rank(self, length: int) -> None:
435+
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
436+
419437
def add_pooling_type(self, value: PoolingType) -> None:
420438
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
421439

@@ -440,6 +458,9 @@ def add_rope_scaling_orig_ctx_len(self, value: int) -> None:
440458
def add_rope_scaling_finetuned(self, value: bool) -> None:
441459
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
442460

461+
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
462+
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
463+
443464
def add_ssm_conv_kernel(self, value: int) -> None:
444465
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
445466

gguf-py/gguf/tensor_mapping.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ class TensorNameMap:
256256

257257
MODEL_TENSOR.FFN_UP_SHEXP: (
258258
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
259+
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
259260
),
260261

261262
# AWQ-activation gate
@@ -285,6 +286,7 @@ class TensorNameMap:
285286

286287
MODEL_TENSOR.FFN_GATE_SHEXP: (
287288
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
289+
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2
288290
),
289291

290292
# Feed-forward down
@@ -320,6 +322,7 @@ class TensorNameMap:
320322

321323
MODEL_TENSOR.FFN_DOWN_SHEXP: (
322324
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
325+
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2
323326
),
324327

325328
MODEL_TENSOR.ATTN_Q_NORM: (
@@ -383,6 +386,30 @@ class TensorNameMap:
383386
"model.layers.{bid}.out_proj",
384387
"backbone.layers.{bid}.mixer.out_proj",
385388
),
389+
390+
MODEL_TENSOR.ATTN_Q_A: (
391+
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
392+
),
393+
394+
MODEL_TENSOR.ATTN_Q_B: (
395+
"model.layers.{bid}.self_attn.q_b_proj", # deepseek2
396+
),
397+
398+
MODEL_TENSOR.ATTN_KV_A_MQA: (
399+
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
400+
),
401+
402+
MODEL_TENSOR.ATTN_KV_B: (
403+
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
404+
),
405+
406+
MODEL_TENSOR.ATTN_Q_A_NORM: (
407+
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
408+
),
409+
410+
MODEL_TENSOR.ATTN_KV_A_NORM: (
411+
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
412+
),
386413
}
387414

388415
# architecture-specific block mappings
@@ -415,7 +442,7 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
415442
if tensor not in MODEL_TENSORS[arch]:
416443
continue
417444
# TODO: make this configurable
418-
n_experts = 128
445+
n_experts = 160
419446
for xid in range(n_experts):
420447
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
421448
self.mapping[tensor_name] = (tensor, tensor_name)

0 commit comments

Comments
 (0)