diff --git a/llama.cpp b/llama.cpp index 3f5d663cf1ed3..9dae1f4e2c42a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -5855,7 +5855,11 @@ static int llama_decode_internal( { auto & logits_out = lctx.logits; - if (batch.logits) { + if (lctx.logits_all) { + logits_out.resize(n_vocab * n_tokens); + memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); + } else { + GGML_ASSERT(batch.logits); logits_out.resize(n_vocab * n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { if (batch.logits[i] == 0) { @@ -5863,12 +5867,6 @@ static int llama_decode_internal( } memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); } - } else if (lctx.logits_all) { - logits_out.resize(n_vocab * n_tokens); - memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); - } else { - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); } }