Skip to content

Commit 0722ad1

Browse files
committed
llama: support dbrx
#6344
1 parent 1d8de31 commit 0722ad1

File tree

3 files changed

+259
-2
lines changed

3 files changed

+259
-2
lines changed

convert-hf-to-gguf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1438,7 +1438,7 @@ def set_gguf_parameters(self):
14381438
self.gguf_writer.add_head_count(self.hparams["n_heads"])
14391439
self.gguf_writer.add_head_count_kv(attn_config["kv_n_heads"])
14401440
self.gguf_writer.add_rope_freq_base(attn_config["rope_theta"])
1441-
self.gguf_writer.add_clip_kqv(attn_config["clip_qkv"])
1441+
self.gguf_writer.add_clamp_kqv(attn_config["clip_qkv"])
14421442
self.gguf_writer.add_file_type(self.ftype)
14431443

14441444
self.gguf_writer.add_expert_count(ffn_config["moe_num_experts"])

gguf-py/gguf/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class Attention:
5454
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
5555
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
5656
CAUSAL = "{arch}.attention.causal"
57-
CLIP_KQV = "{arch}.attention.clip_kqv"
5857

5958
class Rope:
6059
DIMENSION_COUNT = "{arch}.rope.dimension_count"

llama.cpp

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ enum llm_arch {
220220
LLM_ARCH_MAMBA,
221221
LLM_ARCH_XVERSE,
222222
LLM_ARCH_COMMAND_R,
223+
LLM_ARCH_DBRX,
223224
LLM_ARCH_UNKNOWN,
224225
};
225226

@@ -252,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
252253
{ LLM_ARCH_MAMBA, "mamba" },
253254
{ LLM_ARCH_XVERSE, "xverse" },
254255
{ LLM_ARCH_COMMAND_R, "command-r" },
256+
{ LLM_ARCH_DBRX, "dbrx" },
255257
{ LLM_ARCH_UNKNOWN, "(unknown)" },
256258
};
257259

