Skip to content

Commit 00efbe6

Browse files
authored
Merge pull request pytorch#105 from jamarshon/T44497670
Migrate audio transform computations into functional.py
2 parents acdedc4 + ec0b29f commit 00efbe6

File tree

2 files changed

+375
-136
lines changed

2 files changed

+375
-136
lines changed

torchaudio/functional.py

Lines changed: 358 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,358 @@
1+
import numpy as np
2+
import torch
3+
4+
5+
__all__ = [
6+
'scale',
7+
'pad_trim',
8+
'downmix_mono',
9+
'LC2CL',
10+
'spectrogram',
11+
'create_fb_matrix',
12+
'mel_scale',
13+
'spectrogram_to_DB',
14+
'create_dct',
15+
'MFCC',
16+
'BLC2CBL',
17+
'mu_law_encoding',
18+
'mu_law_expanding'
19+
]
20+
21+
22+
def scale(tensor, factor):
23+
# type: (Tensor, int) -> Tensor
24+
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
25+
to a floating point number between -1.0 and 1.0. Note the 16-bit number is
26+
called the "bit depth" or "precision", not to be confused with "bit rate".
27+
28+
Inputs:
29+
tensor (Tensor): Tensor of audio of size (Samples x Channels)
30+
factor (int): Maximum value of input tensor
31+
32+
Outputs:
33+
Tensor: Scaled by the scale factor
34+
"""
35+
if not tensor.dtype.is_floating_point:
36+
tensor = tensor.to(torch.float32)
37+
38+
return tensor / factor
39+
40+
41+
def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
42+
# type: (Tensor, int, int, int, float) -> Tensor
43+
"""Pad/Trim a 2d-Tensor (Signal or Labels)
44+
45+
Inputs:
46+
tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
47+
ch_dim (int): Dimension of channel (not size)
48+
max_len (int): Length to which the tensor will be padded
49+
len_dim (int): Dimension of length (not size)
50+
fill_value (float): Value to fill in
51+
52+
Outputs:
53+
Tensor: Padded/trimmed tensor
54+
"""
55+
if max_len > tensor.size(len_dim):
56+
# tuple of (padding_left, padding_right, padding_top, padding_bottom)
57+
# so pad similar to append (aka only right/bottom) and do not pad
58+
# the length dimension. assumes equal sizes of padding.
59+
padding = [max_len - tensor.size(len_dim)
60+
if (i % 2 == 1) and (i // 2 != len_dim)
61+
else 0
62+
for i in range(4)]
63+
with torch.no_grad():
64+
tensor = torch.nn.functional.pad(tensor, padding, "constant", fill_value)
65+
elif max_len < tensor.size(len_dim):
66+
tensor = tensor.narrow(len_dim, 0, max_len)
67+
return tensor
68+
69+
70+
def downmix_mono(tensor, ch_dim):
71+
# type: (Tensor, int) -> Tensor
72+
"""Downmix any stereo signals to mono.
73+
74+
Inputs:
75+
tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
76+
ch_dim (int): Dimension of channel (not size)
77+
78+
Outputs:
79+
Tensor: Mono signal
80+
"""
81+
if not tensor.dtype.is_floating_point:
82+
tensor = tensor.to(torch.float32)
83+
84+
tensor = torch.mean(tensor, ch_dim, True)
85+
return tensor
86+
87+
88+
def LC2CL(tensor):
89+
# type: (Tensor) -> Tensor
90+
"""Permute a 2d tensor from samples (n x c) to (c x n)
91+
92+
Inputs:
93+
tensor (Tensor): Tensor of audio signal with shape (LxC)
94+
95+
Outputs:
96+
Tensor: Tensor of audio signal with shape (CxL)
97+
"""
98+
return tensor.transpose(0, 1).contiguous()
99+
100+
101+
def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
102+
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
103+
"""Create a spectrogram from a raw audio signal
104+
105+
Inputs:
106+
sig (Tensor): Tensor of audio of size (c, n)
107+
pad (int): two sided padding of signal
108+
window (Tensor): window_tensor
109+
n_fft (int): size of fft
110+
hop (int): length of hop between STFT windows
111+
ws (int): window size
112+
power (int > 0 ) : Exponent for the magnitude spectrogram,
113+
e.g., 1 for energy, 2 for power, etc.
114+
normalize (bool) : whether to normalize by magnitude after stft
115+
116+
117+
Outputs:
118+
Tensor: channels x hops x n_fft (c, l, f), where channels
119+
is unchanged, hops is the number of hops, and n_fft is the
120+
number of fourier bins, which should be the window size divided
121+
by 2 plus 1.
122+
"""
123+
assert sig.dim() == 2
124+
125+
if pad > 0:
126+
with torch.no_grad():
127+
sig = torch.nn.functional.pad(sig, (pad, pad), "constant")
128+
window = window.to(sig.device)
129+
130+
# default values are consistent with librosa.core.spectrum._spectrogram
131+
spec_f = torch.stft(sig, n_fft, hop, ws,
132+
window, center=True,
133+
normalized=False, onesided=True,
134+
pad_mode='reflect').transpose(1, 2)
135+
if normalize:
136+
spec_f /= window.pow(2).sum().sqrt()
137+
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft)
138+
return spec_f
139+
140+
141+
def create_fb_matrix(n_stft, f_min, f_max, n_mels):
142+
# type: (int, float, float, int) -> Tensor
143+
""" Create a frequency bin conversion matrix.
144+
145+
Inputs:
146+
n_stft (int): number of filter banks from spectrogram
147+
f_min (float): minimum frequency
148+
f_max (float): maximum frequency
149+
n_mels (int): number of mel bins
150+
151+
Outputs:
152+
Tensor: triangular filter banks (fb matrix)
153+
"""
154+
def _hertz_to_mel(f):
155+
# type: (float) -> Tensor
156+
return 2595. * torch.log10(torch.tensor(1.) + (f / 700.))
157+
158+
def _mel_to_hertz(mel):
159+
# type: (Tensor) -> Tensor
160+
return 700. * (10**(mel / 2595.) - 1.)
161+
162+
# get stft freq bins
163+
stft_freqs = torch.linspace(f_min, f_max, n_stft)
164+
# calculate mel freq bins
165+
m_min = 0. if f_min == 0 else _hertz_to_mel(f_min)
166+
m_max = _hertz_to_mel(f_max)
167+
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
168+
f_pts = _mel_to_hertz(m_pts)
169+
# calculate the difference between each mel point and each stft freq point in hertz
170+
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
171+
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2)
172+
# create overlapping triangles
173+
z = torch.tensor(0.)
174+
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels)
175+
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels)
176+
fb = torch.max(z, torch.min(down_slopes, up_slopes))
177+
return fb
178+
179+
180+
def mel_scale(spec_f, f_min, f_max, n_mels, fb=None):
181+
# type: (Tensor, float, float, int, Optional[Tensor]) -> Tuple[Tensor, Tensor]
182+
""" This turns a normal STFT into a mel frequency STFT, using a conversion
183+
matrix. This uses triangular filter banks.
184+
185+
Inputs:
186+
spec_f (Tensor): normal STFT
187+
f_min (float): minimum frequency
188+
f_max (float): maximum frequency
189+
n_mels (int): number of mel bins
190+
fb (Optional[Tensor]): triangular filter banks (fb matrix)
191+
192+
Outputs:
193+
Tuple[Tensor, Tensor]: triangular filter banks (fb matrix) and mel frequency STFT
194+
"""
195+
if fb is None:
196+
fb = create_fb_matrix(spec_f.size(2), f_min, f_max, n_mels).to(spec_f.device)
197+
else:
198+
# need to ensure same device for dot product
199+
fb = fb.to(spec_f.device)
200+
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
201+
return fb, spec_m
202+
203+
204+
def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
205+
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
206+
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
207+
208+
This output depends on the maximum value in the input spectrogram, and so
209+
may return different values for an audio clip split into snippets vs. a
210+
a full clip.
211+
212+
Inputs:
213+
spec (Tensor): normal STFT
214+
multiplier (float): use 10. for power and 20. for amplitude
215+
amin (float): number to clamp spec
216+
db_multiplier (float): log10(max(reference value and amin))
217+
top_db (Optional[float]): minimum negative cut-off in decibels. A reasonable number
218+
is 80.
219+
220+
Outputs:
221+
Tensor: spectrogram in DB
222+
"""
223+
spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin))
224+
spec_db -= multiplier * db_multiplier
225+
226+
if top_db is not None:
227+
spec_db = torch.max(spec_db, spec_db.new_full((1,), spec_db.max() - top_db))
228+
return spec_db
229+
230+
231+
def create_dct(n_mfcc, n_mels, norm):
232+
# type: (int, int, string) -> Tensor
233+
"""
234+
Creates a DCT transformation matrix with shape (num_mels, num_mfcc),
235+
normalized depending on norm
236+
237+
Inputs:
238+
n_mfcc (int) : number of mfc coefficients to retain
239+
n_mels (int): number of MEL bins
240+
norm (string) : norm to use
241+
242+
Outputs:
243+
Tensor: The transformation matrix, to be right-multiplied to row-wise data.
244+
"""
245+
outdim = n_mfcc
246+
dim = n_mels
247+
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
248+
n = np.arange(dim)
249+
k = np.arange(outdim)[:, np.newaxis]
250+
dct = np.cos(np.pi / dim * (n + 0.5) * k)
251+
if norm == 'ortho':
252+
dct[0] *= 1.0 / np.sqrt(2)
253+
dct *= np.sqrt(2.0 / dim)
254+
else:
255+
dct *= 2
256+
return torch.Tensor(dct.T)
257+
258+
259+
def MFCC(sig, mel_spect, log_mels, s2db, dct_mat):
260+
# type: (Tensor, MelSpectrogram, bool, SpectrogramToDB, Tensor) -> Tensor
261+
"""Create the Mel-frequency cepstrum coefficients from an audio signal
262+
263+
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
264+
This is not the textbook implementation, but is implemented here to
265+
give consistency with librosa.
266+
267+
This output depends on the maximum value in the input spectrogram, and so
268+
may return different values for an audio clip split into snippets vs. a
269+
a full clip.
270+
271+
Inputs:
272+
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
273+
mel_spect (MelSpectrogram): melspectrogram of sig
274+
log_mels (bool): whether to use log-mel spectrograms instead of db-scaled
275+
s2db (SpectrogramToDB): a SpectrogramToDB instance
276+
dct_mat (Tensor): The transformation matrix (dct matrix), to be
277+
right-multiplied to row-wise data
278+
Outputs:
279+
Tensor: Mel-frequency cepstrum coefficients
280+
"""
281+
if log_mels:
282+
log_offset = 1e-6
283+
mel_spect = torch.log(mel_spect + log_offset)
284+
else:
285+
mel_spect = s2db(mel_spect)
286+
mfcc = torch.matmul(mel_spect, dct_mat.to(mel_spect.device))
287+
return mfcc
288+
289+
290+
def BLC2CBL(tensor):
291+
# type: (Tensor) -> Tensor
292+
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x
293+
Bands x Samples length
294+
295+
Inputs:
296+
tensor (Tensor): Tensor of spectrogram with shape (BxLxC)
297+
298+
Outputs:
299+
Tensor: Tensor of spectrogram with shape (CxBxL)
300+
"""
301+
return tensor.permute(2, 0, 1).contiguous()
302+
303+
304+
def mu_law_encoding(x, qc):
305+
# type: (Tensor/ndarray, int) -> Tensor/ndarray
306+
"""Encode signal based on mu-law companding. For more info see the
307+
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
308+
309+
This algorithm assumes the signal has been scaled to between -1 and 1 and
310+
returns a signal encoded with values from 0 to quantization_channels - 1
311+
312+
Inputs:
313+
x (Tensor): Input tensor
314+
qc (int): Number of channels (i.e. quantization channels)
315+
316+
Outputs:
317+
Tensor: Input after mu-law companding
318+
"""
319+
mu = qc - 1.
320+
if isinstance(x, np.ndarray):
321+
x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
322+
x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
323+
elif isinstance(x, torch.Tensor):
324+
if not x.dtype.is_floating_point:
325+
x = x.to(torch.float)
326+
mu = torch.tensor(mu, dtype=x.dtype)
327+
x_mu = torch.sign(x) * torch.log1p(mu *
328+
torch.abs(x)) / torch.log1p(mu)
329+
x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
330+
return x_mu
331+
332+
333+
def mu_law_expanding(x_mu, qc):
334+
# type: (Tensor/ndarray, int) -> Tensor/ndarray
335+
"""Decode mu-law encoded signal. For more info see the
336+
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
337+
338+
This expects an input with values between 0 and quantization_channels - 1
339+
and returns a signal scaled between -1 and 1.
340+
341+
Inputs:
342+
x_mu (Tensor): Input tensor
343+
qc (int): Number of channels (i.e. quantization channels)
344+
345+
Outputs:
346+
Tensor: Input after decoding
347+
"""
348+
mu = qc - 1.
349+
if isinstance(x_mu, np.ndarray):
350+
x = ((x_mu) / mu) * 2 - 1.
351+
x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
352+
elif isinstance(x_mu, torch.Tensor):
353+
if not x_mu.dtype.is_floating_point:
354+
x_mu = x_mu.to(torch.float)
355+
mu = torch.tensor(mu, dtype=x_mu.dtype)
356+
x = ((x_mu) / mu) * 2 - 1.
357+
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
358+
return x

0 commit comments

Comments
 (0)