@@ -1205,14 +1205,55 @@ struct server_task_result_apply_lora : server_task_result {
1205
1205
}
1206
1206
};
1207
1207
1208
+ struct server_batch {
1209
+ llama_batch_ext_ptr batch;
1210
+ struct batch_token {
1211
+ llama_token token;
1212
+ llama_seq_id seq_id;
1213
+ bool logits;
1214
+ };
1215
+ std::vector<batch_token> tokens;
1216
+ server_batch () = default ;
1217
+ server_batch (int32_t n_tokens, int32_t n_seq_max) {
1218
+ batch.reset (llama_batch_ext_init (n_tokens, n_seq_max));
1219
+ tokens.reserve (n_tokens);
1220
+ }
1221
+ void clear () {
1222
+ llama_batch_ext_clear (batch.get ());
1223
+ tokens.clear ();
1224
+ }
1225
+ void add_text (llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
1226
+ llama_batch_ext_add_text (batch.get (), token, pos, &seq_id, 1 , logits);
1227
+ tokens.push_back ({token, seq_id, logits});
1228
+ }
1229
+ void set_logits_last () {
1230
+ if (!tokens.empty ()) {
1231
+ llama_batch_ext_set_logits_last (batch.get ());
1232
+ tokens.back ().logits = true ;
1233
+ }
1234
+ }
1235
+ int32_t get_n_tokens () const {
1236
+ return (int32_t )tokens.size ();
1237
+ }
1238
+ server_batch get_view (int32_t offset, int32_t n_tokens) {
1239
+ server_batch view;
1240
+ view.batch = llama_batch_ext_ptr (llama_batch_ext_get_view (batch.get (), offset, n_tokens));
1241
+ view.tokens .reserve (n_tokens);
1242
+ for (int32_t i = 0 ; i < n_tokens; i++) {
1243
+ view.tokens .push_back (tokens[offset + i]);
1244
+ }
1245
+ return view;
1246
+ }
1247
+ };
1248
+
1208
1249
struct server_slot {
1209
1250
int id;
1210
1251
int id_task = -1 ;
1211
1252
1212
1253
// only used for completion/embedding/infill/rerank
1213
1254
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
1214
1255
1215
- llama_batch_ext_ptr batch_spec;
1256
+ server_batch batch_spec;
1216
1257
1217
1258
llama_context * ctx = nullptr ;
1218
1259
llama_context * ctx_dft = nullptr ;
@@ -1784,7 +1825,7 @@ struct server_context {
1784
1825
1785
1826
llama_context_params cparams_dft;
1786
1827
1787
- llama_batch_ext_ptr batch;
1828
+ server_batch batch;
1788
1829
1789
1830
bool clean_kv_cache = true ;
1790
1831
bool add_bos_token = true ;
@@ -1909,7 +1950,7 @@ struct server_context {
1909
1950
slot.n_predict = params_base.n_predict ;
1910
1951
1911
1952
if (model_dft) {
1912
- slot.batch_spec . reset ( llama_batch_ext_init ( params_base.speculative .n_max + 1 , 1 ) );
1953
+ slot.batch_spec = server_batch ( params_base.speculative .n_max + 1 , 1 );
1913
1954
1914
1955
slot.ctx_dft = llama_init_from_model (model_dft, cparams_dft);
1915
1956
if (slot.ctx_dft == nullptr ) {
@@ -1945,7 +1986,7 @@ struct server_context {
1945
1986
const int32_t n_batch = llama_n_batch (ctx);
1946
1987
1947
1988
// only a single seq_id per token is needed
1948
- batch. reset ( llama_batch_ext_init ( std::max (n_batch, params_base.n_parallel ), 1 ) );
1989
+ batch = server_batch ( std::max (n_batch, params_base.n_parallel ), 1 );
1949
1990
}
1950
1991
1951
1992
metrics.init ();
@@ -2063,7 +2104,7 @@ struct server_context {
2063
2104
}
2064
2105
2065
2106
if (slot.ctx_dft ) {
2066
- slot.batch_spec . reset ( llama_batch_ext_init ( slot.params .speculative .n_max + 1 , 1 ) );
2107
+ slot.batch_spec = server_batch ( slot.params .speculative .n_max + 1 , 1 );
2067
2108
}
2068
2109
2069
2110
slot.state = SLOT_STATE_STARTED;
@@ -2371,7 +2412,7 @@ struct server_context {
2371
2412
queue_results.send (std::move (res));
2372
2413
}
2373
2414
2374
- void send_embedding (const server_slot & slot, llama_batch_ext_ptr & batch) {
2415
+ void send_embedding (const server_slot & slot, server_batch & batch) {
2375
2416
auto res = std::make_unique<server_task_result_embd>();
2376
2417
res->id = slot.id_task ;
2377
2418
res->index = slot.index ;
@@ -2382,19 +2423,19 @@ struct server_context {
2382
2423
2383
2424
std::vector<float > embd_res (n_embd, 0 .0f );
2384
2425
2385
- for (int i = 0 ; i < llama_batch_ext_get_n_tokens ( batch.get () ); ++i) {
2386
- llama_batch_ext_token_info tok = llama_batch_ext_get_token_info ( batch.get (), i) ;
2387
- if (!tok.logits || tok.seq_id [ 0 ] != slot.id ) {
2426
+ for (int i = 0 ; i < batch.get_n_tokens ( ); ++i) {
2427
+ auto tok = batch.tokens [i] ;
2428
+ if (!tok.logits || tok.seq_id != slot.id ) {
2388
2429
continue ;
2389
2430
}
2390
2431
2391
- const float * embd = llama_get_embeddings_seq (ctx, tok.seq_id [ 0 ] );
2432
+ const float * embd = llama_get_embeddings_seq (ctx, tok.seq_id );
2392
2433
if (embd == NULL ) {
2393
2434
embd = llama_get_embeddings_ith (ctx, i);
2394
2435
}
2395
2436
2396
2437
if (embd == NULL ) {
2397
- SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , tok.token , tok.seq_id [ 0 ] );
2438
+ SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , tok.token , tok.seq_id );
2398
2439
2399
2440
res->embedding .push_back (std::vector<float >(n_embd, 0 .0f ));
2400
2441
continue ;
@@ -2415,25 +2456,25 @@ struct server_context {
2415
2456
queue_results.send (std::move (res));
2416
2457
}
2417
2458
2418
- void send_rerank (const server_slot & slot, llama_batch_ext_ptr & batch) {
2459
+ void send_rerank (const server_slot & slot, server_batch & batch) {
2419
2460
auto res = std::make_unique<server_task_result_rerank>();
2420
2461
res->id = slot.id_task ;
2421
2462
res->index = slot.index ;
2422
2463
res->n_tokens = slot.n_prompt_tokens ;
2423
2464
2424
- for (int i = 0 ; i < llama_batch_ext_get_n_tokens ( batch.get () ); ++i) {
2425
- llama_batch_ext_token_info tok = llama_batch_ext_get_token_info ( batch.get (), i) ;
2426
- if (!tok.logits || tok.seq_id [ 0 ] != slot.id ) {
2465
+ for (int i = 0 ; i < batch.get_n_tokens ( ); ++i) {
2466
+ auto tok = batch.tokens [i] ;
2467
+ if (!tok.logits || tok.seq_id != slot.id ) {
2427
2468
continue ;
2428
2469
}
2429
2470
2430
- const float * embd = llama_get_embeddings_seq (ctx, tok.seq_id [ 0 ] );
2471
+ const float * embd = llama_get_embeddings_seq (ctx, tok.seq_id );
2431
2472
if (embd == NULL ) {
2432
2473
embd = llama_get_embeddings_ith (ctx, i);
2433
2474
}
2434
2475
2435
2476
if (embd == NULL ) {
2436
- SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , tok.token , tok.seq_id [ 0 ] );
2477
+ SLT_ERR (slot, " failed to get embeddings, token = %d, seq_id = %d\n " , tok.token , tok.seq_id );
2437
2478
2438
2479
res->score = -1e6 ;
2439
2480
continue ;
@@ -2824,7 +2865,7 @@ struct server_context {
2824
2865
}
2825
2866
2826
2867
// start populating the batch for this iteration
2827
- llama_batch_ext_clear ( batch.get () );
2868
+ batch.clear ( );
2828
2869
2829
2870
// track if given slot can be batched with slots already in the batch
2830
2871
server_slot * slot_batched = nullptr ;
@@ -2846,10 +2887,9 @@ struct server_context {
2846
2887
continue ;
2847
2888
}
2848
2889
2849
- slot.i_batch = llama_batch_ext_get_n_tokens ( batch.get () );
2890
+ slot.i_batch = batch.get_n_tokens ( );
2850
2891
2851
- std::array<llama_token, 1 > seq_id = { slot.id };
2852
- llama_batch_ext_add_text (batch.get (), slot.sampled , slot.n_past , seq_id.data (), seq_id.size (), true );
2892
+ batch.add_text (slot.sampled , slot.n_past , slot.id , true );
2853
2893
2854
2894
slot.n_past += 1 ;
2855
2895
@@ -2866,7 +2906,7 @@ struct server_context {
2866
2906
int32_t n_ubatch = llama_n_ubatch (ctx);
2867
2907
2868
2908
// next, batch any pending prompts without exceeding n_batch
2869
- if (params_base.cont_batching || llama_batch_ext_get_n_tokens ( batch.get () ) == 0 ) {
2909
+ if (params_base.cont_batching || batch.get_n_tokens ( ) == 0 ) {
2870
2910
for (auto & slot : slots) {
2871
2911
// check if we can batch this slot with the previous one
2872
2912
if (slot.is_processing ()) {
@@ -3032,7 +3072,7 @@ struct server_context {
3032
3072
// non-causal tasks require to fit the entire prompt in the physical batch
3033
3073
if (slot.is_non_causal ()) {
3034
3074
// cannot fit the prompt in the current batch - will try next iter
3035
- if (llama_batch_ext_get_n_tokens ( batch.get () ) + slot.n_prompt_tokens > n_batch) {
3075
+ if (batch.get_n_tokens ( ) + slot.n_prompt_tokens > n_batch) {
3036
3076
continue ;
3037
3077
}
3038
3078
}
@@ -3052,12 +3092,11 @@ struct server_context {
3052
3092
slot.cache_tokens .resize (slot.n_past );
3053
3093
3054
3094
// add prompt tokens for processing in the current batch
3055
- while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens ( batch.get () ) < n_batch) {
3095
+ while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens ( ) < n_batch) {
3056
3096
// without pooling, we want to output the embeddings for all the tokens in the batch
3057
3097
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type (slot.ctx ) == LLAMA_POOLING_TYPE_NONE;
3058
3098
3059
- std::array<llama_token, 1 > seq_id = { slot.id };
3060
- llama_batch_ext_add_text (batch.get (), prompt_tokens[slot.n_past ], slot.n_past , seq_id.data (), seq_id.size (), need_embd);
3099
+ batch.add_text (prompt_tokens[slot.n_past ], slot.n_past , slot.id , need_embd);
3061
3100
3062
3101
if (slot.params .cache_prompt ) {
3063
3102
slot.cache_tokens .push_back (prompt_tokens[slot.n_past ]);
@@ -3067,13 +3106,13 @@ struct server_context {
3067
3106
slot.n_past ++;
3068
3107
}
3069
3108
3070
- SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , llama_batch_ext_get_n_tokens ( batch.get () ), (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
3109
+ SLT_INF (slot, " prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n " , slot.n_past , batch.get_n_tokens ( ), (float ) slot.n_prompt_tokens_processed / slot.n_prompt_tokens );
3071
3110
3072
3111
// entire prompt has been processed
3073
3112
if (slot.n_past == slot.n_prompt_tokens ) {
3074
3113
slot.state = SLOT_STATE_DONE_PROMPT;
3075
3114
3076
- GGML_ASSERT (llama_batch_ext_get_n_tokens ( batch.get () ) > 0 );
3115
+ GGML_ASSERT (batch.get_n_tokens ( ) > 0 );
3077
3116
3078
3117
common_sampler_reset (slot.smpl );
3079
3118
@@ -3083,27 +3122,27 @@ struct server_context {
3083
3122
}
3084
3123
3085
3124
// extract the logits only for the last token
3086
- llama_batch_ext_set_logits_last ( batch.get () );
3125
+ batch.set_logits_last ( );
3087
3126
3088
3127
slot.n_decoded = 0 ;
3089
- slot.i_batch = llama_batch_ext_get_n_tokens ( batch.get () ) - 1 ;
3128
+ slot.i_batch = batch.get_n_tokens ( ) - 1 ;
3090
3129
3091
- SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , llama_batch_ext_get_n_tokens ( batch.get () ));
3130
+ SLT_INF (slot, " prompt done, n_past = %d, n_tokens = %d\n " , slot.n_past , batch.get_n_tokens ( ));
3092
3131
}
3093
3132
}
3094
3133
3095
- if (llama_batch_ext_get_n_tokens ( batch.get () ) >= n_batch) {
3134
+ if (batch.get_n_tokens ( ) >= n_batch) {
3096
3135
break ;
3097
3136
}
3098
3137
}
3099
3138
}
3100
3139
3101
- if (llama_batch_ext_get_n_tokens ( batch.get () ) == 0 ) {
3140
+ if (batch.get_n_tokens ( ) == 0 ) {
3102
3141
SRV_WRN (" %s" , " no tokens to decode\n " );
3103
3142
return ;
3104
3143
}
3105
3144
3106
- SRV_DBG (" decoding batch, n_tokens = %d\n " , llama_batch_ext_get_n_tokens ( batch.get () ));
3145
+ SRV_DBG (" decoding batch, n_tokens = %d\n " , batch.get_n_tokens ( ));
3107
3146
3108
3147
if (slot_batched) {
3109
3148
// make sure we're in the right embedding mode
@@ -3113,12 +3152,12 @@ struct server_context {
3113
3152
}
3114
3153
3115
3154
// process the created batch of tokens
3116
- for (int32_t i = 0 ; i < llama_batch_ext_get_n_tokens ( batch.get () ); i += n_batch) {
3117
- const int32_t n_tokens = std::min (n_batch, llama_batch_ext_get_n_tokens ( batch.get () ) - i);
3155
+ for (int32_t i = 0 ; i < batch.get_n_tokens ( ); i += n_batch) {
3156
+ const int32_t n_tokens = std::min (n_batch, batch.get_n_tokens ( ) - i);
3118
3157
3119
- llama_batch_ext_ptr batch_view ( llama_batch_ext_get_view ( batch.get (), i, n_tokens) );
3158
+ server_batch batch_view = batch.get_view ( i, n_tokens);
3120
3159
3121
- const int ret = llama_decode_ext (ctx, batch_view.get ());
3160
+ const int ret = llama_decode_ext (ctx, batch_view.batch . get ());
3122
3161
metrics.on_decoded (slots);
3123
3162
3124
3163
if (ret != 0 ) {
@@ -3253,17 +3292,16 @@ struct server_context {
3253
3292
}
3254
3293
3255
3294
// construct the speculation batch
3256
- llama_batch_ext_clear (slot.batch_spec .get ());
3257
- std::array<llama_token, 1 > seq_id = { slot.id };
3258
- llama_batch_ext_add_text (slot.batch_spec .get (), id, slot.n_past , seq_id.data (), seq_id.size (), true );
3295
+ slot.batch_spec .clear ();
3296
+ slot.batch_spec .add_text (id, slot.n_past , slot.id , true );
3259
3297
3260
3298
for (size_t i = 0 ; i < draft.size (); ++i) {
3261
- llama_batch_ext_add_text ( slot.batch_spec .get (), draft[i], slot.n_past + 1 , seq_id. data (), seq_id. size () , true );
3299
+ slot.batch_spec .add_text ( draft[i], slot.n_past + 1 + i, slot. id , true );
3262
3300
}
3263
3301
3264
- SLT_DBG (slot, " decoding speculative batch, size = %d\n " , llama_batch_ext_get_n_tokens ( slot.batch_spec .get () ));
3302
+ SLT_DBG (slot, " decoding speculative batch, size = %d\n " , slot.batch_spec .get_n_tokens ( ));
3265
3303
3266
- llama_decode_ext (ctx, slot.batch_spec .get ());
3304
+ llama_decode_ext (ctx, slot.batch_spec .batch . get ());
3267
3305
3268
3306
// the accepted tokens from the speculation
3269
3307
const auto ids = common_sampler_sample_and_accept_n (slot.smpl , ctx, draft);
0 commit comments