Skip to content

Commit 5f8a78d

Browse files
committed
use naming convention from readme.
1 parent ed94667 commit 5f8a78d

File tree

2 files changed

+32
-32
lines changed

2 files changed

+32
-32
lines changed

test/torchaudio_unittest/models_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,15 @@ def test_paper_configuration(self, num_sources, model_params):
179179
class TestDeepSpeech(common_utils.TorchaudioTestCase):
180180

181181
def test_deepspeech(self):
182-
batch_size = 2
183-
num_features = 1
184-
num_channels = 1
185-
num_classes = 40
186-
input_length = 320
182+
n_batch = 2
183+
n_feature = 1
184+
n_channel = 1
185+
n_class = 40
186+
n_time = 320
187187

188-
model = DeepSpeech(in_features=1, num_classes=num_classes)
188+
model = DeepSpeech(n_feature=n_feature, n_class=n_class)
189189

190-
x = torch.rand(batch_size, num_channels, input_length, num_features)
190+
x = torch.rand(n_batch, n_channel, n_time, n_feature)
191191
out = model(x)
192192

193-
assert out.size() == (batch_size, input_length, num_classes)
193+
assert out.size() == (n_batch, n_time, n_class)

torchaudio/models/deepspeech.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,17 @@
77
class FullyConnected(nn.Module):
88
"""
99
Args:
10-
in_features: Number of input features
11-
hidden_size: Internal hidden unit size.
10+
n_feature: Number of input features
11+
n_hidden: Internal hidden unit size.
1212
"""
1313

1414
def __init__(self,
15-
in_features: int,
16-
hidden_size: int,
15+
n_feature: int,
16+
n_hidden: int,
1717
dropout: float,
1818
relu_max_clip: int = 20) -> None:
1919
super(FullyConnected, self).__init__()
20-
self.fc = nn.Linear(in_features, hidden_size, bias=True)
20+
self.fc = nn.Linear(n_feature, n_hidden, bias=True)
2121
self.relu_max_clip = relu_max_clip
2222
self.dropout = dropout
2323

@@ -37,32 +37,32 @@ class DeepSpeech(nn.Module):
3737
<https://arxiv.org/abs/1412.5567> paper.
3838
3939
Args:
40-
in_features: Number of input features
41-
hidden_size: Internal hidden unit size.
42-
num_classes: Number of output classes
40+
n_feature: Number of input features
41+
n_hidden: Internal hidden unit size.
42+
n_class: Number of output classes
4343
"""
4444

4545
def __init__(self,
46-
in_features: int,
47-
hidden_size: int = 2048,
48-
num_classes: int = 40,
46+
n_feature: int,
47+
n_hidden: int = 2048,
48+
n_class: int = 40,
4949
dropout: float = 0.0) -> None:
5050
super(DeepSpeech, self).__init__()
51-
self.hidden_size = hidden_size
52-
self.fc1 = FullyConnected(in_features, hidden_size, dropout)
53-
self.fc2 = FullyConnected(hidden_size, hidden_size, dropout)
54-
self.fc3 = FullyConnected(hidden_size, hidden_size, dropout)
51+
self.n_hidden = n_hidden
52+
self.fc1 = FullyConnected(n_feature, n_hidden, dropout)
53+
self.fc2 = FullyConnected(n_hidden, n_hidden, dropout)
54+
self.fc3 = FullyConnected(n_hidden, n_hidden, dropout)
5555
self.bi_rnn = nn.RNN(
56-
hidden_size, hidden_size, num_layers=1, nonlinearity='relu', bidirectional=True)
57-
self.fc4 = FullyConnected(hidden_size, hidden_size, dropout)
58-
self.out = nn.Linear(hidden_size, num_classes)
56+
n_hidden, n_hidden, num_layers=1, nonlinearity='relu', bidirectional=True)
57+
self.fc4 = FullyConnected(n_hidden, n_hidden, dropout)
58+
self.out = nn.Linear(n_hidden, n_class)
5959

6060
def forward(self, x: torch.Tensor) -> torch.Tensor:
6161
"""
6262
Args:
63-
x (torch.Tensor): Tensor of dimension (batch_size, num_channels, input_length, num_features).
63+
x (torch.Tensor): Tensor of dimension (batch, channel, time, feature).
6464
Returns:
65-
Tensor: Predictor tensor of dimension (batch_size, input_length, number_of_classes).
65+
Tensor: Predictor tensor of dimension (batch, time, class).
6666
"""
6767
# N x C x T x F
6868
x = self.fc1(x)
@@ -77,14 +77,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
7777
# T x N x H
7878
x, _ = self.bi_rnn(x)
7979
# The fifth (non-recurrent) layer takes both the forward and backward units as inputs
80-
x = x[:, :, :self.hidden_size] + x[:, :, self.hidden_size:]
80+
x = x[:, :, :self.n_hidden] + x[:, :, self.n_hidden:]
8181
# T x N x H
8282
x = self.fc4(x)
8383
# T x N x H
8484
x = self.out(x)
85-
# T x N x num_classes
85+
# T x N x n_class
8686
x = x.permute(1, 0, 2)
87-
# N x T x num_classes
87+
# N x T x n_class
8888
x = torch.nn.functional.log_softmax(x, dim=2)
89-
# T x N x num_classes
89+
# N x T x n_class
9090
return x

0 commit comments

Comments
 (0)