Skip to content

Commit 39cfad0

Browse files
authored
whisper : add support for new distilled Whisper models (#1424)
* whisper : add support for new distilled Whisper models * whisper : print log when using distilled models
1 parent 6d4d0b5 commit 39cfad0

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

whisper.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3940,6 +3940,7 @@ static void whisper_process_logits(
39403940
// suppress task tokens
39413941
logits[vocab.token_translate] = -INFINITY;
39423942
logits[vocab.token_transcribe] = -INFINITY;
3943+
logits[vocab.token_prev] = -INFINITY;
39433944

39443945
if (params.logits_filter_callback) {
39453946
params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
@@ -4558,6 +4559,7 @@ int whisper_full_with_state(
45584559

45594560
// these tokens determine the task that will be performed
45604561
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
4562+
45614563
if (whisper_is_multilingual(ctx)) {
45624564
const int lang_id = whisper_lang_id(params.language);
45634565
state->lang_id = lang_id;
@@ -4569,6 +4571,17 @@ int whisper_full_with_state(
45694571
}
45704572
}
45714573

4574+
{
4575+
const bool is_distil = ctx->model.hparams.n_text_layer == 2;
4576+
4577+
// distilled models require the "no_timestamps" token
4578+
// TODO: add input parameter (#1229)
4579+
if (is_distil) {
4580+
log("%s: using distilled model - forcing no_timestamps\n", __func__);
4581+
prompt_init.push_back(whisper_token_not(ctx));
4582+
}
4583+
}
4584+
45724585
int seek = seek_start;
45734586

45744587
std::vector<whisper_token> prompt;

0 commit comments

Comments
 (0)