Skip to content

Commit 828176e

Browse files
committed
feat: First (broken) pass at nemotronh model architecture
It generates tokens, just not valid ones! https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <[email protected]>
1 parent c25c149 commit 828176e

File tree

1 file changed

+236
-1
lines changed

1 file changed

+236
-1
lines changed

src/llama-model.cpp

Lines changed: 236 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1544,6 +1544,27 @@ void llama_model::load_hparams(llama_model_loader & ml) {
15441544
default: type = LLM_TYPE_UNKNOWN;
15451545
}
15461546
} break;
1547+
case LLM_ARCH_NEMOTRONH:
1548+
{
1549+
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
1550+
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
1551+
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
1552+
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
1553+
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
1554+
1555+
// A layer is recurrent IFF the n_head_kv value is set to 0 and
1556+
// the n_ff value is set to 0
1557+
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
1558+
hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0);
1559+
}
1560+
1561+
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
1562+
1563+
switch (hparams.n_layer) {
1564+
case 56: type = LLM_TYPE_9B; break;
1565+
default: type = LLM_TYPE_UNKNOWN;
1566+
}
1567+
} break;
15471568
case LLM_ARCH_EXAONE:
15481569
{
15491570
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -4626,6 +4647,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
46264647
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
46274648
}
46284649
} break;
4650+
case LLM_ARCH_NEMOTRONH:
4651+
{
4652+
// mamba2 Mixer SSM params
4653+
// NOTE: int64_t for tensor dimensions
4654+
const int64_t d_conv = hparams.ssm_d_conv;
4655+
const int64_t d_inner = hparams.ssm_d_inner;
4656+
const int64_t d_state = hparams.ssm_d_state;
4657+
const int64_t n_ssm_head = hparams.ssm_dt_rank;
4658+
const int64_t n_group = hparams.ssm_n_group;
4659+
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
4660+
4661+
// embeddings
4662+
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
4663+
4664+
// output
4665+
{
4666+
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
4667+
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
4668+
// if output is NULL, init from the input tok embed, duplicated to allow offloading
4669+
if (output == NULL) {
4670+
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
4671+
}
4672+
}
4673+
4674+
for (int i = 0; i < n_layer; ++i) {
4675+
auto & layer = layers[i];
4676+
4677+
if (hparams.is_recurrent(i)) {
4678+
// ssm layers
4679+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4680+
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
4681+
4682+
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
4683+
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED);
4684+
4685+
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
4686+
4687+
// no "weight" suffix for these
4688+
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
4689+
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
4690+
4691+
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
4692+
4693+
// out_proj
4694+
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
4695+
} else if (hparams.n_ff(i) == 0) {
4696+
// attention layers (with optional bias)
4697+
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4698+
const int64_t n_head_i = hparams.n_head(i);
4699+
const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
4700+
const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
4701+
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
4702+
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
4703+
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
4704+
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
4705+
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4706+
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED);
4707+
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED);
4708+
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4709+
} else {
4710+
// mlp layers
4711+
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
4712+
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0);
4713+
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0);
4714+
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
4715+
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED);
4716+
}
4717+
}
4718+
} break;
46294719
case LLM_ARCH_EXAONE:
46304720
{
46314721
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -5800,7 +5890,8 @@ void llama_model::print_info() const {
58005890
arch == LLM_ARCH_JAMBA ||
58015891
arch == LLM_ARCH_FALCON_H1 ||
58025892
arch == LLM_ARCH_PLAMO2 ||
5803-
arch == LLM_ARCH_GRANITE_HYBRID) {
5893+
arch == LLM_ARCH_GRANITE_HYBRID ||
5894+
arch == LLM_ARCH_NEMOTRONH) {
58045895
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
58055896
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
58065897
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
@@ -14070,6 +14161,145 @@ struct llm_build_nemotron : public llm_graph_context {
1407014161
}
1407114162
};
1407214163

14164+
struct llm_build_nemotronh : public llm_graph_context_mamba {
14165+
llm_build_nemotronh(
14166+
const llama_model & model,
14167+
const llm_graph_params & params) :
14168+
llm_graph_context_mamba(params) {
14169+
14170+
const int64_t n_embd_head = hparams.n_embd_head_v;
14171+
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
14172+
14173+
ggml_tensor * cur;
14174+
ggml_tensor * inpL;
14175+
14176+
inpL = build_inp_embd(model.tok_embd);
14177+
14178+
auto * inp = build_inp_mem_hybrid();
14179+
14180+
ggml_tensor * inp_out_ids = build_inp_out_ids();
14181+
14182+
for (int il = 0; il < n_layer; ++il) {
14183+
struct ggml_tensor * inpSA = inpL;
14184+
14185+
// norm
14186+
cur = build_norm(inpL,
14187+
model.layers[il].attn_norm, NULL,
14188+
LLM_NORM_RMS, il);
14189+
cb(cur, "attn_norm", il);
14190+
14191+
if (hparams.is_recurrent(il)) {
14192+
// ssm layer //
14193+
cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il);
14194+
} else if (hparams.n_ff(il) == 0) {
14195+
// attention layer //
14196+
cur = build_attention_layer(cur, inp->get_attn(), model, n_embd_head, il);
14197+
} else {
14198+
cur = build_ffn_layer(cur, inpSA, model, il);
14199+
}
14200+
14201+
if (il == n_layer - 1 && inp_out_ids) {
14202+
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
14203+
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
14204+
}
14205+
14206+
// input for next layer
14207+
inpL = cur;
14208+
}
14209+
14210+
cur = inpL;
14211+
14212+
cur = build_norm(cur,
14213+
model.output_norm, NULL,
14214+
LLM_NORM_RMS, -1);
14215+
14216+
cb(cur, "result_norm", -1);
14217+
res->t_embd = cur;
14218+
14219+
// lm_head
14220+
cur = build_lora_mm(model.output, cur);
14221+
cb(cur, "result_output", -1);
14222+
res->t_logits = cur;
14223+
14224+
ggml_build_forward_expand(gf, cur);
14225+
}
14226+
14227+
ggml_tensor * build_attention_layer(
14228+
ggml_tensor * cur,
14229+
llm_graph_input_attn_kv * inp_attn,
14230+
const llama_model & model,
14231+
const int64_t n_embd_head,
14232+
const int il) {
14233+
14234+
// compute Q and K and (optionally) RoPE them
14235+
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
14236+
cb(Qcur, "Qcur", il);
14237+
if (model.layers[il].bq) {
14238+
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
14239+
cb(Qcur, "Qcur", il);
14240+
}
14241+
14242+
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
14243+
cb(Kcur, "Kcur", il);
14244+
if (model.layers[il].bk) {
14245+
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
14246+
cb(Kcur, "Kcur", il);
14247+
}
14248+
14249+
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
14250+
cb(Vcur, "Vcur", il);
14251+
if (model.layers[il].bv) {
14252+
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
14253+
cb(Vcur, "Vcur", il);
14254+
}
14255+
14256+
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
14257+
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
14258+
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
14259+
14260+
cb(Qcur, "Qcur", il);
14261+
cb(Kcur, "Kcur", il);
14262+
cb(Vcur, "Vcur", il);
14263+
14264+
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
14265+
cur = build_attn(inp_attn,
14266+
model.layers[il].wo, model.layers[il].bo,
14267+
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
14268+
cb(cur, "attn_out", il);
14269+
return cur;
14270+
}
14271+
14272+
ggml_tensor * build_ffn_layer(
14273+
ggml_tensor * cur,
14274+
ggml_tensor * inpSA,
14275+
const llama_model & model,
14276+
const int il) {
14277+
14278+
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
14279+
cb(ffn_inp, "ffn_inp", il);
14280+
cur = build_norm(ffn_inp,
14281+
model.layers[il].ffn_norm, NULL,
14282+
LLM_NORM_RMS, il);
14283+
cb(cur, "ffn_norm", il);
14284+
14285+
cur = build_ffn(cur,
14286+
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
14287+
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
14288+
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
14289+
NULL,
14290+
LLM_FFN_SILU, LLM_FFN_PAR, il);
14291+
cb(cur, "ffn_out", il);
14292+
14293+
cur = ggml_add(ctx0, cur, ffn_inp);
14294+
cb(cur, "ffn_out", il);
14295+
14296+
cur = build_cvec(cur, il);
14297+
cb(cur, "l_out", il);
14298+
14299+
return cur;
14300+
}
14301+
};
14302+
1407314303
struct llm_build_exaone : public llm_graph_context {
1407414304
llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
1407514305
const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -18418,6 +18648,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
1841818648
{
1841918649
llm = std::make_unique<llm_build_nemotron>(*this, params);
1842018650
} break;
18651+
case LLM_ARCH_NEMOTRONH:
18652+
{
18653+
llm = std::make_unique<llm_build_nemotronh>(*this, params);
18654+
} break;
1842118655
case LLM_ARCH_EXAONE:
1842218656
{
1842318657
llm = std::make_unique<llm_build_exaone>(*this, params);
@@ -18648,6 +18882,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
1864818882
case LLM_ARCH_RWKV7:
1864918883
case LLM_ARCH_ARWKV7:
1865018884
case LLM_ARCH_WAVTOKENIZER_DEC:
18885+
case LLM_ARCH_NEMOTRONH:
1865118886
return LLAMA_ROPE_TYPE_NONE;
1865218887

1865318888
// use what we call a normal RoPE, operating on pairs of consecutive head values

0 commit comments

Comments
 (0)