Skip to content

Commit d8a37a2

Browse files
mthrokfacebook-github-bot
authored andcommitted
Properly set #samples passed to encoder (#3204)
Summary: Some audio encoders expect specific, exact number of samples described as in `AVCodecContext.frame_size`. The `AVFrame.nb_samples` is set for the frames passed to `AVFilterGraph`, but frames coming out of the graph do not necessarily have the same numbr of frames. This causes issues with encoding OPUS (among others). This commit fixes it by inserting `asetnsamples` to filter graph if a fixed number of samples is requested. Note: It turned out that FFmpeg 4.1 has issue with OPUS encoding. It does not properly discard some sample. We should probably move the minimum required FFmpeg to 4.2, but I am not sure if we can enforce it via ABI. Work around will be to issue an warning if encoding OPUS with 4.1. (follow-up) Pull Request resolved: #3204 Reviewed By: nateanl Differential Revision: D44374668 Pulled By: mthrok fbshipit-source-id: 10ef5333dc0677dfb83c8e40b78edd8ded1b21dc
1 parent 583174a commit d8a37a2

File tree

3 files changed

+72
-18
lines changed

3 files changed

+72
-18
lines changed

test/torchaudio_unittest/io/stream_writer_test.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -318,28 +318,58 @@ def test_video_num_frames(self, framerate, resolution, format):
318318
pass
319319

320320
@nested_params(
321-
["wav", "mp3", "flac"],
321+
["wav", "flac"],
322322
[8000, 16000, 44100],
323323
[1, 2],
324324
)
325-
def test_audio_num_frames(self, ext, sample_rate, num_channels):
326-
""""""
325+
def test_audio_num_frames_lossless(self, ext, sample_rate, num_channels):
326+
"""Lossless format preserves the data"""
327327
filename = f"test.{ext}"
328328

329+
data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, dtype="int16", channels_first=False)
330+
329331
# Write data
330332
dst = self.get_dst(filename)
331333
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
332-
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
334+
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, format="s16")
335+
with s.open():
336+
s.write_audio_chunk(0, data)
333337

334-
freq = 300
335-
duration = 60
336-
theta = torch.linspace(0, freq * 2 * 3.14 * duration, sample_rate * duration)
337-
if num_channels == 1:
338-
chunk = torch.sin(theta).unsqueeze(-1)
339-
else:
340-
chunk = torch.stack([torch.sin(theta), torch.cos(theta)], dim=-1)
338+
if self.test_fileobj:
339+
dst.flush()
340+
341+
# Load data
342+
s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
343+
s.add_audio_stream(-1)
344+
s.process_all_packets()
345+
(saved,) = s.pop_chunks()
346+
347+
self.assertEqual(saved, data)
348+
349+
@parameterized.expand(
350+
[
351+
("mp3", 1, 8000),
352+
("mp3", 1, 16000),
353+
("mp3", 1, 44100),
354+
("mp3", 2, 8000),
355+
("mp3", 2, 16000),
356+
("mp3", 2, 44100),
357+
("opus", 1, 48000),
358+
]
359+
)
360+
def test_audio_num_frames_lossy(self, ext, num_channels, sample_rate):
361+
"""Saving audio preserves the number of channels and frames"""
362+
filename = f"test.{ext}"
363+
364+
data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, channels_first=False)
365+
366+
# Write data
367+
dst = self.get_dst(filename)
368+
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
369+
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
341370
with s.open():
342-
s.write_audio_chunk(0, chunk)
371+
s.write_audio_chunk(0, data)
372+
343373
if self.test_fileobj:
344374
dst.flush()
345375

@@ -349,9 +379,21 @@ def test_audio_num_frames(self, ext, sample_rate, num_channels):
349379
s.process_all_packets()
350380
(saved,) = s.pop_chunks()
351381

352-
assert saved.shape == chunk.shape
353-
if format in ["wav", "flac"]:
354-
self.assertEqual(saved, chunk)
382+
# This test fails for OPUS if FFmpeg is 4.1, but it passes for 4.2+
383+
# 4.1 produces 48312 samples (extra 312)
384+
# Probably this commit fixes it.
385+
# https://github.com/FFmpeg/FFmpeg/commit/18aea7bdd96b320a40573bccabea56afeccdd91c
386+
# TODO: issue warning if 4.1?
387+
if ext == "opus":
388+
ver = torchaudio.utils.ffmpeg_utils.get_versions()["libavcodec"]
389+
# 5.1 libavcodec 59. 18.100
390+
# 4.4 libavcodec 58.134.100
391+
# 4.3 libavcodec 58. 91.100
392+
# 4.2 libavcodec 58. 54.100
393+
# 4.1 libavcodec 58. 35.100
394+
if ver[0] < 59 and ver[1] < 54:
395+
return
396+
self.assertEqual(saved.shape, data.shape)
355397

356398
def test_preserve_fps(self):
357399
"""Decimal point frame rate is properly saved

torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,19 @@ FilterGraph get_audio_filter(
233233
AVCodecContext* codec_ctx) {
234234
auto desc = [&]() -> std::string {
235235
if (src_fmt == codec_ctx->sample_fmt) {
236-
return "anull";
236+
if (!codec_ctx->frame_size) {
237+
return "anull";
238+
} else {
239+
std::stringstream ss;
240+
ss << "asetnsamples=n=" << codec_ctx->frame_size << ":p=0";
241+
return ss.str();
242+
}
237243
} else {
238244
std::stringstream ss;
239245
ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
246+
if (codec_ctx->frame_size) {
247+
ss << ",asetnsamples=n=" << codec_ctx->frame_size << ":p=0";
248+
}
240249
return ss.str();
241250
}
242251
}();

torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ void convert_func_(const torch::Tensor& chunk, AVFrame* buffer) {
4040
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(chunk.dim() == 2);
4141
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(chunk.size(1) == buffer->channels);
4242

43-
// TODO: make writable
4443
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00334
45-
TORCH_CHECK(av_frame_is_writable(buffer), "frame is not writable.");
44+
if (!av_frame_is_writable(buffer)) {
45+
int ret = av_frame_make_writable(buffer);
46+
TORCH_INTERNAL_ASSERT(
47+
ret >= 0, "Failed to make frame writable: ", av_err2string(ret));
48+
}
4649

4750
auto byte_size = chunk.numel() * chunk.element_size();
4851
memcpy(buffer->data[0], chunk.data_ptr(), byte_size);

0 commit comments

Comments
 (0)