Skip to content

Commit c1f4a78

Browse files
committed
correct output_id for llama-cpp header
1 parent 8ec0ff9 commit c1f4a78

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

include/llama-cpp.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_de
5454
llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false);
5555
}
5656
if (output_last) {
57+
// TODO: somehow return the output ID
5758
llama_batch_ext_set_output_last(batch);
5859
}
5960
return llama_batch_ext_ptr(batch);
@@ -85,18 +86,20 @@ struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_de
8586

8687
// Wrapper to add a single token to the batch, support multiple sequence IDs
8788
int32_t add_text(llama_token token, llama_pos pos, const std::vector<llama_seq_id> & seq_id, bool output_last) {
88-
int32_t output_id = llama_batch_ext_add_text(this->get(), token, pos, seq_id.data(), seq_id.size(), false);
89+
int32_t output_id = -1;
90+
llama_batch_ext_add_text(this->get(), token, pos, seq_id.data(), seq_id.size(), false);
8991
if (output_last) {
90-
llama_batch_ext_set_output_last(this->get());
92+
output_id = llama_batch_ext_set_output_last(this->get());
9193
}
9294
return output_id;
9395
}
9496

9597
// Wrapper to add a single token to the batch (single sequence ID)
9698
int32_t add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool output_last) {
97-
int32_t output_id = llama_batch_ext_add_text(this->get(), token, pos, &seq_id, 1, false);
99+
int32_t output_id = -1;
100+
llama_batch_ext_add_text(this->get(), token, pos, &seq_id, 1, false);
98101
if (output_last) {
99-
llama_batch_ext_set_output_last(this->get());
102+
output_id = llama_batch_ext_set_output_last(this->get());
100103
}
101104
return output_id;
102105
}
@@ -105,10 +108,10 @@ struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_de
105108
int32_t add_seq(std::vector<llama_token> & tokens, llama_pos pos0, llama_seq_id seq_id, bool output_last) {
106109
int32_t output_id = -1;
107110
for (size_t i = 0; i < tokens.size(); i++) {
108-
output_id = llama_batch_ext_add_text(this->get(), tokens[i], pos0 + i, &seq_id, 1, false);
111+
llama_batch_ext_add_text(this->get(), tokens[i], pos0 + i, &seq_id, 1, false);
109112
}
110113
if (output_last) {
111-
llama_batch_ext_set_output_last(this->get());
114+
output_id = llama_batch_ext_set_output_last(this->get());
112115
}
113116
return output_id;
114117
}

0 commit comments

Comments
 (0)