Skip to content

Add vanilla DeepSpeech model #1399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ The models subpackage contains definitions of models for addressing common audio
.. automethod:: forward


:hidden:`DeepSpeech`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: DeepSpeech

.. automethod:: forward


:hidden:`Wav2Letter`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
19 changes: 18 additions & 1 deletion test/torchaudio_unittest/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from parameterized import parameterized
from torchaudio.models import ConvTasNet, Wav2Letter, WaveRNN
from torchaudio.models import ConvTasNet, DeepSpeech, Wav2Letter, WaveRNN
from torchaudio.models.wavernn import MelResNet, UpsampleNetwork
from torchaudio_unittest import common_utils

Expand Down Expand Up @@ -174,3 +174,20 @@ def test_paper_configuration(self, num_sources, model_params):
output = model(tensor)

assert output.shape == (batch_size, num_sources, num_frames)


class TestDeepSpeech(common_utils.TorchaudioTestCase):

def test_deepspeech(self):
n_batch = 2
n_feature = 1
n_channel = 1
n_class = 40
n_time = 320

model = DeepSpeech(n_feature=n_feature, n_class=n_class)

x = torch.rand(n_batch, n_channel, n_time, n_feature)
out = model(x)

assert out.size() == (n_batch, n_time, n_class)
2 changes: 2 additions & 0 deletions torchaudio/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .wav2letter import Wav2Letter
from .wavernn import WaveRNN
from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech

__all__ = [
'Wav2Letter',
'WaveRNN',
'ConvTasNet',
'DeepSpeech',
]
92 changes: 92 additions & 0 deletions torchaudio/models/deepspeech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch

__all__ = ["DeepSpeech"]


class FullyConnected(torch.nn.Module):
"""
Args:
n_feature: Number of input features
n_hidden: Internal hidden unit size.
"""

def __init__(self,
n_feature: int,
n_hidden: int,
dropout: float,
relu_max_clip: int = 20) -> None:
super(FullyConnected, self).__init__()
self.fc = torch.nn.Linear(n_feature, n_hidden, bias=True)
self.relu_max_clip = relu_max_clip
self.dropout = dropout

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc(x)
x = torch.nn.functional.relu(x)
x = torch.nn.functional.hardtanh(x, 0, self.relu_max_clip)
if self.dropout:
x = torch.nn.functional.dropout(x, self.dropout, self.training)
return x


class DeepSpeech(torch.nn.Module):
"""
DeepSpeech model architecture from
`"Deep Speech: Scaling up end-to-end speech recognition"`
<https://arxiv.org/abs/1412.5567> paper.

Args:
n_feature: Number of input features
n_hidden: Internal hidden unit size.
n_class: Number of output classes
"""

def __init__(
self,
n_feature: int,
n_hidden: int = 2048,
n_class: int = 40,
dropout: float = 0.0,
) -> None:
super(DeepSpeech, self).__init__()
self.n_hidden = n_hidden
self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
self.bi_rnn = torch.nn.RNN(
n_hidden, n_hidden, num_layers=1, nonlinearity="relu", bidirectional=True
)
self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
self.out = torch.nn.Linear(n_hidden, n_class)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): Tensor of dimension (batch, channel, time, feature).
Returns:
Tensor: Predictor tensor of dimension (batch, time, class).
"""
# N x C x T x F
x = self.fc1(x)
# N x C x T x H
x = self.fc2(x)
# N x C x T x H
x = self.fc3(x)
# N x C x T x H
x = x.squeeze(1)
# N x T x H
x = x.transpose(0, 1)
# T x N x H
x, _ = self.bi_rnn(x)
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs
x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:]
# T x N x H
x = self.fc4(x)
# T x N x H
x = self.out(x)
# T x N x n_class
x = x.permute(1, 0, 2)
# N x T x n_class
x = torch.nn.functional.log_softmax(x, dim=2)
# N x T x n_class
return x