Skip to content

Commit c844ac6

Browse files
goldsboroughcolesbury
authored andcommitted
Fixes after tensor/variable merge (pytorch#33)
1 parent 4a3e500 commit c844ac6

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

torchaudio/__init__.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@ def _bytes(s, e):
1818
return s.encode(e)
1919

2020

21+
def get_tensor_type_name(tensor):
22+
return tensor.type().replace('torch.', '').replace('Tensor', '')
23+
24+
2125
def check_input(src):
2226
if not torch.is_tensor(src):
2327
raise TypeError('Expected a tensor, got %s' % type(src))
24-
if not src.__module__ == 'torch':
28+
if src.is_cuda:
2529
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
2630

2731

@@ -57,7 +61,7 @@ def load(filepath, out=None, normalization=None):
5761
else:
5862
out = torch.FloatTensor()
5963
# load audio signal
60-
typename = type(out).__name__.replace('Tensor', '')
64+
typename = get_tensor_type_name(out)
6165
func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
6266
sample_rate_p = ffi.new('int*')
6367
func(str(filepath).encode("utf-8"), out, sample_rate_p)
@@ -109,7 +113,7 @@ def save(filepath, src, sample_rate):
109113
# save data to file
110114
filename, extension = os.path.splitext(filepath)
111115
check_input(src)
112-
typename = type(src).__name__.replace('Tensor', '')
116+
typename = get_tensor_type_name(src)
113117
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
114118
func(_bytes(filepath, "utf-8"), src,
115119
_bytes(extension[1:], "utf-8"), sample_rate)

0 commit comments

Comments
 (0)