@@ -926,6 +928,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
926928
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
927929
},
928930
},
931+
{
932+
LLM_ARCH_DBRX,
933+
{
934+
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
935+
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
936+
{ LLM_TENSOR_OUTPUT, "output" },
937+
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
938+
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
939+
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
940+
{ LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
941+
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
942+
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
943+
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
944+
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
945+
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
946+
},
947+
},
929948
{
930949
LLM_ARCH_UNKNOWN,
931950
{
@@ -1692,6 +1711,7 @@ enum e_model {
16921711
MODEL_40B,
16931712
MODEL_65B,
16941713
MODEL_70B,
1714+
MODEL_132B,
16951715
MODEL_314B,
16961716
MODEL_SMALL,
16971717
MODEL_MEDIUM,
@@ -3961,6 +3981,15 @@ static void llm_load_hparams(
39613981
default: model.type = e_model::MODEL_UNKNOWN;
39623982
}
39633983
} break;
3984+
case LLM_ARCH_DBRX:
3985+
{
3986+
ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
3987+
3988+
switch (hparams.n_layer) {
3989+
case 40: model.type = e_model::MODEL_132B; break;
3990+
default: model.type = e_model::MODEL_UNKNOWN;
3991+
}
3992+
} break;
39643993
default: (void)0;
39653994
}
39663995

@@ -4635,6 +4664,46 @@ static bool llm_load_tensors(
46354664
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
46364665
}
46374666
} break;
4667+
case LLM_ARCH_DBRX:
4668+
{
4669+
if (n_expert == 0) {
4670+
throw std::runtime_error("DBRX model cannot have zero experts");
4671+
}
4672+
4673+
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4674+
4675+
// output
4676+
{
4677+
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4678+
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
4679+
// if output is NULL, init from the input tok embed
4680+
if (model.output == NULL) {
4681+
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4682+
ml.n_created--; // artificial tensor
4683+
ml.size_data += ggml_nbytes(model.output);
4684+
}
4685+
}
4686+
4687+
for (int i = 0; i < n_layer; ++i) {
4688+
ggml_context * ctx_layer = ctx_for_layer(i);
4689+
ggml_context * ctx_split = ctx_for_layer_split(i);
4690+
4691+
auto & layer = model.layers[i];
4692+
4693+
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
4694+
layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2,"weight", i), {n_embd});
4695+
4696+
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd});
4697+
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
4698+
4699+
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
4700+
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false);
4701+
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert});
4702+
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
4703+
4704+
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
4705+
}
4706+
} break;
46384707
case LLM_ARCH_BAICHUAN:
46394708
{
46404709
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -7030,6 +7099,190 @@ struct llm_build_context {
70307099
return gf;
70317100
}
70327101

7102+
struct ggml_cgraph * build_dbrx() {
7103+
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
7104+
7105+
// mutable variable, needed during the last layer of the computation to skip unused tokens
7106+
int32_t n_tokens = this->n_tokens;
7107+
7108+
const int64_t n_embd_head = hparams.n_embd_head_v;
7109+
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7110+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
7111+
GGML_ASSERT(n_embd_head == hparams.n_rot);
7112+
7113+
struct ggml_tensor * cur;
7114+
struct ggml_tensor * inpL;
7115+
7116+
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
7117+
7118+
// multiply by embedding_multiplier_scale of 78.38367176906169
7119+
inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
7120+
7121+
// inp_pos - contains the positions
7122+
struct ggml_tensor * inp_pos = build_inp_pos();
7123+
7124+
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7125+
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7126+
7127+
for (int il = 0; il < n_layer; ++il) {
7128+
struct ggml_tensor * inpSA = inpL;
7129+
7130+
// norm
7131+
cur = llm_build_norm(ctx0, inpL, hparams,
7132+
model.layers[il].attn_norm, NULL,
7133+
LLM_NORM_RMS, cb, il);
7134+
cb(cur, "attn_norm", il);
7135+
7136+
7137+
// self-attention
7138+
{
7139+
if (model.layers[il].attn_norm_2) {
7140+
// DBRX
7141+
cur = llm_build_norm(ctx0, inpL, hparams,
7142+
model.layers[il].attn_norm_2,
7143+
NULL,
7144+
LLM_NORM, cb, il);
7145+
cb(cur, "attn_norm_2", il);
7146+
}
7147+
7148+
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
7149+
cb(cur, "wqkv", il);
7150+
7151+
if (hparams.f_clamp_kqv > 0.0f) {
7152+
cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
7153+
cb(cur, "wqkv_clamped", il);
7154+
}
7155+
7156+
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
7157+
Qcur = ggml_rope_custom(
7158+
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
7159+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
7160+
ext_factor, attn_factor, beta_fast, beta_slow
7161+
);
7162+
cb(Qcur, "Qcur", il);
7163+
7164+
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
7165+
Kcur = ggml_rope_custom(
7166+
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
7167+
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
7168+
ext_factor, attn_factor, beta_fast, beta_slow
7169+
);
7170+
cb(Kcur, "Kcur", il);
7171+
7172+
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
7173+
cb(Vcur, "Vcur", il);
7174+
7175+
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7176+
model.layers[il].wo, model.layers[il].bo,
7177+
Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
7178+
}
7179+
7180+
if (il == n_layer - 1) {
7181+
// skip computing output for unused tokens
7182+
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7183+
n_tokens = n_outputs;
7184+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7185+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7186+
}
7187+
7188+
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
7189+
cb(ffn_inp, "ffn_inp", il);
7190+
7191+
// feed-forward network
7192+
// MoE branch
7193+
cur = llm_build_norm(ctx0, ffn_inp, hparams,
7194+
model.layers[il].ffn_norm, NULL,
7195+
LLM_NORM_RMS, cb, il);
7196+
cb(cur, "ffn_norm", il);
7197+
7198+
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
7199+
cb(logits, "ffn_moe_logits", il);
7200+
7201+
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
7202+
cb(probs, "ffn_moe_probs", il);
7203+
7204+
// select experts
7205+
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
7206+
cb(selected_experts->src[0], "ffn_moe_argsort", il);
7207+
7208+
ggml_tensor * weights = ggml_get_rows(ctx0,
7209+
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
7210+
cb(weights, "ffn_moe_weights", il);
7211+
7212+
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
7213+
7214+
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
7215+
cb(weights_sum, "ffn_moe_weights_sum", il);
7216+
7217+
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
7218+
cb(weights, "ffn_moe_weights_norm", il);
7219+
7220+
// compute expert outputs
7221+
ggml_tensor * moe_out = nullptr;
7222+
7223+
for (int i = 0; i < n_expert_used; ++i) {
7224+
ggml_tensor * cur_expert;
7225+
7226+
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
7227+
cb(cur_up, "ffn_moe_up", il);
7228+
7229+
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
7230+
cb(cur_gate, "ffn_moe_gate", il);
7231+
7232+
//GeLU
7233+
cur_gate = ggml_gelu(ctx0, cur_gate);
7234+
cb(cur_gate, "ffn_moe_gelu", il);
7235+
7236+
cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
7237+
cb(cur_expert, "ffn_moe_gate_par", il);
7238+
7239+
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
7240+
cb(cur_expert, "ffn_moe_down", il);
7241+
7242+
cur_expert = ggml_mul(ctx0, cur_expert,
7243+
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
7244+
cb(cur_expert, "ffn_moe_weighted", il);
7245+
7246+
if (i == 0) {
7247+
moe_out = cur_expert;
7248+
} else {
7249+
moe_out = ggml_add(ctx0, moe_out, cur_expert);
7250+
cb(moe_out, "ffn_moe_out", il);
7251+
}
7252+
}
7253+
7254+
cur = moe_out;
7255+
7256+
cur = ggml_add(ctx0, cur, ffn_inp);
7257+
cb(cur, "ffn_out", il);
7258+
7259+
ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
7260+
if (layer_dir != nullptr) {
7261+
cur = ggml_add(ctx0, cur, layer_dir);
7262+
}
7263+
cb(cur, "l_out", il);
7264+
7265+
// input for next layer
7266+
inpL = cur;
7267+
}
7268+
7269+
cur = inpL;
7270+
7271+
cur = llm_build_norm(ctx0, cur, hparams,
7272+
model.output_norm, NULL,
7273+
LLM_NORM_RMS, cb, -1);
7274+
cb(cur, "result_norm", -1);
7275+
7276+
// lm_head
7277+
cur = ggml_mul_mat(ctx0, model.output, cur);
7278+
7279+
cb(cur, "result_output", -1);
7280+
7281+
ggml_build_forward_expand(gf, cur);
7282+
7283+
return gf;
7284+
}
7285+
70337286
struct ggml_cgraph * build_starcoder() {
70347287
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
70357288

@@ -9719,6 +9972,10 @@ static struct ggml_cgraph * llama_build_graph(
97199972
{
97209973
result = llm.build_command_r();
97219974
} break;
9975+
case LLM_ARCH_DBRX:
9976+
{
9977+
result = llm.build_dbrx();
9978+
} break;
97229979
default:
97239980
GGML_ASSERT(false);
97249981
}
@@ -14525,6 +14782,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
1452514782
case LLM_ARCH_MINICPM:
1452614783
case LLM_ARCH_XVERSE:
1452714784
case LLM_ARCH_COMMAND_R:
14785+
case LLM_ARCH_DBRX: // FIXME REVIEW @ggerganov I am not sure what to put here
1452814786
return LLAMA_ROPE_TYPE_NORM;
1452914787

1453014788
// the pairs of head values are offset by n_rot/2

0 commit comments

Comments
 (0)