diff --git a/include/llama.h b/include/llama.h index 405af912c4686..e001a7e97e746 100644 --- a/include/llama.h +++ b/include/llama.h @@ -253,6 +253,21 @@ extern "C" { llama_seq_id all_seq_id; // used if seq_id == NULL } llama_batch; + enum llama_decode_result { + LLAMA_DECODE_RESULT_OK = 0, + LLAMA_DECODE_RESULT_ERR_ALLOC_KV = 1, + LLAMA_DECODE_RESULT_ERR_RESERVE_OUTPUT = -1, + LLAMA_DECODE_RESULT_INVALID_BATCH = -2, + }; + + enum llama_encode_result { + LLAMA_ENCODE_RESULT_OK = 0, + LLAMA_ENCODE_RESULT_ERR_ALLOC_KV = 1, + LLAMA_ENCODE_RESULT_ERR_NO_ENCODER = 2, + LLAMA_ENCODE_RESULT_ERR_RESERVE_OUTPUT = -1, + LLAMA_ENCODE_RESULT_INVALID_BATCH = -2, + }; + enum llama_model_kv_override_type { LLAMA_KV_OVERRIDE_TYPE_INT, LLAMA_KV_OVERRIDE_TYPE_FLOAT, @@ -801,7 +816,7 @@ extern "C" { // Stores the encoder output internally for later use by the decoder cross-attention layers. // 0 - success // < 0 - error - LLAMA_API int32_t llama_encode( + LLAMA_API enum llama_encode_result llama_encode( struct llama_context * ctx, struct llama_batch batch); @@ -809,7 +824,7 @@ extern "C" { // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // < 0 - error - LLAMA_API int32_t llama_decode( + LLAMA_API enum llama_decode_result llama_decode( struct llama_context * ctx, struct llama_batch batch); diff --git a/src/llama.cpp b/src/llama.cpp index 40db035171127..7a46a7fee6896 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -16064,7 +16064,7 @@ static void llama_graph_compute( // return positive int on warning // return negative int on error // -static int llama_decode_internal( +static enum llama_decode_result llama_decode_internal( llama_context & lctx, llama_batch batch_all) { // TODO: rename back to batch @@ -16073,13 +16073,13 @@ static int llama_decode_internal( if (n_tokens_all == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); - return -1; + return LLAMA_DECODE_RESULT_INVALID_BATCH; } for (uint32_t i = 0; i < n_tokens_all; ++i) { if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= lctx.model.vocab.n_vocab) { LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch_all.token[i]); - return -1; + return LLAMA_DECODE_RESULT_INVALID_BATCH; } } @@ -16132,7 +16132,7 @@ static int llama_decode_internal( // reserve output buffer if (llama_output_reserve(lctx, n_outputs) < n_outputs) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_outputs); - return -2; + return LLAMA_DECODE_RESULT_ERR_RESERVE_OUTPUT; }; while (lctx.sbatch.n_tokens > 0) { @@ -16184,7 +16184,7 @@ static int llama_decode_internal( } if (!llama_kv_cache_find_slot(kv_self, ubatch)) { - return 1; + return LLAMA_DECODE_RESULT_ERR_ALLOC_KV; } if (!kv_self.recurrent) { @@ -16350,7 +16350,7 @@ static int llama_decode_internal( // overlap with device computation. ggml_backend_sched_reset(lctx.sched); - return 0; + return LLAMA_DECODE_RESULT_OK; } // encode a batch of tokens by evaluating the encoder part of the transformer @@ -16362,9 +16362,12 @@ static int llama_decode_internal( // return positive int on warning // return negative int on error // -static int llama_encode_internal( +static enum llama_encode_result llama_encode_internal( llama_context & lctx, llama_batch batch) { + if (!llama_model_has_encoder(&lctx.model)) { + return LLAMA_ENCODE_RESULT_ERR_NO_ENCODER; + } lctx.is_encoding = true; @@ -16372,13 +16375,13 @@ static int llama_encode_internal( if (n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); - return -1; + return LLAMA_ENCODE_RESULT_INVALID_BATCH; } for (uint32_t i = 0; i < n_tokens; ++i) { if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= lctx.model.vocab.n_vocab) { LLAMA_LOG_ERROR("%s: invalid token[%d] = %d", __func__, i, batch.token[i]); - return -1; + return LLAMA_ENCODE_RESULT_INVALID_BATCH; } } @@ -16406,7 +16409,7 @@ static int llama_encode_internal( // reserve output buffer if (llama_output_reserve(lctx, n_tokens) < n_tokens) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); - return -2; + return LLAMA_ENCODE_RESULT_ERR_RESERVE_OUTPUT; }; for (uint32_t i = 0; i < n_tokens; ++i) { @@ -16516,7 +16519,7 @@ static int llama_encode_internal( // overlap with device computation. ggml_backend_sched_reset(lctx.sched); - return 0; + return LLAMA_ENCODE_RESULT_OK; } // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache @@ -20038,10 +20041,10 @@ void llama_batch_free(struct llama_batch batch) { if (batch.logits) free(batch.logits); } -int32_t llama_encode( +enum llama_encode_result llama_encode( struct llama_context * ctx, struct llama_batch batch) { - const int ret = llama_encode_internal(*ctx, batch); + const enum llama_encode_result ret = llama_encode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -20049,10 +20052,10 @@ int32_t llama_encode( return ret; } -int32_t llama_decode( +enum llama_decode_result llama_decode( struct llama_context * ctx, struct llama_batch batch) { - const int ret = llama_decode_internal(*ctx, batch); + const enum llama_decode_result ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); }