Skip to content

Commit 1d6ba97

Browse files
committed
remove token_info API
1 parent 1170135 commit 1d6ba97

File tree

3 files changed

+82
-69
lines changed

3 files changed

+82
-69
lines changed

examples/server/server.cpp

Lines changed: 82 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,14 +1205,55 @@ struct server_task_result_apply_lora : server_task_result {
12051205
}
12061206
};
12071207

1208+
struct server_batch {
1209+
llama_batch_ext_ptr batch;
1210+
struct batch_token {
1211+
llama_token token;
1212+
llama_seq_id seq_id;
1213+
bool logits;
1214+
};
1215+
std::vector<batch_token> tokens;
1216+
server_batch() = default;
1217+
server_batch(int32_t n_tokens, int32_t n_seq_max) {
1218+
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
1219+
tokens.reserve(n_tokens);
1220+
}
1221+
void clear() {
1222+
llama_batch_ext_clear(batch.get());
1223+
tokens.clear();
1224+
}
1225+
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
1226+
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
1227+
tokens.push_back({token, seq_id, logits});
1228+
}
1229+
void set_logits_last() {
1230+
if (!tokens.empty()) {
1231+
llama_batch_ext_set_logits_last(batch.get());
1232+
tokens.back().logits = true;
1233+
}
1234+
}
1235+
int32_t get_n_tokens() const {
1236+
return (int32_t)tokens.size();
1237+
}
1238+
server_batch get_view(int32_t offset, int32_t n_tokens) {
1239+
server_batch view;
1240+
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
1241+
view.tokens.reserve(n_tokens);
1242+
for (int32_t i = 0; i < n_tokens; i++) {
1243+
view.tokens.push_back(tokens[offset + i]);
1244+
}
1245+
return view;
1246+
}
1247+
};
1248+
12081249
struct server_slot {
12091250
int id;
12101251
int id_task = -1;
12111252

12121253
// only used for completion/embedding/infill/rerank
12131254
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
12141255

1215-
llama_batch_ext_ptr batch_spec;
1256+
server_batch batch_spec;
12161257

12171258
llama_context * ctx = nullptr;
12181259
llama_context * ctx_dft = nullptr;
@@ -1784,7 +1825,7 @@ struct server_context {
17841825

17851826
llama_context_params cparams_dft;
17861827

1787-
llama_batch_ext_ptr batch;
1828+
server_batch batch;
17881829

17891830
bool clean_kv_cache = true;
17901831
bool add_bos_token = true;
@@ -1909,7 +1950,7 @@ struct server_context {
19091950
slot.n_predict = params_base.n_predict;
19101951

19111952
if (model_dft) {
1912-
slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1));
1953+
slot.batch_spec = server_batch(params_base.speculative.n_max + 1, 1);
19131954

19141955
slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
19151956
if (slot.ctx_dft == nullptr) {
@@ -1945,7 +1986,7 @@ struct server_context {
19451986
const int32_t n_batch = llama_n_batch(ctx);
19461987

19471988
// only a single seq_id per token is needed
1948-
batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1));
1989+
batch = server_batch(std::max(n_batch, params_base.n_parallel), 1);
19491990
}
19501991

19511992
metrics.init();
@@ -2063,7 +2104,7 @@ struct server_context {
20632104
}
20642105

20652106
if (slot.ctx_dft) {
2066-
slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1));
2107+
slot.batch_spec = server_batch(slot.params.speculative.n_max + 1, 1);
20672108
}
20682109

20692110
slot.state = SLOT_STATE_STARTED;
@@ -2371,7 +2412,7 @@ struct server_context {
23712412
queue_results.send(std::move(res));
23722413
}
23732414

