@@ -2936,6 +2936,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2936
2936
/* .language =*/ " en" ,
2937
2937
2938
2938
/* .suppress_blank =*/ true ,
2939
+ /* .suppress_non_speech_tokens =*/ true ,
2939
2940
2940
2941
/* .temperature =*/ 0 .0f ,
2941
2942
/* .max_initial_ts =*/ 1 .0f ,
@@ -3077,6 +3078,14 @@ static int whisper_wrap_segment(struct whisper_context & ctx, int max_len, bool
3077
3078
return res;
3078
3079
}
3079
3080
3081
+ static const std::vector<std::string> non_speech_tokens
3082
+ {
3083
+ " \" " , " #" , " (" , " )" , " *" , " +" , " /" , " :" , " ;" , " <" , " =" , " >" , " @" , " [" , " \\ " , " ]" , " ^" ,
3084
+ " _" , " `" , " {" , " |" , " }" , " ~" , " 「" , " 」" , " 『" , " 』" , " <<" , " >>" , " <<<" , " >>>" , " --" ,
3085
+ " ---" , " -(" , " -[" , " ('" , " (\" " , " ((" , " ))" , " (((" , " )))" , " [[" , " ]]" , " {{" , " }}" , " ♪♪" ,
3086
+ " ♪♪♪" ," ♩" , " ♪" , " ♫" , " ♬" , " ♭" , " ♮" , " ♯"
3087
+ };
3088
+
3080
3089
// process the logits for the selected decoder
3081
3090
// - applies logit filters
3082
3091
// - computes logprobs and probs
@@ -3137,6 +3146,33 @@ static void whisper_process_logits(
3137
3146
logits[vocab.token_translate ] = -INFINITY;
3138
3147
logits[vocab.token_transcribe ] = -INFINITY;
3139
3148
3149
+
3150
+ // suppress non-speech tokens
3151
+ // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
3152
+ if (params.suppress_non_speech_tokens )
3153
+ {
3154
+ for (const std::string &token : non_speech_tokens)
3155
+ {
3156
+ std::string suppress_tokens[] = {token, " " + token};
3157
+ for (const std::string &suppress_token : suppress_tokens)
3158
+ {
3159
+ if (vocab.token_to_id .find (suppress_token) != vocab.token_to_id .end ())
3160
+ {
3161
+ logits[vocab.token_to_id .at (suppress_token)] = -INFINITY;
3162
+ }
3163
+ }
3164
+ }
3165
+ // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
3166
+ if (vocab.token_to_id .find (" -" ) != vocab.token_to_id .end ())
3167
+ {
3168
+ logits[vocab.token_to_id .at (" -" )] = -INFINITY;
3169
+ }
3170
+ if (vocab.token_to_id .find (" '" ) != vocab.token_to_id .end ())
3171
+ {
3172
+ logits[vocab.token_to_id .at (" '" )] = -INFINITY;
3173
+ }
3174
+ }
3175
+
3140
3176
// timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
3141
3177
// https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424
3142
3178
{
0 commit comments