Skip to content

Commit 8a5be3b

Browse files
llama : sanity checks for access to logits (#4274)
Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 88ae895 commit 8a5be3b

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

llama.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,6 +1505,10 @@ struct llama_context {
15051505

15061506
// decode output (2-dimensional array: [n_tokens][n_vocab])
15071507
std::vector<float> logits;
1508+
#ifndef NDEBUG
1509+
// guard against access to unset logits
1510+
std::vector<bool> logits_valid;
1511+
#endif
15081512
bool logits_all = false;
15091513

15101514
// input embedding (1-dimensional array: [n_embd])
@@ -6150,20 +6154,37 @@ static int llama_decode_internal(
61506154
{
61516155
auto & logits_out = lctx.logits;
61526156

6157+
#ifndef NDEBUG
6158+
auto & logits_valid = lctx.logits_valid;
6159+
logits_valid.clear();
6160+
logits_valid.resize(n_tokens);
6161+
6162+
logits_out.clear();
6163+
#endif
6164+
61536165
if (batch.logits) {
61546166
logits_out.resize(n_vocab * n_tokens);
61556167
for (uint32_t i = 0; i < n_tokens; i++) {
61566168
if (batch.logits[i] == 0) {
61576169
continue;
61586170
}
61596171
memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab);
6172+
#ifndef NDEBUG
6173+
logits_valid[i] = true;
6174+
#endif
61606175
}
61616176
} else if (lctx.logits_all) {
61626177
logits_out.resize(n_vocab * n_tokens);
61636178
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens);
6179+
#ifndef NDEBUG
6180+
std::fill(logits_valid.begin(), logits_valid.end(), true);
6181+
#endif
61646182
} else {
61656183
logits_out.resize(n_vocab);
61666184
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab);
6185+
#ifndef NDEBUG
6186+
logits_valid[n_tokens - 1] = true;
6187+
#endif
61676188
}
61686189
}
61696190

@@ -10052,6 +10073,7 @@ float * llama_get_logits(struct llama_context * ctx) {
1005210073
}
1005310074

1005410075
float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
10076+
assert(ctx->logits_valid.at(i));
1005510077
return ctx->logits.data() + i*ctx->model.hparams.n_vocab;
1005610078
}
1005710079

0 commit comments

Comments
 (0)