Skip to content

server : avoid common_batch #16

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 0 additions & 64 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -565,70 +565,6 @@ std::pair<std::string, std::string> common_get_hf_file(
// clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);

//
// Batch utils
//

// convenient wrapper around llama_batch_ext, to provide a way to get embeddings positions
// this is meant to be temporary
struct common_batch {
llama_batch_ext_ptr batch;
struct batch_token {
llama_token token;
llama_seq_id seq_id; // only support single seq for now
bool logits;
};
std::vector<batch_token> tokens;
int n_outputs = 0;
common_batch() = default;
common_batch(int32_t n_tokens, int32_t n_seq_max) {
batch.reset(llama_batch_ext_init(n_tokens, n_seq_max));
tokens.reserve(n_tokens);
}
void clear() {
llama_batch_ext_clear(batch.get());
tokens.clear();
}
void add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, &seq_id, 1, logits);
tokens.push_back({token, seq_id, logits});
if (logits) {
n_outputs++;
}
}
void add_text_multi_seq(llama_token token, llama_pos pos, std::vector<llama_seq_id> seq_ids, bool logits) {
llama_batch_ext_add_text(batch.get(), token, pos, seq_ids.data(), seq_ids.size(), logits);
tokens.push_back({token, seq_ids[0], logits});
if (logits) {
n_outputs++;
}
}
void set_logits_last() {
if (!tokens.empty()) {
llama_batch_ext_set_output_last(batch.get());
tokens.back().logits = true;
}
}
int32_t get_n_tokens() const {
return (int32_t)tokens.size();
}
llama_batch_ext * get() {
return batch.get();
}
common_batch get_view(int32_t offset, int32_t n_tokens) {
common_batch view;
view.batch = llama_batch_ext_ptr(llama_batch_ext_get_view(batch.get(), offset, n_tokens));
view.tokens.reserve(n_tokens);
for (int32_t i = 0; i < n_tokens; i++) {
view.tokens.push_back(tokens[offset + i]);
if (tokens[offset + i].logits) {
view.n_outputs++;
}
}
return view;
}
};

//
// Token utils
//
Expand Down
Loading