Skip to content

Commit b311c4c

Browse files
jph00soumith
authored andcommitted
Bug fix: Use correct device for MEL2 functions so MEL2 works on CUDA tensors (pytorch#77)
1 parent d62d3c0 commit b311c4c

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

torchaudio/transforms.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def __call__(self, sig):
196196
if self.pad > 0:
197197
with torch.no_grad():
198198
sig = torch.nn.functional.pad(sig, (self.pad, self.pad), "constant")
199+
self.window = self.window.to(sig.device)
199200
spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
200201
self.window, center=False,
201202
normalized=True, onesided=True).transpose(1, 2)
@@ -225,7 +226,7 @@ def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_stft=None):
225226

226227
def __call__(self, spec_f):
227228
if self.fb is None:
228-
self.fb = self._create_fb_matrix(spec_f.size(2))
229+
self.fb = self._create_fb_matrix(spec_f.size(2)).to(spec_f.device)
229230
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
230231
return spec_m
231232

@@ -280,7 +281,7 @@ def __call__(self, spec):
280281

281282
spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB
282283
if self.top_db is not None:
283-
spec_db = torch.max(spec_db, torch.tensor(self.top_db, dtype=spec_db.dtype))
284+
spec_db = torch.max(spec_db, spec_db.new_full((1,),self.top_db))
284285
return spec_db
285286

286287

0 commit comments

Comments
 (0)