Skip to content

Commit cbb5dd7

Browse files
committed
change batch.logits to batch.output
1 parent 23fd453 commit cbb5dd7

File tree

18 files changed

+60
-60
lines changed

18 files changed

+60
-60
lines changed

common/common.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2666,7 +2666,7 @@ void llama_batch_add(
26662666
for (size_t i = 0; i < seq_ids.size(); ++i) {
26672667
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
26682668
}
2669-
batch.logits [batch.n_tokens] = logits;
2669+
batch.output [batch.n_tokens] = logits;
26702670

26712671
batch.n_tokens++;
26722672
}

common/log.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -686,7 +686,7 @@ inline std::string LOG_BATCH_TOSTR_PRETTY(const C & ctx, const B & batch)
686686
<< ":pos " << std::to_string(batch.pos[i])
687687
<< ":n_seq_id " << std::to_string(batch.n_seq_id[i])
688688
<< ":seq_id " << std::to_string(batch.seq_id[i][0])
689-
<< ":logits " << std::to_string(batch.logits[i]);
689+
<< ":logits " << std::to_string(batch.output[i]);
690690
}
691691
buf << " ]";
692692

examples/batched-bench/batched-bench.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ int main(int argc, char ** argv) {
9494
batch.pos + i,
9595
batch.n_seq_id + i,
9696
batch.seq_id + i,
97-
batch.logits + i,
97+
batch.output + i,
9898
0, 0, 0, // unused
9999
};
100100

@@ -149,7 +149,7 @@ int main(int argc, char ** argv) {
149149
llama_batch_add(batch, 0, i, { j }, false);
150150
}
151151
}
152-
batch.logits[batch.n_tokens - 1] = true;
152+
batch.output[batch.n_tokens - 1] = true;
153153

154154
const auto t_pp_start = ggml_time_us();
155155

examples/batched.swift/Sources/main.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ for (i, token) in tokens.enumerated() {
8686
if let seq_id = batch.seq_id[i] {
8787
seq_id[0] = 0
8888
}
89-
batch.logits[i] = 0
89+
batch.output[i] = 0
9090
}
9191

9292
// llama_decode will output logits only for the last token of the prompt
93-
batch.logits[Int(batch.n_tokens) - 1] = 1
93+
batch.output[Int(batch.n_tokens) - 1] = 1
9494

9595
if llama_decode(context, batch) != 0 {
9696
print("llama_decode() failed")
@@ -178,7 +178,7 @@ while n_cur <= n_len {
178178
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
179179
seq_id[0] = Int32(i)
180180
}
181-
batch.logits[Int(batch.n_tokens)] = 1
181+
batch.output[Int(batch.n_tokens)] = 1
182182

183183
i_batch[i] = batch.n_tokens
184184

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ int main(int argc, char ** argv) {
122122
}
123123

124124
// llama_decode will output logits only for the last token of the prompt
125-
batch.logits[batch.n_tokens - 1] = true;
125+
batch.output[batch.n_tokens - 1] = true;
126126

127127
if (llama_decode(ctx, batch) != 0) {
128128
LOG_TEE("%s: llama_decode() failed\n", __func__);

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
5252
}
5353

5454
for (int i = 0; i < batch.n_tokens; i++) {
55-
if (!batch.logits[i]) {
55+
if (!batch.output[i]) {
5656
continue;
5757
}
5858

examples/gritlm/gritlm.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,21 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
102102
llama_set_embeddings(ctx, false);
103103
llama_set_causal_attn(ctx, true);
104104

105-
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
105+
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
106106

107107
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
108108
int32_t i_current_token = 0;
109109

110110
while (true) {
111-
llama_batch_clear(bat);
111+
llama_batch_clear(batch);
112112
auto n_inputs = (int32_t)inputs.size();
113113
for (int32_t i = 0; i < n_inputs; i++) {
114-
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
114+
llama_batch_add(batch, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
115115
}
116116
inputs.clear();
117117

118-
llama_decode(ctx, bat);
119-
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
118+
llama_decode(ctx, batch);
119+
auto logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
120120

121121
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
122122
auto n_candidates = (int32_t)candidates.size();
@@ -145,7 +145,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
145145
std::printf("\n");
146146
}
147147

148-
llama_batch_free(bat);
148+
llama_batch_free(batch);
149149

150150
return result;
151151
}

examples/imatrix/imatrix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
513513
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
514514
}
515515

516-
// TODO: use batch.logits to save computations instead of relying on logits_all == true
516+
// TODO: use batch.output to save computations instead of relying on logits_all == true
517517
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
518518
fprintf(stderr, "%s : failed to eval\n", __func__);
519519
return false;

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
193193
llama_batch_add(*batch, 0, i, { 0 }, false);
194194
}
195195

196-
batch->logits[batch->n_tokens - 1] = true;
196+
batch->output[batch->n_tokens - 1] = true;
197197
llama_kv_cache_clear(context);
198198

199199
const auto t_pp_start = ggml_time_us();
@@ -306,7 +306,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens,
306306
for (int i = 0; i < n_tokens; ++i) {
307307
batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max);
308308
}
309-
batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
309+
batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
310310

311311
return reinterpret_cast<jlong>(batch);
312312
}
@@ -363,7 +363,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
363363
}
364364

365365
// llama_decode will output logits only for the last token of the prompt
366-
batch->logits[batch->n_tokens - 1] = true;
366+
batch->output[batch->n_tokens - 1] = true;
367367

368368
if (llama_decode(context, *batch) != 0) {
369369
LOGe("llama_decode() failed");

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
1616
for i in 0..<seq_ids.count {
1717
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
1818
}
19-
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
19+
batch.output [Int(batch.n_tokens)] = logits ? 1 : 0
2020

2121
batch.n_tokens += 1
2222
}
@@ -132,7 +132,7 @@ actor LlamaContext {
132132
let i = Int(i1)
133133
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
134134
}
135-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
135+
batch.output[Int(batch.n_tokens) - 1] = 1 // true
136136

137137
if llama_decode(context, batch) != 0 {
138138
print("llama_decode() failed")
@@ -214,7 +214,7 @@ actor LlamaContext {
214214
for i in 0..<n_tokens {
215215
llama_batch_add(&batch, 0, Int32(i), [0], false)
216216
}
217-
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
217+
batch.output[Int(batch.n_tokens) - 1] = 1 // true
218218

219219
llama_kv_cache_clear(context)
220220

0 commit comments

Comments
 (0)