@@ -600,7 +600,7 @@ struct server_response {
600
600
};
601
601
602
602
struct server_context {
603
- common_params params ;
603
+ common_params params_base ;
604
604
605
605
llama_model * model = nullptr ;
606
606
llama_context * ctx = nullptr ;
@@ -662,19 +662,19 @@ struct server_context {
662
662
llama_batch_free (batch);
663
663
}
664
664
665
- bool load_model (const common_params & params_ ) {
666
- SRV_INF (" loading model '%s'\n " , params_ .model .c_str ());
665
+ bool load_model (const common_params & params ) {
666
+ SRV_INF (" loading model '%s'\n " , params .model .c_str ());
667
667
668
- params = params_ ;
668
+ params_base = params ;
669
669
670
- common_init_result llama_init = common_init_from_params (params );
670
+ common_init_result llama_init = common_init_from_params (params_base );
671
671
672
672
model = llama_init.model ;
673
673
ctx = llama_init.context ;
674
674
loras = llama_init.lora_adapters ;
675
675
676
676
if (model == nullptr ) {
677
- SRV_ERR (" failed to load model, '%s'\n " , params .model .c_str ());
677
+ SRV_ERR (" failed to load model, '%s'\n " , params_base .model .c_str ());
678
678
return false ;
679
679
}
680
680
@@ -683,34 +683,34 @@ struct server_context {
683
683
add_bos_token = llama_add_bos_token (model);
684
684
has_eos_token = !llama_add_eos_token (model);
685
685
686
- if (!params .speculative .model .empty ()) {
687
- SRV_INF (" loading draft model '%s'\n " , params .speculative .model .c_str ());
686
+ if (!params_base .speculative .model .empty ()) {
687
+ SRV_INF (" loading draft model '%s'\n " , params_base .speculative .model .c_str ());
688
688
689
- auto params_dft = params ;
689
+ auto params_dft = params_base ;
690
690
691
- params_dft.model = params .speculative .model ;
692
- params_dft.n_ctx = params .speculative .n_ctx ;
693
- params_dft.n_gpu_layers = params .speculative .n_gpu_layers ;
691
+ params_dft.model = params_base .speculative .model ;
692
+ params_dft.n_ctx = params_base .speculative .n_ctx ;
693
+ params_dft.n_gpu_layers = params_base .speculative .n_gpu_layers ;
694
694
695
695
common_init_result llama_init_dft = common_init_from_params (params_dft);
696
696
697
697
model_dft = llama_init_dft.model ;
698
698
699
699
if (model_dft == nullptr ) {
700
- SRV_ERR (" failed to load draft model, '%s'\n " , params .speculative .model .c_str ());
700
+ SRV_ERR (" failed to load draft model, '%s'\n " , params_base .speculative .model .c_str ());
701
701
return false ;
702
702
}
703
703
704
704
if (!common_speculative_are_compatible (ctx, llama_init_dft.context )) {
705
- SRV_ERR (" the draft model '%s' is not compatible with the target model '%s'\n " , params .speculative .model .c_str (), params .model .c_str ());
705
+ SRV_ERR (" the draft model '%s' is not compatible with the target model '%s'\n " , params_base .speculative .model .c_str (), params_base .model .c_str ());
706
706
707
707
llama_free (llama_init_dft.context );
708
708
llama_free_model (llama_init_dft.model );
709
709
710
710
return false ;
711
711
}
712
712
713
- cparams_dft = common_context_params_to_llama (params );
713
+ cparams_dft = common_context_params_to_llama (params_base );
714
714
cparams_dft.n_batch = llama_n_ctx (llama_init_dft.context );
715
715
716
716
// the context is not needed - we will create one for each slot
@@ -734,19 +734,19 @@ struct server_context {
734
734
}
735
735
736
736
void init () {
737
- const int32_t n_ctx_slot = n_ctx / params .n_parallel ;
737
+ const int32_t n_ctx_slot = n_ctx / params_base .n_parallel ;
738
738
739
- SRV_INF (" initializing slots, n_slots = %d\n " , params .n_parallel );
739
+ SRV_INF (" initializing slots, n_slots = %d\n " , params_base .n_parallel );
740
740
741
- for (int i = 0 ; i < params .n_parallel ; i++) {
741
+ for (int i = 0 ; i < params_base .n_parallel ; i++) {
742
742
server_slot slot;
743
743
744
744
slot.id = i;
745
745
slot.n_ctx = n_ctx_slot;
746
- slot.n_predict = params .n_predict ;
746
+ slot.n_predict = params_base .n_predict ;
747
747
748
748
if (model_dft) {
749
- slot.batch_spec = llama_batch_init (params .speculative .n_max + 1 , 0 , 1 );
749
+ slot.batch_spec = llama_batch_init (params_base .speculative .n_max + 1 , 0 , 1 );
750
750
751
751
slot.ctx_dft = llama_new_context_with_model (model_dft, cparams_dft);
752
752
if (slot.ctx_dft == nullptr ) {
@@ -763,7 +763,7 @@ struct server_context {
763
763
764
764
SLT_INF (slot, " new slot n_ctx_slot = %d\n " , slot.n_ctx );
765
765
766
- slot.params .sampling = params .sampling ;
766
+ slot.params .sampling = params_base .sampling ;
767
767
768
768
slot.callback_on_release = [this ](int ) {
769
769
queue_tasks.pop_deferred_task ();
@@ -783,7 +783,7 @@ struct server_context {
783
783
const int32_t n_batch = llama_n_batch (ctx);
784
784
785
785
// only a single seq_id per token is needed
786
- batch = llama_batch_init (std::max (n_batch, params .n_parallel ), 0 , 1 );
786
+ batch = llama_batch_init (std::max (n_batch, params_base .n_parallel ), 0 , 1 );
787
787
}
788
788
789
789
metrics.init ();
@@ -864,8 +864,8 @@ struct server_context {
864
864
bool launch_slot_with_task (server_slot & slot, const server_task & task) {
865
865
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
866
866
slot_params defaults;
867
- defaults.sampling = params .sampling ;
868
- defaults.speculative = params .speculative ;
867
+ defaults.sampling = params_base .sampling ;
868
+ defaults.speculative = params_base .speculative ;
869
869
870
870
const auto & data = task.data ;
871
871
@@ -915,6 +915,8 @@ struct server_context {
915
915
slot.params .speculative .n_max = json_value (data, " speculative.n_max" , defaults.speculative .n_max );
916
916
slot.params .speculative .p_min = json_value (data, " speculative.p_min" , defaults.speculative .p_min );
917
917
918
+ slot.params .speculative .n_min = std::min (slot.params .speculative .n_max , slot.params .speculative .n_min );
919
+
918
920
if (slot.params .sampling .dry_base < 1 .0f ) {
919
921
slot.params .sampling .dry_base = defaults.sampling .dry_base ;
920
922
}
@@ -1066,7 +1068,7 @@ struct server_context {
1066
1068
1067
1069
bool process_token (completion_token_output & result, server_slot & slot) {
1068
1070
// remember which tokens were sampled - used for repetition penalties during sampling
1069
- const std::string token_str = common_token_to_piece (ctx, result.tok , params .special );
1071
+ const std::string token_str = common_token_to_piece (ctx, result.tok , params_base .special );
1070
1072
slot.sampled = result.tok ;
1071
1073
1072
1074
// search stop word and delete it
@@ -1131,7 +1133,7 @@ struct server_context {
1131
1133
}
1132
1134
1133
1135
// check the limits
1134
- if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget (params )) {
1136
+ if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget (params_base )) {
1135
1137
slot.stopped_limit = true ;
1136
1138
slot.has_next_token = false ;
1137
1139
@@ -1232,7 +1234,7 @@ struct server_context {
1232
1234
return json {
1233
1235
{" n_ctx" , slot.n_ctx },
1234
1236
{" n_predict" , slot.n_predict }, // Server configured n_predict
1235
- {" model" , params .model_alias },
1237
+ {" model" , params_base .model_alias },
1236
1238
{" seed" , slot.params .sampling .seed },
1237
1239
{" seed_cur" , slot.smpl ? common_sampler_get_seed (slot.smpl ) : 0 },
1238
1240
{" temperature" , slot.params .sampling .temp },
@@ -1268,6 +1270,10 @@ struct server_context {
1268
1270
{" min_keep" , slot.params .sampling .min_keep },
1269
1271
{" grammar" , slot.params .sampling .grammar },
1270
1272
{" samplers" , samplers},
1273
+ {" speculative" , slot.params .speculative .model .empty () ? false : true },
1274
+ {" speculative.n_max" , slot.params .speculative .n_max },
1275
+ {" speculative.n_min" , slot.params .speculative .n_min },
1276
+ {" speculative.p_min" , slot.params .speculative .p_min },
1271
1277
};
1272
1278
}
1273
1279
@@ -1337,7 +1343,7 @@ struct server_context {
1337
1343
{" content" , !slot.params .stream ? slot.generated_text : " " },
1338
1344
{" id_slot" , slot.id },
1339
1345
{" stop" , true },
1340
- {" model" , params .model_alias },
1346
+ {" model" , params_base .model_alias },
1341
1347
{" tokens_predicted" , slot.n_decoded },
1342
1348
{" tokens_evaluated" , slot.n_prompt_tokens },
1343
1349
{" generation_settings" , get_formated_generation (slot)},
@@ -1510,10 +1516,10 @@ struct server_context {
1510
1516
data.at (" input_prefix" ),
1511
1517
data.at (" input_suffix" ),
1512
1518
data.at (" input_extra" ),
1513
- params .n_batch ,
1514
- params .n_predict ,
1519
+ params_base .n_batch ,
1520
+ params_base .n_predict ,
1515
1521
slots[0 ].n_ctx , // TODO: there should be a better way
1516
- params .spm_infill ,
1522
+ params_base .spm_infill ,
1517
1523
tokenized_prompts[i]
1518
1524
);
1519
1525
create_task (data, tokens);
@@ -1886,7 +1892,7 @@ struct server_context {
1886
1892
// TODO: simplify and improve
1887
1893
for (server_slot & slot : slots) {
1888
1894
if (slot.is_processing () && slot.n_past + 1 >= slot.n_ctx ) {
1889
- if (!params .ctx_shift ) {
1895
+ if (!params_base .ctx_shift ) {
1890
1896
// this check is redundant (for good)
1891
1897
// we should never get here, because generation should already stopped in process_token()
1892
1898
slot.release ();
@@ -1952,7 +1958,7 @@ struct server_context {
1952
1958
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1 ;
1953
1959
1954
1960
// next, batch any pending prompts without exceeding n_batch
1955
- if (params .cont_batching || batch.n_tokens == 0 ) {
1961
+ if (params_base .cont_batching || batch.n_tokens == 0 ) {
1956
1962
for (auto & slot : slots) {
1957
1963
// this slot still has a prompt to be processed
1958
1964
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
@@ -2005,7 +2011,7 @@ struct server_context {
2005
2011
continue ;
2006
2012
}
2007
2013
} else {
2008
- if (!params .ctx_shift ) {
2014
+ if (!params_base .ctx_shift ) {
2009
2015
// if context shift is disabled, we make sure prompt size is smaller than KV size
2010
2016
// TODO: there should be a separate parameter that control prompt truncation
2011
2017
// context shift should be applied only during the generation phase
@@ -2051,11 +2057,11 @@ struct server_context {
2051
2057
slot.n_past = common_lcp (slot.cache_tokens , prompt_tokens);
2052
2058
2053
2059
// reuse chunks from the cached prompt by shifting their KV cache in the new position
2054
- if (params .n_cache_reuse > 0 ) {
2060
+ if (params_base .n_cache_reuse > 0 ) {
2055
2061
size_t head_c = slot.n_past ; // cache
2056
2062
size_t head_p = slot.n_past ; // current prompt
2057
2063
2058
- SLT_DBG (slot, " trying to reuse chunks with size > %d, slot.n_past = %d\n " , params .n_cache_reuse , slot.n_past );
2064
+ SLT_DBG (slot, " trying to reuse chunks with size > %d, slot.n_past = %d\n " , params_base .n_cache_reuse , slot.n_past );
2059
2065
2060
2066
while (head_c < slot.cache_tokens .size () &&
2061
2067
head_p < prompt_tokens.size ()) {
@@ -2068,7 +2074,7 @@ struct server_context {
2068
2074
n_match++;
2069
2075
}
2070
2076
2071
- if (n_match >= (size_t ) params .n_cache_reuse ) {
2077
+ if (n_match >= (size_t ) params_base .n_cache_reuse ) {
2072
2078
SLT_INF (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2073
2079
// for (size_t i = head_p; i < head_p + n_match; i++) {
2074
2080
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
@@ -2303,7 +2309,7 @@ struct server_context {
2303
2309
// TODO: configurable through requests
2304
2310
struct common_speculative_params params_spec;
2305
2311
params_spec.n_draft = slot.params .speculative .n_max ;
2306
- params_spec.n_reuse = 256 ;
2312
+ params_spec.n_reuse = llama_n_ctx (slot. ctx_dft ) - slot. params . speculative . n_max ;
2307
2313
params_spec.p_min = slot.params .speculative .p_min ;
2308
2314
2309
2315
llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
@@ -2847,15 +2853,15 @@ int main(int argc, char ** argv) {
2847
2853
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
2848
2854
json data = {
2849
2855
{ " default_generation_settings" , ctx_server.default_generation_settings_for_props },
2850
- { " total_slots" , ctx_server.params .n_parallel },
2856
+ { " total_slots" , ctx_server.params_base .n_parallel },
2851
2857
{ " chat_template" , llama_get_chat_template (ctx_server.model ) },
2852
2858
};
2853
2859
2854
2860
res_ok (res, data);
2855
2861
};
2856
2862
2857
2863
const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
2858
- if (!ctx_server.params .endpoint_props ) {
2864
+ if (!ctx_server.params_base .endpoint_props ) {
2859
2865
res_error (res, format_error_response (" This server does not support changing global properties. Start it with `--props`" , ERROR_TYPE_NOT_SUPPORTED));
2860
2866
return ;
2861
2867
}
@@ -2868,7 +2874,7 @@ int main(int argc, char ** argv) {
2868
2874
};
2869
2875
2870
2876
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
2871
- if (ctx_server.params .embedding ) {
2877
+ if (ctx_server.params_base .embedding ) {
2872
2878
res_error (res, format_error_response (" This server does not support completions. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
2873
2879
return ;
2874
2880
}
@@ -2974,7 +2980,7 @@ int main(int argc, char ** argv) {
2974
2980
2975
2981
// TODO: maybe merge this function with "handle_completions_generic"
2976
2982
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
2977
- if (ctx_server.params .embedding ) {
2983
+ if (ctx_server.params_base .embedding ) {
2978
2984
res_error (res, format_error_response (" This server does not support completions. Start it without `--embeddings`" , ERROR_TYPE_NOT_SUPPORTED));
2979
2985
return ;
2980
2986
}
@@ -3151,7 +3157,7 @@ int main(int argc, char ** argv) {
3151
3157
};
3152
3158
3153
3159
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3154
- if (!ctx_server.params .reranking || ctx_server.params .embedding ) {
3160
+ if (!ctx_server.params_base .reranking || ctx_server.params_base .embedding ) {
3155
3161
res_error (res, format_error_response (" This server does not support reranking. Start it with `--reranking` and without `--embedding`" , ERROR_TYPE_NOT_SUPPORTED));
3156
3162
return ;
3157
3163
}
0 commit comments