Skip to content

Commit 831e63a

Browse files
committed
server : add speculative decoding support
ggml-ci
1 parent ccc8f63 commit 831e63a

File tree

1 file changed

+172
-26
lines changed

1 file changed

+172
-26
lines changed

examples/server/server.cpp

Lines changed: 172 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
#include "arg.h"
44
#include "common.h"
5-
#include "log.h"
6-
#include "sampling.h"
75
#include "json-schema-to-grammar.h"
86
#include "llama.h"
7+
#include "log.h"
8+
#include "sampling.h"
9+
#include "speculative.h"
910

1011
// Change JSON_ASSERT from assert() to GGML_ASSERT:
1112
#define JSON_ASSERT GGML_ASSERT
@@ -127,6 +128,12 @@ struct server_slot {
127128
int id;
128129
int id_task = -1;
129130

131+
llama_batch batch_spec;
132+
133+
llama_context * ctx_dft = nullptr;
134+
135+
common_speculative * spec = nullptr;
136+
130137
// the index relative to completion multi-task request
131138
size_t index = 0;
132139

@@ -591,11 +598,14 @@ struct server_response {
591598
};
592599

593600
struct server_context {
601+
common_params params;
602+
594603
llama_model * model = nullptr;
595604
llama_context * ctx = nullptr;
596605
std::vector<common_lora_adapter_container> loras;
597606

598-
common_params params;
607+
llama_model * model_dft = nullptr;
608+
llama_context_params cparams_dft;
599609

600610
llama_batch batch = {};
601611

@@ -628,17 +638,33 @@ struct server_context {
628638
model = nullptr;
629639
}
630640

641+
if (model_dft) {
642+
llama_free_model(model_dft);
643+
model_dft = nullptr;
644+
}
645+
631646
// Clear any sampling context
632647
for (server_slot & slot : slots) {
633648
if (slot.smpl != nullptr) {
649+
llama_free(slot.ctx_dft);
650+
slot.ctx_dft = nullptr;
651+
652+
common_speculative_free(slot.spec);
653+
slot.spec = nullptr;
654+
634655
common_sampler_free(slot.smpl);
656+
slot.smpl = nullptr;
657+
658+
llama_batch_free(slot.batch_spec);
635659
}
636660
}
637661

638662
llama_batch_free(batch);
639663
}
640664

641665
bool load_model(const common_params & params_) {
666+
SRV_INF("loading model '%s'\n", params_.model.c_str());
667+
642668
params = params_;
643669

644670
common_init_result llama_init = common_init_from_params(params);
@@ -657,6 +683,40 @@ struct server_context {
657683
add_bos_token = llama_add_bos_token(model);
658684
has_eos_token = !llama_add_eos_token(model);
659685

686+
if (!params.model_draft.empty()) {
687+
SRV_INF("loading draft model '%s'\n", params_.model_draft.c_str());
688+
689+
auto params_dft = params;
690+
691+
params_dft.model = params.model_draft;
692+
params_dft.n_gpu_layers = params.n_gpu_layers_draft;
693+
694+
if (params.draft_cpuparams.n_threads > 0) {
695+
params_dft.cpuparams.n_threads = params.draft_cpuparams.n_threads;
696+
}
697+
698+
params_dft.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
699+
700+
common_init_result llama_init_dft = common_init_from_params(params_dft);
701+
702+
model_dft = llama_init_dft.model;
703+
704+
if (model_dft == nullptr) {
705+
SRV_ERR("failed to load draft model, '%s'\n", params.model_draft.c_str());
706+
return false;
707+
}
708+
709+
if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
710+
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params.model_draft.c_str(), params.model.c_str());
711+
return false;
712+
}
713+
714+
cparams_dft = common_context_params_to_llama(params);
715+
716+
// the context is not needed - we will create one for each slot
717+
llama_free(llama_init_dft.context);
718+
}
719+
660720
return true;
661721
}
662722

