Skip to content

Commit d678357

Browse files
authored
Add model Wav2Letter (#462)
* add wav2letter model * add unit_test to model * add docstrings * add documentation * fix minor error, change logic on forward * update padding same with ceil * add inline typing and minor fixes to docstrings * remove python2 * add formula do docstrings, change param name * add test with mfcc, add pytest * fix bug, update docstrings * change parameter name
1 parent 3ecc701 commit d678357

File tree

4 files changed

+122
-0
lines changed

4 files changed

+122
-0
lines changed

docs/source/models.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
.. role:: hidden
2+
:class: hidden-section
3+
4+
torchaudio.models
5+
======================
6+
7+
.. currentmodule:: torchaudio.models
8+
9+
The models subpackage contains definitions of models for addressing common audio tasks.
10+
11+
12+
:hidden:`Wav2Letter`
13+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
14+
15+
.. autoclass:: Wav2Letter
16+
17+
.. automethod:: forward

test/test_models.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import pytest
2+
3+
import torch
4+
from torchaudio.models import Wav2Letter
5+
6+
7+
class TestWav2Letter:
8+
@pytest.mark.parametrize('batch_size', [2])
9+
@pytest.mark.parametrize('num_features', [1])
10+
@pytest.mark.parametrize('num_classes', [40])
11+
@pytest.mark.parametrize('input_length', [320])
12+
def test_waveform(self, batch_size, num_features, num_classes, input_length):
13+
model = Wav2Letter()
14+
15+
x = torch.rand(batch_size, num_features, input_length)
16+
out = model(x)
17+
18+
assert out.size() == (batch_size, num_classes, 2)
19+
20+
@pytest.mark.parametrize('batch_size', [2])
21+
@pytest.mark.parametrize('num_features', [13])
22+
@pytest.mark.parametrize('num_classes', [40])
23+
@pytest.mark.parametrize('input_length', [2])
24+
def test_mfcc(self, batch_size, num_features, num_classes, input_length):
25+
model = Wav2Letter(input_type="mfcc", num_features=13)
26+
27+
x = torch.rand(batch_size, num_features, input_length)
28+
out = model(x)
29+
30+
assert out.size() == (batch_size, num_classes, 2)

torchaudio/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .wav2letter import *

torchaudio/models/wav2letter.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import Optional
2+
3+
from torch import Tensor
4+
from torch import nn
5+
6+
__all__ = ["Wav2Letter"]
7+
8+
9+
class Wav2Letter(nn.Module):
10+
r"""Wav2Letter model architecture from the `"Wav2Letter: an End-to-End ConvNet-based Speech Recognition System"
11+
<https://arxiv.org/abs/1609.03193>`_ paper.
12+
13+
:math:`\text{padding} = \frac{\text{ceil}(\text{kernel} - \text{stride})}{2}`
14+
15+
Args:
16+
num_classes (int, optional): Number of classes to be classified. (Default: ``40``)
17+
input_type (str, optional): Wav2Letter can use as input: ``waveform``, ``power_spectrum``
18+
or ``mfcc`` (Default: ``waveform``).
19+
num_features (int, optional): Number of input features that the network will receive (Default: ``1``).
20+
"""
21+
22+
def __init__(self, num_classes: int = 40,
23+
input_type: str = "waveform",
24+
num_features: int = 1) -> None:
25+
super(Wav2Letter, self).__init__()
26+
27+
acoustic_num_features = 250 if input_type == "waveform" else num_features
28+
acoustic_model = nn.Sequential(
29+
nn.Conv1d(in_channels=acoustic_num_features, out_channels=250, kernel_size=48, stride=2, padding=23),
30+
nn.ReLU(inplace=True),
31+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
32+
nn.ReLU(inplace=True),
33+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
34+
nn.ReLU(inplace=True),
35+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
36+
nn.ReLU(inplace=True),
37+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
38+
nn.ReLU(inplace=True),
39+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
40+
nn.ReLU(inplace=True),
41+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
42+
nn.ReLU(inplace=True),
43+
nn.Conv1d(in_channels=250, out_channels=250, kernel_size=7, stride=1, padding=3),
44+
nn.ReLU(inplace=True),
45+
nn.Conv1d(in_channels=250, out_channels=2000, kernel_size=32, stride=1, padding=16),
46+
nn.ReLU(inplace=True),
47+
nn.Conv1d(in_channels=2000, out_channels=2000, kernel_size=1, stride=1, padding=0),
48+
nn.ReLU(inplace=True),
49+
nn.Conv1d(in_channels=2000, out_channels=num_classes, kernel_size=1, stride=1, padding=0),
50+
nn.ReLU(inplace=True)
51+
)
52+
53+
if input_type == "waveform":
54+
waveform_model = nn.Sequential(
55+
nn.Conv1d(in_channels=num_features, out_channels=250, kernel_size=250, stride=160, padding=45),
56+
nn.ReLU(inplace=True)
57+
)
58+
self.acoustic_model = nn.Sequential(waveform_model, acoustic_model)
59+
60+
if input_type in ["power_spectrum", "mfcc"]:
61+
self.acoustic_model = acoustic_model
62+
63+
def forward(self, x: Tensor) -> Tensor:
64+
r"""
65+
Args:
66+
x (Tensor): Tensor of dimension (batch_size, num_features, input_length).
67+
68+
Returns:
69+
Tensor: Predictor tensor of dimension (batch_size, number_of_classes, input_length).
70+
"""
71+
72+
x = self.acoustic_model(x)
73+
x = nn.functional.log_softmax(x, dim=1)
74+
return x

0 commit comments

Comments
 (0)