@@ -196,6 +196,7 @@ def __call__(self, sig):
196196 if self .pad > 0 :
197197 with torch .no_grad ():
198198 sig = torch .nn .functional .pad (sig , (self .pad , self .pad ), "constant" )
199+ self .window = self .window .to (sig .device )
199200 spec_f = torch .stft (sig , self .n_fft , self .hop , self .ws ,
200201 self .window , center = False ,
201202 normalized = True , onesided = True ).transpose (1 , 2 )
@@ -225,7 +226,7 @@ def __init__(self, n_mels=40, sr=16000, f_max=None, f_min=0., n_stft=None):
225226
226227 def __call__ (self , spec_f ):
227228 if self .fb is None :
228- self .fb = self ._create_fb_matrix (spec_f .size (2 ))
229+ self .fb = self ._create_fb_matrix (spec_f .size (2 )). to ( spec_f . device )
229230 spec_m = torch .matmul (spec_f , self .fb ) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
230231 return spec_m
231232
@@ -280,7 +281,7 @@ def __call__(self, spec):
280281
281282 spec_db = self .multiplier * torch .log10 (spec / spec .max ()) # power -> dB
282283 if self .top_db is not None :
283- spec_db = torch .max (spec_db , torch . tensor ( self .top_db , dtype = spec_db . dtype ))
284+ spec_db = torch .max (spec_db , spec_db . new_full (( 1 ,), self .top_db ))
284285 return spec_db
285286
286287
0 commit comments