Skip to content

Load all MoE experts during warmup #11571

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,8 @@ struct common_init_result common_init_from_params(common_params & params) {
if (params.warmup) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);

llama_set_warmup(lctx, true);

std::vector<llama_token> tmp;
llama_token bos = llama_vocab_bos(vocab);
llama_token eos = llama_vocab_eos(vocab);
Expand Down Expand Up @@ -1063,6 +1065,7 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_kv_self_clear(lctx);
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
llama_set_warmup(lctx, false);
}

iparams.model.reset(model);
Expand Down
4 changes: 4 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,10 @@ extern "C" {
// If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);

// Set whether the model is in warmup mode or not
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);

// Set abort callback
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);

Expand Down
13 changes: 12 additions & 1 deletion src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ llama_context::llama_context(
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf;
cparams.pooling_type = params.pooling_type;
cparams.warmup = false;

cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx;
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
Expand Down Expand Up @@ -952,6 +953,12 @@ void llama_context::set_causal_attn(bool value) {
cparams.causal_attn = value;
}

void llama_context::set_warmup(bool value) {
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);

cparams.warmup = value;
}

void llama_context::set_adapter_lora(
llama_adapter_lora * adapter,
float scale) {
Expand Down Expand Up @@ -1598,7 +1605,7 @@ void llama_context::output_reorder() {
//

int32_t llama_context::graph_max_nodes() const {
return std::max<int32_t>(8192, 5*model.n_tensors());
return std::max<int32_t>(65536, 5*model.n_tensors());
}

ggml_cgraph * llama_context::graph_init() {
Expand Down Expand Up @@ -2376,6 +2383,10 @@ void llama_set_causal_attn(llama_context * ctx, bool causal_attn) {
ctx->set_causal_attn(causal_attn);
}

void llama_set_warmup(llama_context * ctx, bool warmup) {
ctx->set_warmup(warmup);
}

void llama_synchronize(llama_context * ctx) {
ctx->synchronize();
}
Expand Down
1 change: 1 addition & 0 deletions src/llama-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct llama_context {

void set_embeddings (bool value);
void set_causal_attn(bool value);
void set_warmup(bool value);

void set_adapter_lora(
llama_adapter_lora * adapter,
Expand Down
1 change: 1 addition & 0 deletions src/llama-cparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct llama_cparams {
bool offload_kqv;
bool flash_attn;
bool no_perf;
bool warmup;

enum llama_pooling_type pooling_type;

Expand Down
2 changes: 1 addition & 1 deletion src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
n_embd_head_v (hparams.n_embd_head_v),
n_embd_v_gqa (hparams.n_embd_v_gqa()),
n_expert (hparams.n_expert),
n_expert_used (hparams.n_expert_used),
n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used),
freq_base (cparams.rope_freq_base),
freq_scale (cparams.rope_freq_scale),
ext_factor (cparams.yarn_ext_factor),
Expand Down
Loading