2
2
3
3
#include " arg.h"
4
4
#include " common.h"
5
- #include " log.h"
6
- #include " sampling.h"
7
5
#include " json-schema-to-grammar.h"
8
6
#include " llama.h"
7
+ #include " log.h"
8
+ #include " sampling.h"
9
+ #include " speculative.h"
9
10
10
11
// Change JSON_ASSERT from assert() to GGML_ASSERT:
11
12
#define JSON_ASSERT GGML_ASSERT
@@ -127,6 +128,12 @@ struct server_slot {
127
128
int id;
128
129
int id_task = -1 ;
129
130
131
+ llama_batch batch_spec;
132
+
133
+ llama_context * ctx_dft = nullptr ;
134
+
135
+ common_speculative * spec = nullptr ;
136
+
130
137
// the index relative to completion multi-task request
131
138
size_t index = 0 ;
132
139
@@ -591,11 +598,14 @@ struct server_response {
591
598
};
592
599
593
600
struct server_context {
601
+ common_params params;
602
+
594
603
llama_model * model = nullptr ;
595
604
llama_context * ctx = nullptr ;
596
605
std::vector<common_lora_adapter_container> loras;
597
606
598
- common_params params;
607
+ llama_model * model_dft = nullptr ;
608
+ llama_context_params cparams_dft;
599
609
600
610
llama_batch batch = {};
601
611
@@ -628,17 +638,33 @@ struct server_context {
628
638
model = nullptr ;
629
639
}
630
640
641
+ if (model_dft) {
642
+ llama_free_model (model_dft);
643
+ model_dft = nullptr ;
644
+ }
645
+
631
646
// Clear any sampling context
632
647
for (server_slot & slot : slots) {
633
648
if (slot.smpl != nullptr ) {
649
+ llama_free (slot.ctx_dft );
650
+ slot.ctx_dft = nullptr ;
651
+
652
+ common_speculative_free (slot.spec );
653
+ slot.spec = nullptr ;
654
+
634
655
common_sampler_free (slot.smpl );
656
+ slot.smpl = nullptr ;
657
+
658
+ llama_batch_free (slot.batch_spec );
635
659
}
636
660
}
637
661
638
662
llama_batch_free (batch);
639
663
}
640
664
641
665
bool load_model (const common_params & params_) {
666
+ SRV_INF (" loading model '%s'\n " , params_.model .c_str ());
667
+
642
668
params = params_;
643
669
644
670
common_init_result llama_init = common_init_from_params (params);
@@ -657,6 +683,40 @@ struct server_context {
657
683
add_bos_token = llama_add_bos_token (model);
658
684
has_eos_token = !llama_add_eos_token (model);
659
685
686
+ if (!params.model_draft .empty ()) {
687
+ SRV_INF (" loading draft model '%s'\n " , params_.model_draft .c_str ());
688
+
689
+ auto params_dft = params;
690
+
691
+ params_dft.model = params.model_draft ;
692
+ params_dft.n_gpu_layers = params.n_gpu_layers_draft ;
693
+
694
+ if (params.draft_cpuparams .n_threads > 0 ) {
695
+ params_dft.cpuparams .n_threads = params.draft_cpuparams .n_threads ;
696
+ }
697
+
698
+ params_dft.cpuparams_batch .n_threads = params.draft_cpuparams_batch .n_threads ;
699
+
700
+ common_init_result llama_init_dft = common_init_from_params (params_dft);
701
+
702
+ model_dft = llama_init_dft.model ;
703
+
704
+ if (model_dft == nullptr ) {
705
+ SRV_ERR (" failed to load draft model, '%s'\n " , params.model_draft .c_str ());
706
+ return false ;
707
+ }
708
+
709
+ if (!common_speculative_are_compatible (ctx, llama_init_dft.context )) {
710
+ SRV_ERR (" the draft model '%s' is not compatible with the target model '%s'\n " , params.model_draft .c_str (), params.model .c_str ());
711
+ return false ;
712
+ }
713
+
714
+ cparams_dft = common_context_params_to_llama (params);
715
+
716
+ // the context is not needed - we will create one for each slot
717
+ llama_free (llama_init_dft.context );
718
+ }
719
+
660
720
return true ;
661
721
}
662
722
@@ -685,6 +745,22 @@ struct server_context {
685
745
slot.n_ctx = n_ctx_slot;
686
746
slot.n_predict = params.n_predict ;
687
747
748
+ if (model_dft) {
749
+ slot.ctx_dft = llama_new_context_with_model (model_dft, cparams_dft);
750
+ if (slot.ctx_dft == nullptr ) {
751
+ SRV_ERR (" %s" , " failed to create draft context\n " );
752
+ return ;
753
+ }
754
+
755
+ slot.spec = common_speculative_init (slot.ctx_dft );
756
+ if (slot.spec == nullptr ) {
757
+ SRV_ERR (" %s" , " failed to create speculator\n " );
758
+ return ;
759
+ }
760
+
761
+ slot.batch_spec = llama_batch_init (params.n_draft + 1 , 0 , 1 );
762
+ }
763
+
688
764
SLT_INF (slot, " new slot n_ctx_slot = %d\n " , slot.n_ctx );
689
765
690
766
slot.sparams = params.sparams ;
@@ -2168,38 +2244,108 @@ struct server_context {
2168
2244
continue ; // continue loop of slots
2169
2245
}
2170
2246
2171
- completion_token_output result;
2172
- const llama_token id = common_sampler_sample (slot.smpl , ctx, slot.i_batch - i);
2247
+ llama_token id;
2173
2248
2174
- common_sampler_accept (slot.smpl , id, true );
2249
+ {
2250
+ completion_token_output result;
2175
2251
2176
- slot.n_decoded += 1 ;
2177
- if (slot.n_decoded == 1 ) {
2178
- slot.t_start_generation = ggml_time_us ();
2179
- slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt ) / 1e3 ;
2180
- metrics.on_prompt_eval (slot);
2181
- }
2252
+ id = common_sampler_sample (slot.smpl , ctx, slot.i_batch - i);
2182
2253
2183
- result. tok = id ;
2254
+ common_sampler_accept (slot. smpl , id, true ) ;
2184
2255
2185
- const auto * cur_p = common_sampler_get_candidates (slot.smpl );
2256
+ slot.n_decoded += 1 ;
2257
+ if (slot.n_decoded == 1 ) {
2258
+ slot.t_start_generation = ggml_time_us ();
2259
+ slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt ) / 1e3 ;
2260
+ metrics.on_prompt_eval (slot);
2261
+ }
2186
2262
2187
- for (size_t i = 0 ; i < (size_t ) slot.sparams .n_probs ; ++i) {
2188
- result.probs .push_back ({
2189
- cur_p->data [i].id ,
2190
- i >= cur_p->size ? 0 .0f : cur_p->data [i].p ,
2191
- });
2192
- }
2263
+ result.tok = id;
2193
2264
2194
- if (!process_token (result, slot)) {
2195
- // release slot because of stop condition
2196
- slot.release ();
2197
- slot.print_timings ();
2198
- send_final_response (slot);
2199
- metrics.on_prediction (slot);
2265
+ const auto * cur_p = common_sampler_get_candidates (slot.smpl );
2266
+
2267
+ for (size_t i = 0 ; i < (size_t ) slot.sparams .n_probs ; ++i) {
2268
+ result.probs .push_back ({
2269
+ cur_p->data [i].id ,
2270
+ i >= cur_p->size ? 0 .0f : cur_p->data [i].p ,
2271
+ });
2272
+ }
2273
+
2274
+ if (!process_token (result, slot)) {
2275
+ // release slot because of stop condition
2276
+ slot.release ();
2277
+ slot.print_timings ();
2278
+ send_final_response (slot);
2279
+ metrics.on_prediction (slot);
2280
+ }
2200
2281
}
2201
2282
2202
2283
slot.i_batch = -1 ;
2284
+
2285
+ if (slot.ctx_dft ) {
2286
+ struct common_speculative_params params_spec;
2287
+ params_spec.n_draft = params.n_draft ;
2288
+ params_spec.n_reuse = 256 ;
2289
+ params_spec.p_min = 0 .9f ;
2290
+
2291
+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
2292
+
2293
+ if (draft.size () > params.n_draft_min ) {
2294
+ common_batch_clear (slot.batch_spec );
2295
+ common_batch_add (slot.batch_spec , id, slot.n_past ++, { slot.id }, true );
2296
+
2297
+ for (size_t i = 0 ; i < draft.size (); ++i) {
2298
+ common_batch_add (slot.batch_spec , draft[i], slot.n_past + i, { slot.id }, true );
2299
+ }
2300
+
2301
+ llama_decode (ctx, slot.batch_spec );
2302
+
2303
+ const auto ids = common_sampler_sample_n (slot.smpl , ctx, draft);
2304
+
2305
+ slot.n_past += ids.size () - 1 ;
2306
+
2307
+ slot.cache_tokens .push_back (id);
2308
+
2309
+ for (size_t i = 0 ; i < ids.size (); ++i) {
2310
+ completion_token_output result;
2311
+
2312
+ id = ids[i];
2313
+
2314
+ common_sampler_accept (slot.smpl , id, true );
2315
+
2316
+ slot.n_decoded += 1 ;
2317
+ if (slot.n_decoded == 1 ) {
2318
+ slot.t_start_generation = ggml_time_us ();
2319
+ slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt ) / 1e3 ;
2320
+ metrics.on_prompt_eval (slot);
2321
+ }
2322
+
2323
+ result.tok = id;
2324
+
2325
+ const auto * cur_p = common_sampler_get_candidates (slot.smpl );
2326
+
2327
+ for (size_t i = 0 ; i < (size_t ) slot.sparams .n_probs ; ++i) {
2328
+ result.probs .push_back ({
2329
+ cur_p->data [i].id ,
2330
+ i >= cur_p->size ? 0 .0f : cur_p->data [i].p ,
2331
+ });
2332
+ }
2333
+
2334
+ if (!process_token (result, slot)) {
2335
+ // release slot because of stop condition
2336
+ slot.release ();
2337
+ slot.print_timings ();
2338
+ send_final_response (slot);
2339
+ metrics.on_prediction (slot);
2340
+ break ;
2341
+ }
2342
+ }
2343
+
2344
+ llama_kv_cache_seq_rm (ctx, slot.id , slot.n_past , -1 );
2345
+
2346
+ slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
2347
+ }
2348
+ }
2203
2349
}
2204
2350
}
2205
2351
0 commit comments