Skip to content

Commit 7b7db0b

Browse files
committed
llama : logits_all has priority over batch->logits
Otherwise, the server embeddings tests failed. This was likely an existing problem but was only detected here because of an additional assertion.
1 parent 2e4adb4 commit 7b7db0b

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

src/llama.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2898,7 +2898,12 @@ struct llama_sbatch {
28982898
}
28992899
}
29002900
}
2901-
if (batch->logits) {
2901+
if (logits_all) {
2902+
for (size_t i = 0; i < length; ++i) {
2903+
ubatch.output[ubatch.n_tokens + i] = 1;
2904+
out_ids.push_back(ids[seq.offset + i]);
2905+
}
2906+
} else if (batch->logits) {
29022907
if (ubatch.equal_seqs) {
29032908
for (size_t i = 0; i < length; ++i) {
29042909
size_t id = ids[seq.offset + i];
@@ -2913,11 +2918,6 @@ struct llama_sbatch {
29132918
if (ubatch.output[i] != 0) { out_ids.push_back(seq.offset + i); }
29142919
}
29152920
}
2916-
} else if (logits_all) {
2917-
for (size_t i = 0; i < length; ++i) {
2918-
ubatch.output[ubatch.n_tokens + i] = 1;
2919-
out_ids.push_back(ids[seq.offset + i]);
2920-
}
29212921
} else {
29222922
// only get last output
29232923
for (size_t i = 0; i < length; ++i) {
@@ -15088,7 +15088,7 @@ static int llama_decode_internal(
1508815088
};
1508915089

1509015090
while (lctx.sbatch.n_tokens > 0) {
15091-
// For now, only use equal splits for recurrent or hybrid model architectures
15091+
// For now, only use equal splits for recurrent model architectures
1509215092
llama_ubatch u_batch = kv_self.recurrent ? lctx.sbatch.split_equal(n_ubatch) : lctx.sbatch.split_simple(n_ubatch);
1509315093
const uint32_t n_tokens = u_batch.n_tokens;
1509415094

0 commit comments

Comments
 (0)