Skip to content

Commit bd51d63

Browse files
authored
Merge pull request #16 from ggml-org/xsn/private_batch_api_pooling_none
server : avoid common_batch
2 parents 76fd7d6 + b8b1732 commit bd51d63

File tree

3 files changed

+86
-137
lines changed

3 files changed

+86
-137
lines changed

common/common.h

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -565,70 +565,6 @@ std::pair<std::string, std::string> common_get_hf_file(
565565
// clear LoRA adapters from context, then apply new list of adapters
566566
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
567567

568-
//
569-
// Batch utils
570-
//
571-
572-
// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
573-
// this is meant to be temporary
574-
struct common_batch {
575-
llama_batch_ext_ptr batch;
576-
struct batch_token {
577-
llama_token token;
578-
llama_seq_id seq_id; // only support single seq for now
579-
bool logits;
580-
};
581-
std::vector<batch_token> tokens;
582-
int n_outputs = 0;
583-
common_batch() = default;
584-
common_batch(int32_t n_tokens, int32_t n_seq_max) {
585-
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
586-
tokens.reserve(n_tokens);
587-
}
588-
void clear() {
589-
llama_batch_ext_clear(batch.get());
590-
tokens.clear();
591-
}
592-
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
593-
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
594-
tokens.push_back({token, seq_id, logits});
595-
if (logits) {
596-
n_outputs++;
597-
}
598-
}
599-
void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
600-
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
601-
tokens.push_back({token, seq_ids[0], logits});
602-
if (logits) {
603-
n_outputs++;
604-
}
605-
}
606-
void set_logits_last() {
607-
if (!tokens.empty()) {
608-
llama_batch_ext_set_output_last(batch.get());
609-
tokens.back().logits = true;
610-
}
611-
}
612-
int32_t get_n_tokens() const {
613-
return (int32_t)tokens.size();
614-
}
615-
llama_batch_ext * get() {
616-
return batch.get();
617-
}
618-
common_batch get_view(int32_t offset, int32_t n_tokens) {
619-
common_batch view;
620-
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
621-
view.tokens.reserve(n_tokens);
622-
for (int32_t i = 0; i < n_tokens; i++) {
623-
view.tokens.push_back(tokens[offset + i]);
624-
if (tokens[offset + i].logits) {
625-
view.n_outputs++;
626-
}
627-
}
628-
return view;
629-
}
630-
};
631-
632568
//
633569
// Token utils
634570
//

0 commit comments

Comments
 (0)