@@ -18,10 +18,14 @@ def _bytes(s, e):
1818 return s .encode (e )
1919
2020
21+ def get_tensor_type_name (tensor ):
22+ return tensor .type ().replace ('torch.' , '' ).replace ('Tensor' , '' )
23+
24+
2125def check_input (src ):
2226 if not torch .is_tensor (src ):
2327 raise TypeError ('Expected a tensor, got %s' % type (src ))
24- if not src .__module__ == 'torch' :
28+ if src .is_cuda :
2529 raise TypeError ('Expected a CPU based tensor, got %s' % type (src ))
2630
2731
@@ -57,7 +61,7 @@ def load(filepath, out=None, normalization=None):
5761 else :
5862 out = torch .FloatTensor ()
5963 # load audio signal
60- typename = type (out ). __name__ . replace ( 'Tensor' , '' )
64+ typename = get_tensor_type_name (out )
6165 func = getattr (th_sox , 'libthsox_{}_read_audio_file' .format (typename ))
6266 sample_rate_p = ffi .new ('int*' )
6367 func (str (filepath ).encode ("utf-8" ), out , sample_rate_p )
@@ -109,7 +113,7 @@ def save(filepath, src, sample_rate):
109113 # save data to file
110114 filename , extension = os .path .splitext (filepath )
111115 check_input (src )
112- typename = type (src ). __name__ . replace ( 'Tensor' , '' )
116+ typename = get_tensor_type_name (src )
113117 func = getattr (th_sox , 'libthsox_{}_write_audio_file' .format (typename ))
114118 func (_bytes (filepath , "utf-8" ), src ,
115119 _bytes (extension [1 :], "utf-8" ), sample_rate )
0 commit comments