@@ -214,6 +214,7 @@ enum llm_arch {
214
214
LLM_ARCH_GEMMA,
215
215
LLM_ARCH_STARCODER2,
216
216
LLM_ARCH_MAMBA,
217
+ LLM_ARCH_COMMAND_R,
217
218
LLM_ARCH_UNKNOWN,
218
219
};
219
220
@@ -243,6 +244,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
243
244
{ LLM_ARCH_GEMMA, "gemma" },
244
245
{ LLM_ARCH_STARCODER2, "starcoder2" },
245
246
{ LLM_ARCH_MAMBA, "mamba" },
247
+ { LLM_ARCH_COMMAND_R, "command-r" },
246
248
{ LLM_ARCH_UNKNOWN, "(unknown)" },
247
249
};
248
250
@@ -267,6 +269,7 @@ enum llm_kv {
267
269
LLM_KV_EXPERT_COUNT,
268
270
LLM_KV_EXPERT_USED_COUNT,
269
271
LLM_KV_POOLING_TYPE,
272
+ LLM_KV_LOGIT_SCALE,
270
273
271
274
LLM_KV_ATTENTION_HEAD_COUNT,
272
275
LLM_KV_ATTENTION_HEAD_COUNT_KV,
@@ -330,6 +333,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
330
333
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
331
334
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
332
335
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
336
+ { LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
333
337
334
338
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
335
339
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@@ -836,6 +840,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
836
840
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
837
841
},
838
842
},
843
+ {
844
+ LLM_ARCH_COMMAND_R,
845
+ {
846
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
847
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
848
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
849
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
850
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
851
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
852
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
853
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
854
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
855
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
856
+ },
857
+ },
839
858
{
840
859
LLM_ARCH_UNKNOWN,
841
860
{
@@ -1610,6 +1629,7 @@ enum e_model {
1610
1629
MODEL_20B,
1611
1630
MODEL_30B,
1612
1631
MODEL_34B,
1632
+ MODEL_35B,
1613
1633
MODEL_40B,
1614
1634
MODEL_65B,
1615
1635
MODEL_70B,
@@ -1656,6 +1676,7 @@ struct llama_hparams {
1656
1676
1657
1677
float f_clamp_kqv = 0.0f;
1658
1678
float f_max_alibi_bias = 0.0f;
1679
+ float f_logit_scale = 0.0f;
1659
1680
1660
1681
bool causal_attn = true;
1661
1682
bool need_kq_pos = false;
@@ -3237,6 +3258,7 @@ static const char * llama_model_type_name(e_model type) {
3237
3258
case MODEL_20B: return "20B";
3238
3259
case MODEL_30B: return "30B";
3239
3260
case MODEL_34B: return "34B";
3261
+ case MODEL_35B: return "35B";
3240
3262
case MODEL_40B: return "40B";
3241
3263
case MODEL_65B: return "65B";
3242
3264
case MODEL_70B: return "70B";
@@ -3628,6 +3650,14 @@ static void llm_load_hparams(
3628
3650
default: model.type = e_model::MODEL_UNKNOWN;
3629
3651
}
3630
3652
} break;
3653
+ case LLM_ARCH_COMMAND_R:
3654
+ {
3655
+ ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
3656
+ switch (hparams.n_layer) {
3657
+ case 40: model.type = e_model::MODEL_35B; break;
3658
+ default: model.type = e_model::MODEL_UNKNOWN;
3659
+ }
3660
+ } break;
3631
3661
default: (void)0;
3632
3662
}
3633
3663
@@ -3937,6 +3967,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
3937
3967
LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
3938
3968
LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv);
3939
3969
LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias);
3970
+ LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale);
3940
3971
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
3941
3972
LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert);
3942
3973
LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used);
@@ -4910,6 +4941,37 @@ static bool llm_load_tensors(
4910
4941
layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd});
4911
4942
}
4912
4943
} break;
4944
+ case LLM_ARCH_COMMAND_R:
4945
+ {
4946
+ model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4947
+
4948
+ // output
4949
+ {
4950
+ model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
4951
+ // init output from the input tok embed
4952
+ model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
4953
+ ml.n_created--; // artificial tensor
4954
+ ml.size_data += ggml_nbytes(model.output);
4955
+ }
4956
+
4957
+ for (int i = 0; i < n_layer; ++i) {
4958
+ ggml_context * ctx_layer = ctx_for_layer(i);
4959
+ ggml_context * ctx_split = ctx_for_layer_split(i);
4960
+
4961
+ auto & layer = model.layers[i];
4962
+
4963
+ layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
4964
+
4965
+ layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
4966
+ layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
4967
+ layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
4968
+ layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
4969
+
4970
+ layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
4971
+ layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
4972
+ layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
4973
+ }
4974
+ } break;
4913
4975
default:
4914
4976
throw std::runtime_error("unknown architecture");
4915
4977
}
@@ -8302,6 +8364,125 @@ struct llm_build_context {
8302
8364
8303
8365
return gf;
8304
8366
}
8367
+
8368
+ struct ggml_cgraph * build_command_r() {
8369
+
8370
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
8371
+
8372
+ const int64_t n_embd_head = hparams.n_embd_head_v;
8373
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
8374
+ const float f_logit_scale = hparams.f_logit_scale;
8375
+
8376
+ struct ggml_tensor * cur;
8377
+ struct ggml_tensor * inpL;
8378
+
8379
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
8380
+ cb(inpL, "inp_embd", -1);
8381
+
8382
+ // inp_pos - contains the positions
8383
+ struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
8384
+ cb(inp_pos, "inp_pos", -1);
8385
+
8386
+ // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
8387
+ struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
8388
+ cb(KQ_mask, "KQ_mask", -1);
8389
+
8390
+ for (int il = 0; il < n_layer; ++il) {
8391
+
8392
+ // norm
8393
+ cur = llm_build_norm(ctx0, inpL, hparams,
8394
+ model.layers[il].attn_norm, NULL,
8395
+ LLM_NORM, cb, il);
8396
+ cb(cur, "attn_norm", il);
8397
+ struct ggml_tensor * ffn_inp = cur;
8398
+
8399
+ // self-attention
8400
+ {
8401
+ // compute Q and K and RoPE them
8402
+ struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
8403
+ cb(Qcur, "Qcur", il);
8404
+ if (model.layers[il].bq) {
8405
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
8406
+ cb(Qcur, "Qcur", il);
8407
+ }
8408
+
8409
+ struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
8410
+ cb(Kcur, "Kcur", il);
8411
+ if (model.layers[il].bk) {
8412
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
8413
+ cb(Kcur, "Kcur", il);
8414
+ }
8415
+
8416
+ struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
8417
+ cb(Vcur, "Vcur", il);
8418
+ if (model.layers[il].bv) {
8419
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
8420
+ cb(Vcur, "Vcur", il);
8421
+ }
8422
+
8423
+ Qcur = ggml_rope_custom(
8424
+ ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
8425
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
8426
+ ext_factor, attn_factor, beta_fast, beta_slow
8427
+ );
8428
+ cb(Qcur, "Qcur", il);
8429
+
8430
+ Kcur = ggml_rope_custom(
8431
+ ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
8432
+ n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
8433
+ ext_factor, attn_factor, beta_fast, beta_slow
8434
+ );
8435
+ cb(Kcur, "Kcur", il);
8436
+
8437
+ cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
8438
+ model.layers[il].wo, model.layers[il].bo,
8439
+ Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
8440
+ cb(cur, "kqv_out", il);
8441
+ }
8442
+
8443
+ struct ggml_tensor * attn_out = cur;
8444
+
8445
+ // feed-forward network
8446
+ {
8447
+ cur = llm_build_ffn(ctx0, ffn_inp,
8448
+ model.layers[il].ffn_up, NULL,
8449
+ model.layers[il].ffn_gate, NULL,
8450
+ model.layers[il].ffn_down, NULL,
8451
+ NULL,
8452
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
8453
+ cb(cur, "ffn_out", il);
8454
+ }
8455
+
8456
+ // add together residual + FFN + self-attention
8457
+ cur = ggml_add(ctx0, cur, inpL);
8458
+ cur = ggml_add(ctx0, cur, attn_out);
8459
+ cb(cur, "l_out", il);
8460
+
8461
+ // input for next layer
8462
+ inpL = cur;
8463
+ }
8464
+
8465
+ cur = inpL;
8466
+
8467
+ cur = llm_build_norm(ctx0, cur, hparams,
8468
+ model.output_norm, NULL,
8469
+ LLM_NORM, cb, -1);
8470
+ cb(cur, "result_norm", -1);
8471
+
8472
+ // lm_head
8473
+ cur = ggml_mul_mat(ctx0, model.output, cur);
8474
+
8475
+ if (f_logit_scale) {
8476
+ cur = ggml_scale(ctx0, cur, f_logit_scale);
8477
+ }
8478
+
8479
+ cb(cur, "result_output", -1);
8480
+
8481
+ ggml_build_forward_expand(gf, cur);
8482
+
8483
+ return gf;
8484
+
8485
+ }
8305
8486
};
8306
8487
8307
8488
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -8473,6 +8654,10 @@ static struct ggml_cgraph * llama_build_graph(
8473
8654
{
8474
8655
result = llm.build_mamba();
8475
8656
} break;
8657
+ case LLM_ARCH_COMMAND_R:
8658
+ {
8659
+ result = llm.build_command_r();
8660
+ } break;
8476
8661
default:
8477
8662
GGML_ASSERT(false);
8478
8663
}
@@ -13053,6 +13238,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
13053
13238
case LLM_ARCH_ORION:
13054
13239
case LLM_ARCH_INTERNLM2:
13055
13240
case LLM_ARCH_MINICPM:
13241
+ case LLM_ARCH_COMMAND_R:
13056
13242
return LLAMA_ROPE_TYPE_NORM;
13057
13243
13058
13244
// the pairs of head values are offset by n_rot/2
0 commit comments