Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/en/annotators.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ Additionally, these transformers are available.
{% include templates/anno_table_entry.md path="./transformers" name="UniversalSentenceEncoder" summary="The Universal Sentence Encoder encodes text into high dimensional vectors that can be used for text classification, semantic similarity, clustering and other natural language tasks."%}
{% include templates/anno_table_entry.md path="./transformers" name="ViTForImageClassification" summary="Vision Transformer (ViT) for image classification."%}
{% include templates/anno_table_entry.md path="./transformers" name="Wav2Vec2ForCTC" summary="Wav2Vec2 Model with a language modeling head on top for Connectionist Temporal Classification (CTC)."%}
{% include templates/anno_table_entry.md path="./transformers" name="WhisperForCTC" summary="Whisper Model with a language modeling head on top for Connectionist Temporal Classification (CTC)."%}
{% include templates/anno_table_entry.md path="./transformers" name="XlmRoBertaEmbeddings" summary="XlmRoBerta is a large multi-lingual language model, trained on 2.5TB of filtered CommonCrawl"%}
{% include templates/anno_table_entry.md path="./transformers" name="XlmRoBertaForQuestionAnswering" summary="XlmRoBertaForQuestionAnswering can load XLM-RoBERTa Models with a span classification head on top for extractive question-answering tasks like SQuAD."%}
{% include templates/anno_table_entry.md path="./transformers" name="XlmRoBertaForSequenceClassification" summary="XlmRoBertaForSequenceClassification can load XLM-RoBERTa Models with sequence classification/regression head on top e.g. for multi-class document classification tasks."%}
Expand Down
146 changes: 146 additions & 0 deletions docs/en/transformer_entries/WhisperForCTC.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
{%- capture title -%}
WhisperForCTC
{%- endcapture -%}

{%- capture description -%}
Whisper Model with a language modeling head on top for Connectionist Temporal Classification
(CTC).

Whisper is an automatic speech recognition (ASR) system trained on 680,000 hours of
multilingual and multitask supervised data collected from the web. It transcribe in multiple
languages, as well as translate from those languages into English.

The audio needs to be provided pre-processed an array of floats.

Note that at the moment, this annotator only supports greedy search.

For multilingual models, the language and the task (transcribe or translate) can be set with
`setLanguage` and `setTask`.

Pretrained models can be loaded with `pretrained` of the companion object:

```scala
val speechToText = WhisperForCTC.pretrained()
.setInputCols("audio_assembler")
.setOutputCol("text")
```

The default model is `"asr_whisper_tiny_opt"`, if no name is provided.

For available pretrained models please see the [Models Hub](https://sparknlp.org/models).

To see which models are compatible and how to import them see
https://github.com/JohnSnowLabs/spark-nlp/discussions/5669 and to see more extended
examples, see
[WhisperForCTCTestSpec](https://github.com/JohnSnowLabs/spark-nlp/blob/master/src/test/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTCTest.scala).

**References:**

[Robust Speech Recognition via Large-Scale Weak Supervision](https://arxiv.org/abs/2212.04356)

**Paper Abstract:**

*We study the capabilities of speech processing systems trained simply to predict large
amounts of transcripts of audio on the internet. When scaled to 680,000 hours of multilingual
and multitask supervision, the resulting models generalize well to standard benchmarks and are
often competitive with prior fully supervised results but in a zero- shot transfer setting
without the need for any fine- tuning. When compared to humans, the models approach their
accuracy and robustness. We are releasing models and inference code to serve as a foundation
for further work on robust speech processing.*
{%- endcapture -%}

{%- capture input_anno -%}
AUDIO
{%- endcapture -%}

{%- capture output_anno -%}
DOCUMENT
{%- endcapture -%}

{%- capture python_example -%}
import sparknlp
from sparknlp.base import *
from sparknlp.annotator import *
from pyspark.ml import Pipeline

audioAssembler = AudioAssembler() \
.setInputCol("audio_content") \
.setOutputCol("audio_assembler")

speechToText = WhisperForCTC.pretrained() \
.setInputCols(["audio_assembler"]) \
.setOutputCol("text")

pipeline = Pipeline().setStages([audioAssembler, speechToText])
processedAudioFloats = spark.createDataFrame([[rawFloats]]).toDF("audio_content")
result = pipeline.fit(processedAudioFloats).transform(processedAudioFloats)
result.select("text.result").show(truncate = False)
+------------------------------------------------------------------------------------------+
|result |
+------------------------------------------------------------------------------------------+
|[ Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.]|
+------------------------------------------------------------------------------------------+
{%- endcapture -%}

{%- capture scala_example -%}
import spark.implicits._
import com.johnsnowlabs.nlp.base._
import com.johnsnowlabs.nlp.annotators._
import com.johnsnowlabs.nlp.annotators.audio.WhisperForCTC
import org.apache.spark.ml.Pipeline

val audioAssembler: AudioAssembler = new AudioAssembler()
.setInputCol("audio_content")
.setOutputCol("audio_assembler")

val speechToText: WhisperForCTC = WhisperForCTC
.pretrained()
.setInputCols("audio_assembler")
.setOutputCol("text")

val pipeline: Pipeline = new Pipeline().setStages(Array(audioAssembler, speechToText))

val bufferedSource =
scala.io.Source.fromFile("src/test/resources/audio/txt/librispeech_asr_0.txt")

val rawFloats = bufferedSource
.getLines()
.map(_.split(",").head.trim.toFloat)
.toArray
bufferedSource.close

val processedAudioFloats = Seq(rawFloats).toDF("audio_content")

val result = pipeline.fit(processedAudioFloats).transform(processedAudioFloats)
result.select("text.result").show(truncate = false)
+------------------------------------------------------------------------------------------+
|result |
+------------------------------------------------------------------------------------------+
|[ Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.]|
+------------------------------------------------------------------------------------------+

{%- endcapture -%}

{%- capture api_link -%}
[WhisperForCTC](/api/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC)
{%- endcapture -%}

{%- capture python_api_link -%}
[WhisperForCTC](/api/python/reference/autosummary/sparknlp/annotator/audio/whisper_for_ctc/index.html?highlight=whisperforctc#python.sparknlp.annotator.audio.whisper_for_ctc.WhisperForCTC)
{%- endcapture -%}

{%- capture source_link -%}
[WhisperForCTC](https://github.com/JohnSnowLabs/spark-nlp/tree/master/src/main/scala/com/johnsnowlabs/nlp/annotators/audio/WhisperForCTC.scala)
{%- endcapture -%}

{% include templates/anno_template.md
title=title
description=description
input_anno=input_anno
output_anno=output_anno
python_example=python_example
scala_example=scala_example
api_link=api_link
python_api_link=python_api_link
source_link=source_link
%}
Loading