@@ -220,6 +220,7 @@ enum llm_arch {
220
220
LLM_ARCH_MAMBA,
221
221
LLM_ARCH_XVERSE,
222
222
LLM_ARCH_COMMAND_R,
223
+ LLM_ARCH_DBRX,
223
224
LLM_ARCH_UNKNOWN,
224
225
};
225
226
@@ -252,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
252
253
{ LLM_ARCH_MAMBA, "mamba" },
253
254
{ LLM_ARCH_XVERSE, "xverse" },
254
255
{ LLM_ARCH_COMMAND_R, "command-r" },
256
+ { LLM_ARCH_DBRX, "dbrx" },
255
257
{ LLM_ARCH_UNKNOWN, "(unknown)" },
256
258
};
257
259
@@ -926,6 +928,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
926
928
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
927
929
},
928
930
},
931
+ {
932
+ LLM_ARCH_DBRX,
933
+ {
934
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
935
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
936
+ { LLM_TENSOR_OUTPUT, "output" },
937
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
938
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
939
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
940
+ { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
941
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
942
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
943
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
944
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
945
+ { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
946
+ },
947
+ },
929
948
{
930
949
LLM_ARCH_UNKNOWN,
931
950
{
@@ -1692,6 +1711,7 @@ enum e_model {
1692
1711
MODEL_40B,
1693
1712
MODEL_65B,
1694
1713
MODEL_70B,
1714
+ MODEL_132B,
1695
1715
MODEL_314B,
1696
1716
MODEL_SMALL,
1697
1717
MODEL_MEDIUM,
@@ -3961,6 +3981,15 @@ static void llm_load_hparams(
3961
3981
default: model.type = e_model::MODEL_UNKNOWN;
3962
3982
}
3963
3983
} break;
3984
+ case LLM_ARCH_DBRX:
3985
+ {
3986
+ ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
3987
+
3988
+ switch (hparams.n_layer) {
3989
+ case 40: model.type = e_model::MODEL_132B; break;
3990
+ default: model.type = e_model::MODEL_UNKNOWN;
3991
+ }
3992
+ } break;
3964
3993
default: (void)0;
3965
3994
}
3966
3995
@@ -4635,6 +4664,46 @@ static bool llm_load_tensors(
4635
4664
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
4636
4665
}
4637
4666
} break;
4667
+ case LLM_ARCH_DBRX:
4668
+ {
4669
+ if (n_expert == 0) {
4670
+ throw std::runtime_error("DBRX model cannot have zero experts");
4671
+ }
4672
+
4673
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4674
+
4675
+ // output
4676
+ {
4677
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4678
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
4679
+ // if output is NULL, init from the input tok embed
4680
+ if (model.output == NULL) {
4681
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4682
+ ml.n_created--; // artificial tensor
4683
+ ml.size_data += ggml_nbytes(model.output);
4684
+ }
4685
+ }
4686
+
4687
+ for (int i = 0; i < n_layer; ++i) {
4688
+ ggml_context * ctx_layer = ctx_for_layer(i);
4689
+ ggml_context * ctx_split = ctx_for_layer_split(i);
4690
+
4691
+ auto & layer = model.layers[i];
4692
+
4693
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
4694
+ layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2,"weight", i), {n_embd});
4695
+
4696
+ layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd});
4697
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
4698
+
4699
+ layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
4700
+ layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false);
4701
+ layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert});
4702
+ layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
4703
+
4704
+ layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
4705
+ }
4706
+ } break;
4638
4707
case LLM_ARCH_BAICHUAN:
4639
4708
{
4640
4709
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -7030,6 +7099,190 @@ struct llm_build_context {
7030
7099
return gf;
7031
7100
}
7032
7101
7102
+ struct ggml_cgraph * build_dbrx() {
7103
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
7104
+
7105
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
7106
+ int32_t n_tokens = this->n_tokens;
7107
+
7108
+ const int64_t n_embd_head = hparams.n_embd_head_v;
7109
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7110
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
7111
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
7112
+
7113
+ struct ggml_tensor * cur;
7114
+ struct ggml_tensor * inpL;
7115
+
7116
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
7117
+
7118
+ // multiply by embedding_multiplier_scale of 78.38367176906169
7119
+ inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
7120
+
7121
+ // inp_pos - contains the positions
7122
+ struct ggml_tensor * inp_pos = build_inp_pos();
7123
+
7124
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7125
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7126
+
7127
+ for (int il = 0; il < n_layer; ++il) {
7128
+ struct ggml_tensor * inpSA = inpL;
7129
+
7130
+ // norm
7131
+ cur = llm_build_norm(ctx0, inpL, hparams,
7132
+ model.layers[il].attn_norm, NULL,
7133
+ LLM_NORM_RMS, cb, il);
7134
+ cb(cur, "attn_norm", il);
7135
+
7136
+
7137
+ // self-attention
7138
+ {
7139
+ if (model.layers[il].attn_norm_2) {
7140
+ // DBRX
7141
+ cur = llm_build_norm(ctx0, inpL, hparams,
7142
+ model.layers[il].attn_norm_2,
7143
+ NULL,
7144
+ LLM_NORM, cb, il);
7145
+ cb(cur, "attn_norm_2", il);
7146
+ }
7147
+
7148
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
7149
+ cb(cur, "wqkv", il);
7150
+
7151
+ if (hparams.f_clamp_kqv > 0.0f) {
7152
+ cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
7153
+ cb(cur, "wqkv_clamped", il);
7154
+ }
7155
+
7156
+ struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
7157
+ Qcur = ggml_rope_custom(
7158
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
7159
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
7160
+ ext_factor, attn_factor, beta_fast, beta_slow
7161
+ );
7162
+ cb(Qcur, "Qcur", il);
7163
+
7164
+ struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
7165
+ Kcur = ggml_rope_custom(
7166
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
7167
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
7168
+ ext_factor, attn_factor, beta_fast, beta_slow
7169
+ );
7170
+ cb(Kcur, "Kcur", il);
7171
+
7172
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
7173
+ cb(Vcur, "Vcur", il);
7174
+
7175
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7176
+ model.layers[il].wo, model.layers[il].bo,
7177
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
7178
+ }
7179
+
7180
+ if (il == n_layer - 1) {
7181
+ // skip computing output for unused tokens
7182
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7183
+ n_tokens = n_outputs;
7184
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7185
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7186
+ }
7187
+
7188
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
7189
+ cb(ffn_inp, "ffn_inp", il);
7190
+
7191
+ // feed-forward network
7192
+ // MoE branch
7193
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
7194
+ model.layers[il].ffn_norm, NULL,
7195
+ LLM_NORM_RMS, cb, il);
7196
+ cb(cur, "ffn_norm", il);
7197
+
7198
+ ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
7199
+ cb(logits, "ffn_moe_logits", il);
7200
+
7201
+ ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
7202
+ cb(probs, "ffn_moe_probs", il);
7203
+
7204
+ // select experts
7205
+ ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
7206
+ cb(selected_experts->src[0], "ffn_moe_argsort", il);
7207
+
7208
+ ggml_tensor * weights = ggml_get_rows(ctx0,
7209
+ ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
7210
+ cb(weights, "ffn_moe_weights", il);
7211
+
7212
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
7213
+
7214
+ ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
7215
+ cb(weights_sum, "ffn_moe_weights_sum", il);
7216
+
7217
+ weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
7218
+ cb(weights, "ffn_moe_weights_norm", il);
7219
+
7220
+ // compute expert outputs
7221
+ ggml_tensor * moe_out = nullptr;
7222
+
7223
+ for (int i = 0; i < n_expert_used; ++i) {
7224
+ ggml_tensor * cur_expert;
7225
+
7226
+ ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
7227
+ cb(cur_up, "ffn_moe_up", il);
7228
+
7229
+ ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
7230
+ cb(cur_gate, "ffn_moe_gate", il);
7231
+
7232
+ //GeLU
7233
+ cur_gate = ggml_gelu(ctx0, cur_gate);
7234
+ cb(cur_gate, "ffn_moe_gelu", il);
7235
+
7236
+ cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
7237
+ cb(cur_expert, "ffn_moe_gate_par", il);
7238
+
7239
+ cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
7240
+ cb(cur_expert, "ffn_moe_down", il);
7241
+
7242
+ cur_expert = ggml_mul(ctx0, cur_expert,
7243
+ ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
7244
+ cb(cur_expert, "ffn_moe_weighted", il);
7245
+
7246
+ if (i == 0) {
7247
+ moe_out = cur_expert;
7248
+ } else {
7249
+ moe_out = ggml_add(ctx0, moe_out, cur_expert);
7250
+ cb(moe_out, "ffn_moe_out", il);
7251
+ }
7252
+ }
7253
+
7254
+ cur = moe_out;
7255
+
7256
+ cur = ggml_add(ctx0, cur, ffn_inp);
7257
+ cb(cur, "ffn_out", il);
7258
+
7259
+ ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
7260
+ if (layer_dir != nullptr) {
7261
+ cur = ggml_add(ctx0, cur, layer_dir);
7262
+ }
7263
+ cb(cur, "l_out", il);
7264
+
7265
+ // input for next layer
7266
+ inpL = cur;
7267
+ }
7268
+
7269
+ cur = inpL;
7270
+
7271
+ cur = llm_build_norm(ctx0, cur, hparams,
7272
+ model.output_norm, NULL,
7273
+ LLM_NORM_RMS, cb, -1);
7274
+ cb(cur, "result_norm", -1);
7275
+
7276
+ // lm_head
7277
+ cur = ggml_mul_mat(ctx0, model.output, cur);
7278
+
7279
+ cb(cur, "result_output", -1);
7280
+
7281
+ ggml_build_forward_expand(gf, cur);
7282
+
7283
+ return gf;
7284
+ }
7285
+
7033
7286
struct ggml_cgraph * build_starcoder() {
7034
7287
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
7035
7288
@@ -9719,6 +9972,10 @@ static struct ggml_cgraph * llama_build_graph(
9719
9972
{
9720
9973
result = llm.build_command_r();
9721
9974
} break;
9975
+ case LLM_ARCH_DBRX:
9976
+ {
9977
+ result = llm.build_dbrx();
9978
+ } break;
9722
9979
default:
9723
9980
GGML_ASSERT(false);
9724
9981
}
@@ -14525,6 +14782,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
14525
14782
case LLM_ARCH_MINICPM:
14526
14783
case LLM_ARCH_XVERSE:
14527
14784
case LLM_ARCH_COMMAND_R:
14785
+ case LLM_ARCH_DBRX: // FIXME REVIEW @ggerganov I am not sure what to put here
14528
14786
return LLAMA_ROPE_TYPE_NORM;
14529
14787
14530
14788
// the pairs of head values are offset by n_rot/2
0 commit comments