@@ -54,6 +54,7 @@ struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_de
54
54
llama_batch_ext_add_text (batch, tokens[i], pos0 + i, &seq_id, 1 , false );
55
55
}
56
56
if (output_last) {
57
+ // TODO: somehow return the output ID
57
58
llama_batch_ext_set_output_last (batch);
58
59
}
59
60
return llama_batch_ext_ptr (batch);
@@ -85,18 +86,20 @@ struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_de
85
86
86
87
// Wrapper to add a single token to the batch, support multiple sequence IDs
87
88
int32_t add_text (llama_token token, llama_pos pos, const std::vector<llama_seq_id> & seq_id, bool output_last) {
88
- int32_t output_id = llama_batch_ext_add_text (this ->get (), token, pos, seq_id.data (), seq_id.size (), false );
89
+ int32_t output_id = -1 ;
90
+ llama_batch_ext_add_text (this ->get (), token, pos, seq_id.data (), seq_id.size (), false );
89
91
if (output_last) {
90
- llama_batch_ext_set_output_last (this ->get ());
92
+ output_id = llama_batch_ext_set_output_last (this ->get ());
91
93
}
92
94
return output_id;
93
95
}
94
96
95
97
// Wrapper to add a single token to the batch (single sequence ID)
96
98
int32_t add_text (llama_token token, llama_pos pos, llama_seq_id seq_id, bool output_last) {
97
- int32_t output_id = llama_batch_ext_add_text (this ->get (), token, pos, &seq_id, 1 , false );
99
+ int32_t output_id = -1 ;
100
+ llama_batch_ext_add_text (this ->get (), token, pos, &seq_id, 1 , false );
98
101
if (output_last) {
99
- llama_batch_ext_set_output_last (this ->get ());
102
+ output_id = llama_batch_ext_set_output_last (this ->get ());
100
103
}
101
104
return output_id;
102
105
}
@@ -105,10 +108,10 @@ struct llama_batch_ext_ptr : std::unique_ptr<llama_batch_ext, llama_batch_ext_de
105
108
int32_t add_seq (std::vector<llama_token> & tokens, llama_pos pos0, llama_seq_id seq_id, bool output_last) {
106
109
int32_t output_id = -1 ;
107
110
for (size_t i = 0 ; i < tokens.size (); i++) {
108
- output_id = llama_batch_ext_add_text (this ->get (), tokens[i], pos0 + i, &seq_id, 1 , false );
111
+ llama_batch_ext_add_text (this ->get (), tokens[i], pos0 + i, &seq_id, 1 , false );
109
112
}
110
113
if (output_last) {
111
- llama_batch_ext_set_output_last (this ->get ());
114
+ output_id = llama_batch_ext_set_output_last (this ->get ());
112
115
}
113
116
return output_id;
114
117
}
0 commit comments