diff --git a/daras_ai_v2/asr.py b/daras_ai_v2/asr.py index a8b6656da..1bf919e12 100644 --- a/daras_ai_v2/asr.py +++ b/daras_ai_v2/asr.py @@ -283,6 +283,7 @@ class AsrModels(Enum): ghana_nlp_asr_v2 = "Ghana NLP ASR v2" lelapa = "Vulavula (Lelapa AI)" whisper_sunbird_large_v3 = "Sunbird Ugandan Whisper v3 (Sunbird AI)" + whisper_akera_large_v3 = "Akera Whisper v3 (akera)" whisper_swahili_medium_v3 = "Jacaranda Health Swahili Whisper v3 (Jacaranda Health)" mbaza_ctc_large = "Mbaza Conformer LG (MbazaNLP)" @@ -336,6 +337,7 @@ def supports_input_prompt(self) -> bool: asr_model_ids = { + AsrModels.whisper_akera_large_v3: "akera/whisper-large-v3-kik-full_v2", AsrModels.gpt_4_o_audio: "gpt-4o-transcribe", AsrModels.gpt_4_o_mini_audio: "gpt-4o-mini-transcribe", AsrModels.whisper_large_v3: "vaibhavs10/incredibly-fast-whisper:3ab86df6c8f54c11309d4d1f930ac292bad43ace52d10c80d87eb258b3c9f79c", @@ -362,6 +364,7 @@ def supports_input_prompt(self) -> bool: AsrModels.vakyansh_bhojpuri: "bho", AsrModels.nemo_english: "en", AsrModels.nemo_hindi: "hi", + AsrModels.whisper_akera_large_v3: "kik", } asr_supported_languages = { @@ -386,6 +389,7 @@ def supports_input_prompt(self) -> bool: AsrModels.lelapa: LELAPA_ASR_SUPPORTED, AsrModels.whisper_sunbird_large_v3: SUNBIRD_SUPPORTED_LANGUAGES, AsrModels.whisper_swahili_medium_v3: {"sw", "en"}, + AsrModels.whisper_akera_large_v3: {"kik"}, AsrModels.mbaza_ctc_large: {"sw", "rw", "lg"}, } @@ -1286,13 +1290,17 @@ def run_asr( ) # call one of the self-hosted models else: - kwargs = {} + kwargs = {"task": "translate" if speech_translation_target else "transcribe"} if "vakyansh" in selected_model.name: # fixes https://github.com/huggingface/transformers/issues/15275#issuecomment-1624879632 kwargs["decoder_kwargs"] = dict(skip_special_tokens=True) kwargs["chunk_length_s"] = 60 kwargs["stride_length_s"] = (6, 0) kwargs["batch_size"] = 32 + elif selected_model == AsrModels.whisper_akera_large_v3: + # don't pass language or task + kwargs.pop("task", None) + kwargs["max_length"] = 448 elif "whisper" in selected_model.name: forced_lang = forced_asr_languages.get(selected_model) if forced_lang: @@ -1308,7 +1316,6 @@ def run_asr( ), inputs=dict( audio=audio_url, - task="translate" if speech_translation_target else "transcribe", return_timestamps=output_format != AsrOutputFormat.text, **kwargs, ),