2374-
void send_embedding(const server_slot & slot, llama_batch_ext_ptr & batch) {
2415+
void send_embedding(const server_slot & slot, server_batch & batch) {
23752416
auto res = std::make_unique<server_task_result_embd>();
23762417
res->id = slot.id_task;
23772418
res->index = slot.index;
@@ -2382,19 +2423,19 @@ struct server_context {
23822423

23832424
std::vector<float> embd_res(n_embd, 0.0f);
23842425

2385-
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) {
2386-
llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i);
2387-
if (!tok.logits || tok.seq_id[0] != slot.id) {
2426+
for (int i = 0; i < batch.get_n_tokens(); ++i) {
2427+
auto tok = batch.tokens[i];
2428+
if (!tok.logits || tok.seq_id != slot.id) {
23882429
continue;
23892430
}
23902431

2391-
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
2432+
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
23922433
if (embd == NULL) {
23932434
embd = llama_get_embeddings_ith(ctx, i);
23942435
}
23952436

23962437
if (embd == NULL) {
2397-
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);
2438+
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
23982439

23992440
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
24002441
continue;
@@ -2415,25 +2456,25 @@ struct server_context {
24152456
queue_results.send(std::move(res));
24162457
}
24172458

2418-
void send_rerank(const server_slot & slot, llama_batch_ext_ptr & batch) {
2459+
void send_rerank(const server_slot & slot, server_batch & batch) {
24192460
auto res = std::make_unique<server_task_result_rerank>();
24202461
res->id = slot.id_task;
24212462
res->index = slot.index;
24222463
res->n_tokens = slot.n_prompt_tokens;
24232464

2424-
for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) {
2425-
llama_batch_ext_token_info tok = llama_batch_ext_get_token_info(batch.get(), i);
2426-
if (!tok.logits || tok.seq_id[0] != slot.id) {
2465+
for (int i = 0; i < batch.get_n_tokens(); ++i) {
2466+
auto tok = batch.tokens[i];
2467+
if (!tok.logits || tok.seq_id != slot.id) {
24272468
continue;
24282469
}
24292470

2430-
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id[0]);
2471+
const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id);
24312472
if (embd == NULL) {
24322473
embd = llama_get_embeddings_ith(ctx, i);
24332474
}
24342475

24352476
if (embd == NULL) {
2436-
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id[0]);
2477+
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id);
24372478

24382479
res->score = -1e6;
24392480
continue;
@@ -2824,7 +2865,7 @@ struct server_context {
28242865
}
28252866

28262867
// start populating the batch for this iteration
2827-
llama_batch_ext_clear(batch.get());
2868+
batch.clear();
28282869

28292870
// track if given slot can be batched with slots already in the batch
28302871
server_slot * slot_batched = nullptr;
@@ -2846,10 +2887,9 @@ struct server_context {
28462887
continue;
28472888
}
28482889

2849-
slot.i_batch = llama_batch_ext_get_n_tokens(batch.get());
2890+
slot.i_batch = batch.get_n_tokens();
28502891

2851-
std::array<llama_token, 1> seq_id = { slot.id };
2852-
llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, seq_id.data(), seq_id.size(), true);
2892+
batch.add_text(slot.sampled, slot.n_past, slot.id, true);
28532893

28542894
slot.n_past += 1;
28552895

@@ -2866,7 +2906,7 @@ struct server_context {
28662906
int32_t n_ubatch = llama_n_ubatch(ctx);
28672907

28682908
// next, batch any pending prompts without exceeding n_batch
2869-
if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) {
2909+
if (params_base.cont_batching || batch.get_n_tokens() == 0) {
28702910
for (auto & slot : slots) {
28712911
// check if we can batch this slot with the previous one
28722912
if (slot.is_processing()) {
@@ -3032,7 +3072,7 @@ struct server_context {
30323072
// non-causal tasks require to fit the entire prompt in the physical batch
30333073
if (slot.is_non_causal()) {
30343074
// cannot fit the prompt in the current batch - will try next iter
3035-
if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) {
3075+
if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) {
30363076
continue;
30373077
}
30383078
}
@@ -3052,12 +3092,11 @@ struct server_context {
30523092
slot.cache_tokens.resize(slot.n_past);
30533093

30543094
// add prompt tokens for processing in the current batch
3055-
while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) {
3095+
while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) {
30563096
// without pooling, we want to output the embeddings for all the tokens in the batch
30573097
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
30583098

3059-
std::array<llama_token, 1> seq_id = { slot.id };
3060-
llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, seq_id.data(), seq_id.size(), need_embd);
3099+
batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd);
30613100

30623101
if (slot.params.cache_prompt) {
30633102
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -3067,13 +3106,13 @@ struct server_context {
30673106
slot.n_past++;
30683107
}
30693108

3070-
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
3109+
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
30713110

30723111
// entire prompt has been processed
30733112
if (slot.n_past == slot.n_prompt_tokens) {
30743113
slot.state = SLOT_STATE_DONE_PROMPT;
30753114

3076-
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0);
3115+
GGML_ASSERT(batch.get_n_tokens() > 0);
30773116

30783117
common_sampler_reset(slot.smpl);
30793118

@@ -3083,27 +3122,27 @@ struct server_context {
30833122
}
30843123

30853124
// extract the logits only for the last token
3086-
llama_batch_ext_set_logits_last(batch.get());
3125+
batch.set_logits_last();
30873126

30883127
slot.n_decoded = 0;
3089-
slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1;
3128+
slot.i_batch = batch.get_n_tokens() - 1;
30903129

3091-
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get()));
3130+
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens());
30923131
}
30933132
}
30943133

3095-
if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) {
3134+
if (batch.get_n_tokens() >= n_batch) {
30963135
break;
30973136
}
30983137
}
30993138
}
31003139

3101-
if (llama_batch_ext_get_n_tokens(batch.get()) == 0) {
3140+
if (batch.get_n_tokens() == 0) {
31023141
SRV_WRN("%s", "no tokens to decode\n");
31033142
return;
31043143
}
31053144

3106-
SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get()));
3145+
SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens());
31073146

