Skip to content

Commit 5f2d4e6

Browse files
authored
ppl : fix n_seq_max for perplexity (#8277)
* ppl : fix n_seq_max for perplexity * use 1 seq for kl_divergence
1 parent 916248a commit 5f2d4e6

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,6 +1991,12 @@ int main(int argc, char ** argv) {
19911991
params.n_batch = std::min(params.n_batch, n_kv);
19921992
} else {
19931993
params.n_batch = std::min(params.n_batch, params.n_ctx);
1994+
if (params.kl_divergence) {
1995+
params.n_parallel = 1;
1996+
} else {
1997+
// ensure there's at least enough seq_ids for HellaSwag
1998+
params.n_parallel = std::max(4, params.n_parallel);
1999+
}
19942000
}
19952001

19962002
if (params.ppl_stride > 0) {
@@ -2015,9 +2021,6 @@ int main(int argc, char ** argv) {
20152021
llama_model * model;
20162022
llama_context * ctx;
20172023

2018-
// ensure there's at least enough seq_ids for HellaSwag
2019-
params.n_parallel = std::max(4, params.n_parallel);
2020-
20212024
// load the model and apply lora adapter, if any
20222025
std::tie(model, ctx) = llama_init_from_gpt_params(params);
20232026
if (model == NULL) {

0 commit comments

Comments
 (0)