Skip to content

Commit cfc06bf

Browse files
authored
whisper : suppress non-speech-related token outputs (#473)
* add non-speech-token suppression * add suppress non-speech_tokens param
1 parent 2bfe0eb commit cfc06bf

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

whisper.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2936,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
29362936
/*.language =*/ "en",
29372937

29382938
/*.suppress_blank =*/ true,
2939+
/*.suppress_non_speech_tokens =*/true,
29392940

29402941
/*.temperature =*/ 0.0f,
29412942
/*.max_initial_ts =*/ 1.0f,
@@ -3077,6 +3078,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
30773078
return res;
30783079
}
30793080

3081+
static const std::vector<std::string> non_speech_tokens
3082+
{
3083+
"\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^",
3084+
"_", "`", "{", "|", "}", "~", "", "", "", "", "<<", ">>", "<<<", ">>>", "--",
3085+
"---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪",
3086+
"♪♪♪","", "", "", "", "", "", ""
3087+
};
3088+
30803089
// process the logits for the selected decoder
30813090
// - applies logit filters
30823091
// - computes logprobs and probs
@@ -3137,6 +3146,33 @@ static void whisper_process_logits(
31373146
logits[vocab.token_translate] = -INFINITY;
31383147
logits[vocab.token_transcribe] = -INFINITY;
31393148

3149+
3150+
// suppress non-speech tokens
3151+
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
3152+
if (params.suppress_non_speech_tokens)
3153+
{
3154+
for (const std::string &token : non_speech_tokens)
3155+
{
3156+
std::string suppress_tokens[] = {token, " " + token};
3157+
for (const std::string &suppress_token : suppress_tokens)
3158+
{
3159+
if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end())
3160+
{
3161+
logits[vocab.token_to_id.at(suppress_token)] = -INFINITY;
3162+
}
3163+
}
3164+
}
3165+
// allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
3166+
if (vocab.token_to_id.find(" -") != vocab.token_to_id.end())
3167+
{
3168+
logits[vocab.token_to_id.at(" -")] = -INFINITY;
3169+
}
3170+
if (vocab.token_to_id.find(" '") != vocab.token_to_id.end())
3171+
{
3172+
logits[vocab.token_to_id.at(" '")] = -INFINITY;
3173+
}
3174+
}
3175+
31403176
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
31413177
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
31423178
{

whisper.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ extern "C" {
285285

286286
// common decoding parameters:
287287
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
288+
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
288289

289290
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
290291
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97

0 commit comments

Comments
 (0)