31083147
if (slot_batched) {
31093148
// make sure we're in the right embedding mode
@@ -3113,12 +3152,12 @@ struct server_context {
31133152
}
31143153

31153154
// process the created batch of tokens
3116-
for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) {
3117-
const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i);
3155+
for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) {
3156+
const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i);
31183157

3119-
llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens));
3158+
server_batch batch_view = batch.get_view(i, n_tokens);
31203159

3121-
const int ret = llama_decode_ext(ctx, batch_view.get());
3160+
const int ret = llama_decode_ext(ctx, batch_view.batch.get());
31223161
metrics.on_decoded(slots);
31233162

31243163
if (ret != 0) {
@@ -3253,17 +3292,16 @@ struct server_context {
32533292
}
32543293

32553294
// construct the speculation batch
3256-
llama_batch_ext_clear(slot.batch_spec.get());
3257-
std::array<llama_token, 1> seq_id = { slot.id };
3258-
llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, seq_id.data(), seq_id.size(), true);
3295+
slot.batch_spec.clear();
3296+
slot.batch_spec.add_text(id, slot.n_past, slot.id, true);
32593297

32603298
for (size_t i = 0; i < draft.size(); ++i) {
3261-
llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1, seq_id.data(), seq_id.size(), true);
3299+
slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true);
32623300
}
32633301

3264-
SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get()));
3302+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens());
32653303

3266-
llama_decode_ext(ctx, slot.batch_spec.get());
3304+
llama_decode_ext(ctx, slot.batch_spec.batch.get());
32673305

32683306
// the accepted tokens from the speculation
32693307
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);

include/llama.h

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -263,14 +263,6 @@ extern "C" {
263263
// It can contain text tokens and embeddings for one or many sequences
264264
struct llama_batch_ext;
265265

266-
struct llama_batch_ext_token_info {
267-
llama_token token;
268-
llama_pos pos;
269-
int32_t n_seq_id;
270-
llama_seq_id * seq_id;
271-
int8_t logits;
272-
};
273-
274266
enum llama_model_kv_override_type {
275267
LLAMA_KV_OVERRIDE_TYPE_INT,
276268
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
@@ -896,10 +888,6 @@ extern "C" {
896888
// Get the number of tokens in the batch
897889
LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch);
898890

899-
LLAMA_API struct llama_batch_ext_token_info llama_batch_ext_get_token_info(
900-
struct llama_batch_ext * batch,
901-
int32_t i);
902-
903891
// Add text tokens to the batch
904892
// Return values:
905893
// 0 : success

src/llama-batch.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -480,19 +480,6 @@ struct llama_batch_ext * llama_batch_ext_get_view(
480480
return batch_view;
481481
}
482482

483-
struct llama_batch_ext_token_info llama_batch_ext_get_token_info(
484-
struct llama_batch_ext * batch,
485-
int32_t i) {
486-
GGML_ASSERT(i >= 0 && i < batch->n_tokens);
487-
return llama_batch_ext_token_info{
488-
/*token =*/ batch->token [i],
489-
/*pos =*/ batch->pos [i],
490-
/*n_seq_id =*/ batch->n_seq_id[i],
491-
/*seq_id =*/ batch->seq_id [i],
492-
/*logits =*/ batch->logits [i],
493-
};
494-
}
495-
496483
void llama_batch_ext_free(struct llama_batch_ext * batch) {
497484
// do not free the members if it's a view
498485
if (!batch->is_view) {

0 commit comments

Comments
 (0)