@@ -1505,6 +1505,10 @@ struct llama_context {
1505
1505
1506
1506
// decode output (2-dimensional array: [n_tokens][n_vocab])
1507
1507
std::vector<float > logits;
1508
+ #ifndef NDEBUG
1509
+ // guard against access to unset logits
1510
+ std::vector<bool > logits_valid;
1511
+ #endif
1508
1512
bool logits_all = false ;
1509
1513
1510
1514
// input embedding (1-dimensional array: [n_embd])
@@ -6150,20 +6154,37 @@ static int llama_decode_internal(
6150
6154
{
6151
6155
auto & logits_out = lctx.logits ;
6152
6156
6157
+ #ifndef NDEBUG
6158
+ auto & logits_valid = lctx.logits_valid ;
6159
+ logits_valid.clear ();
6160
+ logits_valid.resize (n_tokens);
6161
+
6162
+ logits_out.clear ();
6163
+ #endif
6164
+
6153
6165
if (batch.logits ) {
6154
6166
logits_out.resize (n_vocab * n_tokens);
6155
6167
for (uint32_t i = 0 ; i < n_tokens; i++) {
6156
6168
if (batch.logits [i] == 0 ) {
6157
6169
continue ;
6158
6170
}
6159
6171
memcpy (logits_out.data () + (n_vocab*i), (float *) ggml_get_data (res) + (n_vocab*i), sizeof (float )*n_vocab);
6172
+ #ifndef NDEBUG
6173
+ logits_valid[i] = true ;
6174
+ #endif
6160
6175
}
6161
6176
} else if (lctx.logits_all ) {
6162
6177
logits_out.resize (n_vocab * n_tokens);
6163
6178
memcpy (logits_out.data (), (float *) ggml_get_data (res), sizeof (float )*n_vocab*n_tokens);
6179
+ #ifndef NDEBUG
6180
+ std::fill (logits_valid.begin (), logits_valid.end (), true );
6181
+ #endif
6164
6182
} else {
6165
6183
logits_out.resize (n_vocab);
6166
6184
memcpy (logits_out.data (), (float *) ggml_get_data (res) + (n_vocab*(n_tokens - 1 )), sizeof (float )*n_vocab);
6185
+ #ifndef NDEBUG
6186
+ logits_valid[n_tokens - 1 ] = true ;
6187
+ #endif
6167
6188
}
6168
6189
}
6169
6190
@@ -10052,6 +10073,7 @@ float * llama_get_logits(struct llama_context * ctx) {
10052
10073
}
10053
10074
10054
10075
float * llama_get_logits_ith (struct llama_context * ctx, int32_t i) {
10076
+ assert (ctx->logits_valid .at (i));
10055
10077
return ctx->logits .data () + i*ctx->model .hparams .n_vocab ;
10056
10078
}
10057
10079
0 commit comments