@@ -220,6 +220,7 @@ enum llm_arch {
220
220
LLM_ARCH_MAMBA,
221
221
LLM_ARCH_XVERSE,
222
222
LLM_ARCH_COMMAND_R,
223
+ LLM_ARCH_DBRX,
223
224
LLM_ARCH_UNKNOWN,
224
225
};
225
226
@@ -252,6 +253,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
252
253
{ LLM_ARCH_MAMBA, "mamba" },
253
254
{ LLM_ARCH_XVERSE, "xverse" },
254
255
{ LLM_ARCH_COMMAND_R, "command-r" },
256
+ { LLM_ARCH_DBRX, "dbrx" },
255
257
{ LLM_ARCH_UNKNOWN, "(unknown)" },
256
258
};
257
259
@@ -926,6 +928,23 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
926
928
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
927
929
},
928
930
},
931
+ {
932
+ LLM_ARCH_DBRX,
933
+ {
934
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
935
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
936
+ { LLM_TENSOR_OUTPUT, "output" },
937
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
938
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
939
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
940
+ { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" },
941
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
942
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
943
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
944
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
945
+ { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
946
+ },
947
+ },
929
948
{
930
949
LLM_ARCH_UNKNOWN,
931
950
{
@@ -1692,6 +1711,7 @@ enum e_model {
1692
1711
MODEL_40B,
1693
1712
MODEL_65B,
1694
1713
MODEL_70B,
1714
+ MODEL_132B,
1695
1715
MODEL_314B,
1696
1716
MODEL_SMALL,
1697
1717
MODEL_MEDIUM,
@@ -3961,6 +3981,15 @@ static void llm_load_hparams(
3961
3981
default: model.type = e_model::MODEL_UNKNOWN;
3962
3982
}
3963
3983
} break;
3984
+ case LLM_ARCH_DBRX:
3985
+ {
3986
+ ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv);
3987
+
3988
+ switch (hparams.n_layer) {
3989
+ case 40: model.type = e_model::MODEL_132B; break;
3990
+ default: model.type = e_model::MODEL_UNKNOWN;
3991
+ }
3992
+ } break;
3964
3993
default: (void)0;
3965
3994
}
3966
3995
@@ -4635,6 +4664,46 @@ static bool llm_load_tensors(
4635
4664
layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
4636
4665
}
4637
4666
} break;
4667
+ case LLM_ARCH_DBRX:
4668
+ {
4669
+ if (n_expert == 0) {
4670
+ throw std::runtime_error("DBRX model cannot have zero experts");
4671
+ }
4672
+
4673
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4674
+
4675
+ // output
4676
+ {
4677
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4678
+ model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, false);
4679
+ // if output is NULL, init from the input tok embed
4680
+ if (model.output == NULL) {
4681
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4682
+ ml.n_created--; // artificial tensor
4683
+ ml.size_data += ggml_nbytes(model.output);
4684
+ }
4685
+ }
4686
+
4687
+ for (int i = 0; i < n_layer; ++i) {
4688
+ ggml_context * ctx_layer = ctx_for_layer(i);
4689
+ ggml_context * ctx_split = ctx_for_layer_split(i);
4690
+
4691
+ auto & layer = model.layers[i];
4692
+
4693
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
4694
+ layer.attn_norm_2 = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM_2,"weight", i), {n_embd});
4695
+
4696
+ layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd});
4697
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
4698
+
4699
+ layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
4700
+ layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false);
4701
+ layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert});
4702
+ layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
4703
+
4704
+ layer.layer_out_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd});
4705
+ }
4706
+ } break;
4638
4707
case LLM_ARCH_BAICHUAN:
4639
4708
{
4640
4709
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
@@ -7030,6 +7099,187 @@ struct llm_build_context {
7030
7099
return gf;
7031
7100
}
7032
7101
7102
+ struct ggml_cgraph * build_dbrx() {
7103
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
7104
+
7105
+ // mutable variable, needed during the last layer of the computation to skip unused tokens
7106
+ int32_t n_tokens = this->n_tokens;
7107
+
7108
+ const int64_t n_embd_head = hparams.n_embd_head_v;
7109
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
7110
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
7111
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
7112
+
7113
+ struct ggml_tensor * cur;
7114
+ struct ggml_tensor * inpL;
7115
+
7116
+ inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
7117
+
7118
+ // inp_pos - contains the positions
7119
+ struct ggml_tensor * inp_pos = build_inp_pos();
7120
+
7121
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
7122
+ struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
7123
+
7124
+ for (int il = 0; il < n_layer; ++il) {
7125
+ struct ggml_tensor * inpSA = inpL;
7126
+
7127
+ // norm
7128
+ cur = llm_build_norm(ctx0, inpL, hparams,
7129
+ model.layers[il].attn_norm, NULL,
7130
+ LLM_NORM_RMS, cb, il);
7131
+ cb(cur, "attn_norm", il);
7132
+
7133
+
7134
+ // self-attention
7135
+ {
7136
+ if (model.layers[il].attn_norm_2) {
7137
+ // DBRX
7138
+ cur = llm_build_norm(ctx0, inpL, hparams,
7139
+ model.layers[il].attn_norm_2,
7140
+ NULL,
7141
+ LLM_NORM, cb, il);
7142
+ cb(cur, "attn_norm_2", il);
7143
+ }
7144
+
7145
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
7146
+ cb(cur, "wqkv", il);
7147
+
7148
+ if (hparams.f_clamp_kqv > 0.0f) {
7149
+ cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv);
7150
+ cb(cur, "wqkv_clamped", il);
7151
+ }
7152
+
7153
+ struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
7154
+ Qcur = ggml_rope_custom(
7155
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
7156
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
7157
+ ext_factor, attn_factor, beta_fast, beta_slow
7158
+ );
7159
+ cb(Qcur, "Qcur", il);
7160
+
7161
+ struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
7162
+ Kcur = ggml_rope_custom(
7163
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
7164
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
7165
+ ext_factor, attn_factor, beta_fast, beta_slow
7166
+ );
7167
+ cb(Kcur, "Kcur", il);
7168
+
7169
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
7170
+ cb(Vcur, "Vcur", il);
7171
+
7172
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
7173
+ model.layers[il].wo, model.layers[il].bo,
7174
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
7175
+ }
7176
+
7177
+ if (il == n_layer - 1) {
7178
+ // skip computing output for unused tokens
7179
+ struct ggml_tensor * inp_out_ids = build_inp_out_ids();
7180
+ n_tokens = n_outputs;
7181
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
7182
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
7183
+ }
7184
+
7185
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
7186
+ cb(ffn_inp, "ffn_inp", il);
7187
+
7188
+ // feed-forward network
7189
+ // MoE branch
7190
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
7191
+ model.layers[il].ffn_norm, NULL,
7192
+ LLM_NORM_RMS, cb, il);
7193
+ cb(cur, "ffn_norm", il);
7194
+
7195
+ ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
7196
+ cb(logits, "ffn_moe_logits", il);
7197
+
7198
+ ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
7199
+ cb(probs, "ffn_moe_probs", il);
7200
+
7201
+ // select experts
7202
+ ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
7203
+ cb(selected_experts->src[0], "ffn_moe_argsort", il);
7204
+
7205
+ ggml_tensor * weights = ggml_get_rows(ctx0,
7206
+ ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
7207
+ cb(weights, "ffn_moe_weights", il);
7208
+
7209
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
7210
+
7211
+ ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
7212
+ cb(weights_sum, "ffn_moe_weights_sum", il);
7213
+
7214
+ weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
7215
+ cb(weights, "ffn_moe_weights_norm", il);
7216
+
7217
+ // compute expert outputs
7218
+ ggml_tensor * moe_out = nullptr;
7219
+
7220
+ for (int i = 0; i < n_expert_used; ++i) {
7221
+ ggml_tensor * cur_expert;
7222
+
7223
+ ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
7224
+ cb(cur_up, "ffn_moe_up", il);
7225
+
7226
+ ggml_tensor * cur_gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
7227
+ cb(cur_gate, "ffn_moe_gate", il);
7228
+
7229
+ //GeLU
7230
+ cur_gate = ggml_gelu(ctx0, cur_gate);
7231
+ cb(cur_gate, "ffn_moe_gelu", il);
7232
+
7233
+ cur_expert = ggml_mul(ctx0, cur_up, cur_gate);
7234
+ cb(cur_expert, "ffn_moe_gate_par", il);
7235
+
7236
+ cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
7237
+ cb(cur_expert, "ffn_moe_down", il);
7238
+
7239
+ cur_expert = ggml_mul(ctx0, cur_expert,
7240
+ ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
7241
+ cb(cur_expert, "ffn_moe_weighted", il);
7242
+
7243
+ if (i == 0) {
7244
+ moe_out = cur_expert;
7245
+ } else {
7246
+ moe_out = ggml_add(ctx0, moe_out, cur_expert);
7247
+ cb(moe_out, "ffn_moe_out", il);
7248
+ }
7249
+ }
7250
+
7251
+ cur = moe_out;
7252
+
7253
+ cur = ggml_add(ctx0, cur, ffn_inp);
7254
+ cb(cur, "ffn_out", il);
7255
+
7256
+ ggml_tensor * layer_dir = lctx.cvec.tensor_for(il);
7257
+ if (layer_dir != nullptr) {
7258
+ cur = ggml_add(ctx0, cur, layer_dir);
7259
+ }
7260
+ cb(cur, "l_out", il);
7261
+
7262
+ // input for next layer
7263
+ inpL = cur;
7264
+ }
7265
+
7266
+ cur = inpL;
7267
+
7268
+ cur = llm_build_norm(ctx0, cur, hparams,
7269
+ model.output_norm, NULL,
7270
+ LLM_NORM_RMS, cb, -1);
7271
+ cb(cur, "result_norm", -1);
7272
+
7273
+ // lm_head
7274
+ cur = ggml_mul_mat(ctx0, model.output, cur);
7275
+
7276
+ cb(cur, "result_output", -1);
7277
+
7278
+ ggml_build_forward_expand(gf, cur);
7279
+
7280
+ return gf;
7281
+ }
7282
+
7033
7283
struct ggml_cgraph * build_starcoder() {
7034
7284
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
7035
7285
@@ -9719,6 +9969,10 @@ static struct ggml_cgraph * llama_build_graph(
9719
9969
{
9720
9970
result = llm.build_command_r();
9721
9971
} break;
9972
+ case LLM_ARCH_DBRX:
9973
+ {
9974
+ result = llm.build_dbrx();
9975
+ } break;
9722
9976
default:
9723
9977
GGML_ASSERT(false);
9724
9978
}
@@ -14525,6 +14779,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
14525
14779
case LLM_ARCH_MINICPM:
14526
14780
case LLM_ARCH_XVERSE:
14527
14781
case LLM_ARCH_COMMAND_R:
14782
+ case LLM_ARCH_DBRX: // FIXME REVIEW @ggerganov I am not sure what to put here
14528
14783
return LLAMA_ROPE_TYPE_NORM;
14529
14784
14530
14785
// the pairs of head values are offset by n_rot/2
0 commit comments