@@ -685,6 +745,22 @@ struct server_context {
685745
slot.n_ctx = n_ctx_slot;
686746
slot.n_predict = params.n_predict;
687747

748+
if (model_dft) {
749+
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
750+
if (slot.ctx_dft == nullptr) {
751+
SRV_ERR("%s", "failed to create draft context\n");
752+
return;
753+
}
754+
755+
slot.spec = common_speculative_init(slot.ctx_dft);
756+
if (slot.spec == nullptr) {
757+
SRV_ERR("%s", "failed to create speculator\n");
758+
return;
759+
}
760+
761+
slot.batch_spec = llama_batch_init(params.n_draft + 1, 0, 1);
762+
}
763+
688764
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
689765

690766
slot.sparams = params.sparams;
@@ -2168,38 +2244,108 @@ struct server_context {
21682244
continue; // continue loop of slots
21692245
}
21702246

2171-
completion_token_output result;
2172-
const llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
2247+
llama_token id;
21732248

2174-
common_sampler_accept(slot.smpl, id, true);
2249+
{
2250+
completion_token_output result;
21752251

2176-
slot.n_decoded += 1;
2177-
if (slot.n_decoded == 1) {
2178-
slot.t_start_generation = ggml_time_us();
2179-
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
2180-
metrics.on_prompt_eval(slot);
2181-
}
2252+
id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
21822253

2183-
result.tok = id;
2254+
common_sampler_accept(slot.smpl, id, true);
21842255

2185-
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
2256+
slot.n_decoded += 1;
2257+
if (slot.n_decoded == 1) {
2258+
slot.t_start_generation = ggml_time_us();
2259+
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
2260+
metrics.on_prompt_eval(slot);
2261+
}
21862262

2187-
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
2188-
result.probs.push_back({
2189-
cur_p->data[i].id,
2190-
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2191-
});
2192-
}
2263+
result.tok = id;
21932264

2194-
if (!process_token(result, slot)) {
2195-
// release slot because of stop condition
2196-
slot.release();
2197-
slot.print_timings();
2198-
send_final_response(slot);
2199-
metrics.on_prediction(slot);
2265+
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
2266+
2267+
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
2268+
result.probs.push_back({
2269+
cur_p->data[i].id,
2270+
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2271+
});
2272+
}
2273+
2274+
if (!process_token(result, slot)) {
2275+
// release slot because of stop condition
2276+
slot.release();
2277+
slot.print_timings();
2278+
send_final_response(slot);
2279+
metrics.on_prediction(slot);
2280+
}
22002281
}
22012282

22022283
slot.i_batch = -1;
2284+
2285+
if (slot.ctx_dft) {
2286+
struct common_speculative_params params_spec;
2287+
params_spec.n_draft = params.n_draft;
2288+
params_spec.n_reuse = 256;
2289+
params_spec.p_min = 0.9f;
2290+
2291+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
2292+
2293+
if (draft.size() > params.n_draft_min) {
2294+
common_batch_clear(slot.batch_spec);
2295+
common_batch_add(slot.batch_spec, id, slot.n_past++, { slot.id }, true);
2296+
2297+
for (size_t i = 0; i < draft.size(); ++i) {
2298+
common_batch_add(slot.batch_spec, draft[i], slot.n_past + i, { slot.id }, true);
2299+
}
2300+
2301+
llama_decode(ctx, slot.batch_spec);
2302+
2303+
const auto ids = common_sampler_sample_n(slot.smpl, ctx, draft);
2304+
2305+
slot.n_past += ids.size() - 1;
2306+
2307+
slot.cache_tokens.push_back(id);
2308+
2309+
for (size_t i = 0; i < ids.size(); ++i) {
2310+
completion_token_output result;
2311+
2312+
id = ids[i];
2313+
2314+
common_sampler_accept(slot.smpl, id, true);
2315+
2316+
slot.n_decoded += 1;
2317+
if (slot.n_decoded == 1) {
2318+
slot.t_start_generation = ggml_time_us();
2319+
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
2320+
metrics.on_prompt_eval(slot);
2321+
}
2322+
2323+
result.tok = id;
2324+
2325+
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
2326+
2327+
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
2328+
result.probs.push_back({
2329+
cur_p->data[i].id,
2330+
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2331+
});
2332+
}
2333+
2334+
if (!process_token(result, slot)) {
2335+
// release slot because of stop condition
2336+
slot.release();
2337+
slot.print_timings();
2338+
send_final_response(slot);
2339+
metrics.on_prediction(slot);
2340+
break;
2341+
}
2342+
}
2343+
2344+
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
2345+
2346+
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
2347+
}
2348+
}
22032349
}
22042350
}
22052351

0 commit comments

Comments
 (0)