diff --git a/docs/source/models.rst b/docs/source/models.rst new file mode 100644 index 0000000000..5b40482976 --- /dev/null +++ b/docs/source/models.rst @@ -0,0 +1,17 @@ +.. role:: hidden + :class: hidden-section + +torchaudio.models +====================== + +.. currentmodule:: torchaudio.models + +The models subpackage contains definitions of models for addressing common audio tasks. + + +:hidden:`Wav2Letter` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: Wav2Letter + + .. automethod:: forward diff --git a/test/test_models.py b/test/test_models.py new file mode 100644 index 0000000000..fe8df9a167 --- /dev/null +++ b/test/test_models.py @@ -0,0 +1,30 @@ +import pytest + +import torch +from torchaudio.models import Wav2Letter + + +class TestWav2Letter: + @pytest.mark.parametrize('batch_size', [2]) + @pytest.mark.parametrize('num_features', [1]) + @pytest.mark.parametrize('num_classes', [40]) + @pytest.mark.parametrize('input_length', [320]) + def test_waveform(self, batch_size, num_features, num_classes, input_length): + model = Wav2Letter() + + x = torch.rand(batch_size, num_features, input_length) + out = model(x) + + assert out.size() == (batch_size, num_classes, 2) + + @pytest.mark.parametrize('batch_size', [2]) + @pytest.mark.parametrize('num_features', [13]) + @pytest.mark.parametrize('num_classes', [40]) + @pytest.mark.parametrize('input_length', [2]) + def test_mfcc(self, batch_size, num_features, num_classes, input_length): + model = Wav2Letter(input_type="mfcc", num_features=13) + + x = torch.rand(batch_size, num_features, input_length) + out = model(x) + + assert out.size() == (batch_size, num_classes, 2) diff --git a/torchaudio/models/__init__.py b/torchaudio/models/__init__.py new file mode 100644 index 0000000000..1abdac6271 --- /dev/null +++ b/torchaudio/models/__init__.py @@ -0,0 +1 @@ +from .wav2letter import * diff --git a/torchaudio/models/wav2letter.py b/torchaudio/models/wav2letter.py new file mode 100644 index 0000000000..3466e42dd2 --- /dev/null +++ b/torchaudio/models/wav2letter.py @@ -0,0 +1,74 @@ +from typing import Optional + +from torch import Tensor +from torch import nn + +__all__ = ["Wav2Letter"] + + +class Wav2Letter(nn.Module): + r"""Wav2Letter model architecture from the `"Wav2Letter: an End-to-End ConvNet-based Speech Recognition System" + `_ paper. + + :math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}` + + Args: + num_classes (int, optional): Number of classes to be classified. (Default: ``40``) + input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum`` + or ``mfcc`` (Default: ``waveform``). + num_features (int, optional): Number of input features that the network will receive (Default: ``1``). + """ + + def __init__(self, num_classes: int = 40, + input_type: str = "waveform", + num_features: int = 1) -> None: + super(Wav2Letter, self).__init__() + + acoustic_num_features = 250 if input_type == "waveform" else num_features + acoustic_model = nn.Sequential( + nn.Conv1d(in_channels=acoustic_num_features, out_channels=250, kernel_size=48, stride=2, padding=23), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0), + nn.ReLU(inplace=True), + nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0), + nn.ReLU(inplace=True) + ) + + if input_type == "waveform": + waveform_model = nn.Sequential( + nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45), + nn.ReLU(inplace=True) + ) + self.acoustic_model = nn.Sequential(waveform_model, acoustic_model) + + if input_type in ["power_spectrum", "mfcc"]: + self.acoustic_model = acoustic_model + + def forward(self, x: Tensor) -> Tensor: + r""" + Args: + x (Tensor): Tensor of dimension (batch_size, num_features, input_length). + + Returns: + Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length). + """ + + x = self.acoustic_model(x) + x = nn.functional.log_softmax(x, dim=1) + return x