diff --git a/common/common.h b/common/common.h index 5fe149ff8c991..197108be0ebba 100644 --- a/common/common.h +++ b/common/common.h @@ -565,70 +565,6 @@ std::pair common_get_hf_file( // clear LoRA adapters from context, then apply new list of adapters void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); -// -// Batch utils -// - -// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions -// this is meant to be temporary -struct common_batch { - llama_batch_ext_ptr batch; - struct batch_token { - llama_token token; - llama_seq_id seq_id; // only support single seq for now - bool logits; - }; - std::vector tokens; - int n_outputs = 0; - common_batch() = default; - common_batch(int32_t n_tokens, int32_t n_seq_max) { - batch.reset(llama_batch_ext_init(n_tokens, n_seq_max)); - tokens.reserve(n_tokens); - } - void clear() { - llama_batch_ext_clear(batch.get()); - tokens.clear(); - } - void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) { - llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits); - tokens.push_back({token, seq_id, logits}); - if (logits) { - n_outputs++; - } - } - void add_text_multi_seq(llama_token token, llama_pos pos, std::vector seq_ids, bool logits) { - llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits); - tokens.push_back({token, seq_ids[0], logits}); - if (logits) { - n_outputs++; - } - } - void set_logits_last() { - if (!tokens.empty()) { - llama_batch_ext_set_output_last(batch.get()); - tokens.back().logits = true; - } - } - int32_t get_n_tokens() const { - return (int32_t)tokens.size(); - } - llama_batch_ext * get() { - return batch.get(); - } - common_batch get_view(int32_t offset, int32_t n_tokens) { - common_batch view; - view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens)); - view.tokens.reserve(n_tokens); - for (int32_t i = 0; i < n_tokens; i++) { - view.tokens.push_back(tokens[offset + i]); - if (tokens[offset + i].logits) { - view.n_outputs++; - } - } - return view; - } -}; - // // Token utils // diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 80daec9792e79..b99059511e7e7 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1224,7 +1224,7 @@ struct server_slot { // only used for completion/embedding/infill/rerank server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - common_batch batch_spec; + llama_batch_ext_ptr batch_spec; llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1248,7 +1248,7 @@ struct server_slot { int32_t n_past = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; - int32_t i_batch = -1; + int32_t i_batch = -1; // TODO: remove and use only sequence-based sampling int32_t n_predict = -1; // TODO: disambiguate from params.n_predict // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated @@ -1796,7 +1796,7 @@ struct server_context { llama_context_params cparams_dft; - common_batch batch; + llama_batch_ext_ptr batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1922,7 +1922,7 @@ struct server_context { slot.n_predict = params_base.n_predict; if (model_dft) { - slot.batch_spec = common_batch(params_base.speculative.n_max + 1, 1); + slot.batch_spec.reset(llama_batch_ext_init(params_base.speculative.n_max + 1, 1)); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1958,7 +1958,7 @@ struct server_context { const int32_t n_batch = llama_n_batch(ctx); // only a single seq_id per token is needed - batch = common_batch(std::max(n_batch, params_base.n_parallel), 1); + batch.reset(llama_batch_ext_init(std::max(n_batch, params_base.n_parallel), 1)); } metrics.init(); @@ -2093,7 +2093,7 @@ struct server_context { } if (slot.ctx_dft) { - slot.batch_spec = common_batch(slot.params.speculative.n_max + 1, 1); + slot.batch_spec.reset(llama_batch_ext_init(slot.params.speculative.n_max + 1, 1)); } slot.state = SLOT_STATE_STARTED; @@ -2401,7 +2401,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, common_batch & batch) { + void send_embedding(const server_slot & slot) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2410,34 +2410,40 @@ struct server_context { const int n_embd = llama_model_n_embd(model); - std::vector embd_res(n_embd, 0.0f); + const llama_seq_id seq_id = slot.id; - for (int i = 0; i < batch.get_n_tokens(); ++i) { - auto tok = batch.tokens[i]; - if (!tok.logits || tok.seq_id != slot.id) { - continue; - } + std::vector embd_res(n_embd, 0.0f); - const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + const float * embd = llama_get_embeddings_seq(ctx, seq_id); if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); + SLT_ERR(slot, "failed to get sequence embeddings, seq_id = %d\n", seq_id); res->embedding.push_back(std::vector(n_embd, 0.0f)); - continue; } - // normalize only when there is pooling // TODO: configurable - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, 2); - res->embedding.push_back(embd_res); - } else { - res->embedding.push_back({ embd, embd + n_embd }); - } + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + GGML_ABORT("embeddings without pooling is not supported yet"); + //for (int i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); ++i) { + // auto tok = batch.tokens[i]; + // if (!tok.logits || tok.seq_id != slot.id) { + // continue; + // } + + // const float * embd = llama_get_embeddings_ith(ctx, tok.seq_id); + // if (embd == NULL) { + // SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); + + // res->embedding.push_back(std::vector(n_embd, 0.0f)); + // continue; + // } + + // res->embedding.push_back({ embd, embd + n_embd }); + //} } SLT_DBG(slot, "%s", "sending embeddings\n"); @@ -2445,30 +2451,20 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, common_batch & batch) { + void send_rerank(const server_slot & slot) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; res->n_tokens = slot.n_prompt_tokens; - for (int i = 0; i < batch.get_n_tokens(); ++i) { - auto tok = batch.tokens[i]; - if (!tok.logits || tok.seq_id != slot.id) { - continue; - } + const llama_seq_id seq_id = slot.id; - const float * embd = llama_get_embeddings_seq(ctx, tok.seq_id); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); - - res->score = -1e6; - continue; - } + const float * embd = llama_get_embeddings_seq(ctx, seq_id); + if (embd == NULL) { + SLT_ERR(slot, "failed to get sequence embeddings, seq_id = %d\n", seq_id); + res->score = -1e6; + } else { res->score = embd[0]; } @@ -2854,7 +2850,7 @@ struct server_context { } // start populating the batch for this iteration - batch.clear(); + llama_batch_ext_clear(batch.get()); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2876,9 +2872,9 @@ struct server_context { continue; } - slot.i_batch = batch.get_n_tokens(); + slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()); - batch.add_text(slot.sampled, slot.n_past, slot.id, true); + llama_batch_ext_add_text(batch.get(), slot.sampled, slot.n_past, &slot.id, 1, true); slot.n_past += 1; @@ -2895,7 +2891,7 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.get_n_tokens() == 0) { + if (params_base.cont_batching || llama_batch_ext_get_n_tokens(batch.get()) == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3061,7 +3057,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.get_n_tokens() + slot.n_prompt_tokens > n_batch) { + if (llama_batch_ext_get_n_tokens(batch.get()) + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3081,11 +3077,11 @@ struct server_context { slot.cache_tokens.resize(slot.n_past); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.get_n_tokens() < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && llama_batch_ext_get_n_tokens(batch.get()) < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd); + llama_batch_ext_add_text(batch.get(), prompt_tokens[slot.n_past], slot.n_past, &slot.id, 1, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3095,13 +3091,14 @@ struct server_context { slot.n_past++; } - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.get_n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", + slot.n_past, llama_batch_ext_get_n_tokens(batch.get()), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.get_n_tokens() > 0); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch.get()) > 0); common_sampler_reset(slot.smpl); @@ -3111,27 +3108,27 @@ struct server_context { } // extract the logits only for the last token - batch.set_logits_last(); + llama_batch_ext_set_output_last(batch.get()); slot.n_decoded = 0; - slot.i_batch = batch.get_n_tokens() - 1; + slot.i_batch = llama_batch_ext_get_n_tokens(batch.get()) - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.get_n_tokens()); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, llama_batch_ext_get_n_tokens(batch.get())); } } - if (batch.get_n_tokens() >= n_batch) { + if (llama_batch_ext_get_n_tokens(batch.get()) >= n_batch) { break; } } } - if (batch.get_n_tokens() == 0) { + if (llama_batch_ext_get_n_tokens(batch.get()) == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.get_n_tokens()); + SRV_DBG("decoding batch, n_tokens = %d\n", llama_batch_ext_get_n_tokens(batch.get())); if (slot_batched) { // make sure we're in the right embedding mode @@ -3141,10 +3138,10 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < batch.get_n_tokens(); i += n_batch) { - const int32_t n_tokens = std::min(n_batch, batch.get_n_tokens() - i); + for (int32_t i = 0; i < llama_batch_ext_get_n_tokens(batch.get()); i += n_batch) { + const int32_t n_tokens = std::min(n_batch, llama_batch_ext_get_n_tokens(batch.get()) - i); - common_batch batch_view = batch.get_view(i, n_tokens); + llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens)); const int ret = llama_decode_ext(ctx, batch_view.get()); metrics.on_decoded(slots); @@ -3177,14 +3174,14 @@ struct server_context { if (slot.state == SLOT_STATE_DONE_PROMPT) { if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding - send_embedding(slot, batch_view); + send_embedding(slot); slot.release(); slot.i_batch = -1; continue; // continue loop of slots } if (slot.task_type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); + send_rerank(slot); slot.release(); slot.i_batch = -1; continue; // continue loop of slots @@ -3281,14 +3278,14 @@ struct server_context { } // construct the speculation batch - slot.batch_spec.clear(); - slot.batch_spec.add_text(id, slot.n_past, slot.id, true); + llama_batch_ext_clear(slot.batch_spec.get()); + llama_batch_ext_add_text(slot.batch_spec.get(), id, slot.n_past, &slot.id, 1, true); for (size_t i = 0; i < draft.size(); ++i) { - slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); + llama_batch_ext_add_text(slot.batch_spec.get(), draft[i], slot.n_past + 1 + i, &slot.id, 1, true); } - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.get_n_tokens()); + SLT_DBG(slot, "decoding speculative batch, size = %d\n", llama_batch_ext_get_n_tokens(slot.batch_spec.get())); llama_decode_ext(ctx, slot.batch_spec.get()); @@ -4147,6 +4144,11 @@ int main(int argc, char ** argv) { return; } + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_error(res, format_error_response("Pooling type 'none' is not yet supported. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } + // for the shape of input/content, see tokenize_input_prompts() json prompt; if (body.count("input") != 0) { @@ -4241,6 +4243,11 @@ int main(int argc, char ** argv) { return; } + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_error(res, format_error_response("Pooling type 'none' cannot be used with reranking. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } + const json body = json::parse(req.body); // TODO: implement diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 8b0eb42b0926f..889a759aea934 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -88,13 +88,19 @@ def test_embedding_pooling_none(): res = server.make_request("POST", "/embeddings", data={ "input": "hello hello hello", }) - assert res.status_code == 200 - assert 'embedding' in res.body[0] - assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special - # make sure embedding vector is not normalized - for x in res.body[0]['embedding']: - assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON + # /embeddings does not support pooling type 'none' + assert res.status_code == 400 + assert "error" in res.body + + # TODO: re-enable when we figure out how to support pooling type 'none' + #assert res.status_code == 200 + #assert 'embedding' in res.body[0] + #assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special + + ## make sure embedding vector is not normalized + #for x in res.body[0]['embedding']: + # assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON def test_embedding_pooling_none_oai():