Skip to content

llama : (proposal) return enum for llama_decode and llama_encode #9434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,21 @@ extern "C" {
llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_batch;

enum llama_decode_result {
LLAMA_DECODE_RESULT_OK = 0,
LLAMA_DECODE_RESULT_ERR_ALLOC_KV = 1,
LLAMA_DECODE_RESULT_ERR_RESERVE_OUTPUT = -1,
LLAMA_DECODE_RESULT_INVALID_BATCH = -2,
};

enum llama_encode_result {
LLAMA_ENCODE_RESULT_OK = 0,
LLAMA_ENCODE_RESULT_ERR_ALLOC_KV = 1,
LLAMA_ENCODE_RESULT_ERR_NO_ENCODER = 2,
LLAMA_ENCODE_RESULT_ERR_RESERVE_OUTPUT = -1,
LLAMA_ENCODE_RESULT_INVALID_BATCH = -2,
};

enum llama_model_kv_override_type {
LLAMA_KV_OVERRIDE_TYPE_INT,
LLAMA_KV_OVERRIDE_TYPE_FLOAT,
Expand Down Expand Up @@ -801,15 +816,15 @@ extern "C" {
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// 0 - success
// < 0 - error
LLAMA_API int32_t llama_encode(
LLAMA_API enum llama_encode_result llama_encode(
struct llama_context * ctx,
struct llama_batch batch);

// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error
LLAMA_API int32_t llama_decode(
LLAMA_API enum llama_decode_result llama_decode(
struct llama_context * ctx,
struct llama_batch batch);

Expand Down
33 changes: 18 additions & 15 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16064,7 +16064,7 @@ static void llama_graph_compute(
// return positive int on warning
// return negative int on error
//
static int llama_decode_internal(
static enum llama_decode_result llama_decode_internal(
llama_context & lctx,
llama_batch batch_all) { // TODO: rename back to batch

Expand All @@ -16073,13 +16073,13 @@ static int llama_decode_internal(

if (n_tokens_all == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
return -1;
return LLAMA_DECODE_RESULT_INVALID_BATCH;
}

for (uint32_t i = 0; i < n_tokens_all; ++i) {
if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= lctx.model.vocab.n_vocab) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch_all.token[i]);
return -1;
return LLAMA_DECODE_RESULT_INVALID_BATCH;
}
}

Expand Down Expand Up @@ -16132,7 +16132,7 @@ static int llama_decode_internal(
// reserve output buffer
if (llama_output_reserve(lctx, n_outputs) < n_outputs) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs);
return -2;
return LLAMA_DECODE_RESULT_ERR_RESERVE_OUTPUT;
};

while (lctx.sbatch.n_tokens > 0) {
Expand Down Expand Up @@ -16184,7 +16184,7 @@ static int llama_decode_internal(
}

if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
return 1;
return LLAMA_DECODE_RESULT_ERR_ALLOC_KV;
}

if (!kv_self.recurrent) {
Expand Down Expand Up @@ -16350,7 +16350,7 @@ static int llama_decode_internal(
// overlap with device computation.
ggml_backend_sched_reset(lctx.sched);

return 0;
return LLAMA_DECODE_RESULT_OK;
}

// encode a batch of tokens by evaluating the encoder part of the transformer
Expand All @@ -16362,23 +16362,26 @@ static int llama_decode_internal(
// return positive int on warning
// return negative int on error
//
static int llama_encode_internal(
static enum llama_encode_result llama_encode_internal(
llama_context & lctx,
llama_batch batch) {
if (!llama_model_has_encoder(&lctx.model)) {
return LLAMA_ENCODE_RESULT_ERR_NO_ENCODER;
}

lctx.is_encoding = true;

const uint32_t n_tokens = batch.n_tokens;

if (n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__);
return -1;
return LLAMA_ENCODE_RESULT_INVALID_BATCH;
}

for (uint32_t i = 0; i < n_tokens; ++i) {
if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= lctx.model.vocab.n_vocab) {
LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch.token[i]);
return -1;
return LLAMA_ENCODE_RESULT_INVALID_BATCH;
}
}

Expand Down Expand Up @@ -16406,7 +16409,7 @@ static int llama_encode_internal(
// reserve output buffer
if (llama_output_reserve(lctx, n_tokens) < n_tokens) {
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
return -2;
return LLAMA_ENCODE_RESULT_ERR_RESERVE_OUTPUT;
};

for (uint32_t i = 0; i < n_tokens; ++i) {
Expand Down Expand Up @@ -16516,7 +16519,7 @@ static int llama_encode_internal(
// overlap with device computation.
ggml_backend_sched_reset(lctx.sched);

return 0;
return LLAMA_ENCODE_RESULT_OK;
}

// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
Expand Down Expand Up @@ -20038,21 +20041,21 @@ void llama_batch_free(struct llama_batch batch) {
if (batch.logits) free(batch.logits);
}

int32_t llama_encode(
enum llama_encode_result llama_encode(
struct llama_context * ctx,
struct llama_batch batch) {
const int ret = llama_encode_internal(*ctx, batch);
const enum llama_encode_result ret = llama_encode_internal(*ctx, batch);
if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
}

return ret;
}

int32_t llama_decode(
enum llama_decode_result llama_decode(
struct llama_context * ctx,
struct llama_batch batch) {
const int ret = llama_decode_internal(*ctx, batch);
const enum llama_decode_result ret = llama_decode_internal(*ctx, batch);
if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
}
Expand Down
Loading