Skip to content

Commit 3155d0e

Browse files
mthrokfacebook-github-bot
authored andcommitted
Add additional filter graph option to StreamWriter (#3194)
Summary: Pull Request resolved: #3194 Differential Revision: D44283910 Pulled By: mthrok fbshipit-source-id: 9f73bcabdeb01b88220d47809ef860e854878c85
1 parent b1de9f1 commit 3155d0e

File tree

7 files changed

+137
-22
lines changed

7 files changed

+137
-22
lines changed

test/torchaudio_unittest/io/stream_writer_test.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,3 +580,57 @@ def write_audio(buffer, bit_rate):
580580
out1_size = dst.tell()
581581

582582
self.assertGreater(out1_size, out0_size)
583+
584+
def test_filter_graph_audio(self):
585+
"""Can apply additional effect with filter graph"""
586+
sample_rate = 8000
587+
num_channels = 2
588+
ext = "wav"
589+
filename = f"test.{ext}"
590+
591+
original = get_audio_chunk("s16", num_channels=num_channels, sample_rate=sample_rate)
592+
593+
dst = self.get_dst(filename)
594+
w = StreamWriter(dst, format=ext)
595+
w.add_audio_stream(sample_rate=8000, num_channels=num_channels, filter_desc="areverse", format="s16")
596+
597+
with w.open():
598+
w.write_audio_chunk(0, original)
599+
600+
# check
601+
if self.test_fileobj:
602+
dst.flush()
603+
604+
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
605+
reader.add_audio_stream(-1)
606+
reader.process_all_packets()
607+
(output,) = reader.pop_chunks()
608+
609+
self.assertEqual(output, original.flip(0))
610+
611+
def test_filter_graph_video(self):
612+
"""Can apply additional effect with filter graph"""
613+
rate = 30
614+
num_frames, width, height = 400, 160, 90
615+
ext = "mp4"
616+
filename = f"test.{ext}"
617+
618+
original = torch.zeros((num_frames, 3, height, width), dtype=torch.uint8)
619+
620+
dst = self.get_dst(filename)
621+
w = StreamWriter(dst, format=ext)
622+
w.add_video_stream(frame_rate=rate, format="rgb24", height=height, width=width, filter_desc="framestep=2")
623+
624+
with w.open():
625+
w.write_video_chunk(0, original)
626+
627+
# check
628+
if self.test_fileobj:
629+
dst.flush()
630+
631+
reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
632+
reader.add_video_stream(-1)
633+
reader.process_all_packets()
634+
(output,) = reader.pop_chunks()
635+
636+
self.assertEqual(output.shape, [num_frames // 2, 3, height, width])

torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -513,19 +513,26 @@ FilterGraph get_audio_filter_graph(
513513
AVSampleFormat src_fmt,
514514
int sample_rate,
515515
uint64_t channel_layout,
516+
const c10::optional<std::string>& filter_desc,
516517
AVSampleFormat enc_fmt,
517518
int nb_samples) {
518-
const std::string filter_desc = [&]() -> const std::string {
519+
const std::string desc = [&]() -> const std::string {
519520
if (src_fmt == enc_fmt) {
520521
if (nb_samples == 0) {
521-
return "anull";
522+
return filter_desc.value_or("anull");
522523
} else {
523524
std::stringstream ss;
525+
if (filter_desc) {
526+
ss << filter_desc.value() << ",";
527+
}
524528
ss << "asetnsamples=n=" << nb_samples << ":p=0";
525529
return ss.str();
526530
}
527531
} else {
528532
std::stringstream ss;
533+
if (filter_desc) {
534+
ss << filter_desc.value() << ",";
535+
}
529536
ss << "aformat=" << av_get_sample_fmt_name(enc_fmt);
530537
if (nb_samples > 0) {
531538
ss << ",asetnsamples=n=" << nb_samples << ":p=0";
@@ -537,7 +544,7 @@ FilterGraph get_audio_filter_graph(
537544
FilterGraph f{AVMEDIA_TYPE_AUDIO};
538545
f.add_audio_src(src_fmt, {1, sample_rate}, sample_rate, channel_layout);
539546
f.add_sink();
540-
f.add_process(filter_desc);
547+
f.add_process(desc);
541548
f.create_filter();
542549
return f;
543550
}
@@ -547,13 +554,17 @@ FilterGraph get_video_filter_graph(
547554
AVRational rate,
548555
int width,
549556
int height,
557+
const c10::optional<std::string>& filter_desc,
550558
AVPixelFormat enc_fmt,
551559
bool is_cuda) {
552560
auto desc = [&]() -> std::string {
553561
if (src_fmt == enc_fmt || is_cuda) {
554-
return "null";
562+
return filter_desc.value_or("null");
555563
} else {
556564
std::stringstream ss;
565+
if (filter_desc) {
566+
ss << filter_desc.value() << ",";
567+
}
557568
ss << "format=" << av_get_pix_fmt_name(enc_fmt);
558569
return ss.str();
559570
}
@@ -624,7 +635,8 @@ EncodeProcess get_audio_encode_process(
624635
const c10::optional<std::string>& encoder,
625636
const c10::optional<OptionDict>& encoder_option,
626637
const c10::optional<std::string>& encoder_format,
627-
const c10::optional<CodecConfig>& codec_config) {
638+
const c10::optional<CodecConfig>& codec_config,
639+
const c10::optional<std::string>& filter_desc) {
628640
// 1. Check the source format, rate and channels
629641
const AVSampleFormat src_fmt = get_sample_fmt(format);
630642
TORCH_CHECK(
@@ -663,7 +675,12 @@ EncodeProcess get_audio_encode_process(
663675

664676
// 5. Build filter graph
665677
FilterGraph filter_graph = get_audio_filter_graph(
666-
src_fmt, src_sample_rate, channel_layout, enc_fmt, codec_ctx->frame_size);
678+
src_fmt,
679+
src_sample_rate,
680+
channel_layout,
681+
filter_desc,
682+
enc_fmt,
683+
codec_ctx->frame_size);
667684

668685
// 6. Instantiate source frame
669686
AVFramePtr src_frame = get_audio_frame(
@@ -701,7 +718,8 @@ EncodeProcess get_video_encode_process(
701718
const c10::optional<OptionDict>& encoder_option,
702719
const c10::optional<std::string>& encoder_format,
703720
const c10::optional<std::string>& hw_accel,
704-
const c10::optional<CodecConfig>& codec_config) {
721+
const c10::optional<CodecConfig>& codec_config,
722+
const c10::optional<std::string>& filter_desc) {
705723
// 1. Checkc the source format, rate and resolution
706724
const AVPixelFormat src_fmt = get_pix_fmt(format);
707725
AVRational src_rate = av_d2q(frame_rate, 1 << 24);
@@ -742,7 +760,13 @@ EncodeProcess get_video_encode_process(
742760

743761
// 5. Build filter graph
744762
FilterGraph filter_graph = get_video_filter_graph(
745-
src_fmt, src_rate, src_width, src_height, enc_fmt, hw_accel.has_value());
763+
src_fmt,
764+
src_rate,
765+
src_width,
766+
src_height,
767+
filter_desc,
768+
enc_fmt,
769+
hw_accel.has_value());
746770

747771
// 6. Instantiate source frame
748772
AVFramePtr src_frame = [&]() {

torchaudio/csrc/ffmpeg/stream_writer/encode_process.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ EncodeProcess get_audio_encode_process(
4141
const c10::optional<std::string>& encoder,
4242
const c10::optional<OptionDict>& encoder_option,
4343
const c10::optional<std::string>& encoder_format,
44-
const c10::optional<CodecConfig>& codec_config);
44+
const c10::optional<CodecConfig>& codec_config,
45+
const c10::optional<std::string>& filter_desc);
4546

4647
EncodeProcess get_video_encode_process(
4748
AVFormatContext* format_ctx,
@@ -53,6 +54,7 @@ EncodeProcess get_video_encode_process(
5354
const c10::optional<OptionDict>& encoder_option,
5455
const c10::optional<std::string>& encoder_format,
5556
const c10::optional<std::string>& hw_accel,
56-
const c10::optional<CodecConfig>& codec_config);
57+
const c10::optional<CodecConfig>& codec_config,
58+
const c10::optional<std::string>& filter_desc);
5759

5860
}; // namespace torchaudio::io

torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ void StreamWriter::add_audio_stream(
6060
const c10::optional<std::string>& encoder,
6161
const c10::optional<OptionDict>& encoder_option,
6262
const c10::optional<std::string>& encoder_format,
63-
const c10::optional<CodecConfig>& codec_config) {
63+
const c10::optional<CodecConfig>& codec_config,
64+
const c10::optional<std::string>& filter_desc) {
6465
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
6566
TORCH_INTERNAL_ASSERT(
6667
pFormatContext->nb_streams == processes.size(),
@@ -73,7 +74,8 @@ void StreamWriter::add_audio_stream(
7374
encoder,
7475
encoder_option,
7576
encoder_format,
76-
codec_config));
77+
codec_config,
78+
filter_desc));
7779
}
7880

7981
void StreamWriter::add_video_stream(
@@ -85,7 +87,8 @@ void StreamWriter::add_video_stream(
8587
const c10::optional<OptionDict>& encoder_option,
8688
const c10::optional<std::string>& encoder_format,
8789
const c10::optional<std::string>& hw_accel,
88-
const c10::optional<CodecConfig>& codec_config) {
90+
const c10::optional<CodecConfig>& codec_config,
91+
const c10::optional<std::string>& filter_desc) {
8992
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
9093
TORCH_INTERNAL_ASSERT(
9194
pFormatContext->nb_streams == processes.size(),
@@ -100,7 +103,8 @@ void StreamWriter::add_video_stream(
100103
encoder_option,
101104
encoder_format,
102105
hw_accel,
103-
codec_config));
106+
codec_config,
107+
filter_desc));
104108
}
105109

106110
void StreamWriter::set_metadata(const OptionDict& metadata) {

torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,17 @@ class StreamWriter {
100100
/// To list supported formats for the encoder, you can use
101101
/// ``ffmpeg -h encoder=<ENCODER>`` command.
102102
/// @param codec_config Codec configuration.
103+
/// @param filter_desc Additional processing to apply before
104+
/// encoding the input data
103105
void add_audio_stream(
104106
int sample_rate,
105107
int num_channels,
106108
const std::string& format,
107109
const c10::optional<std::string>& encoder,
108110
const c10::optional<OptionDict>& encoder_option,
109111
const c10::optional<std::string>& encoder_format,
110-
const c10::optional<CodecConfig>& codec_config);
112+
const c10::optional<CodecConfig>& codec_config,
113+
const c10::optional<std::string>& filter_desc);
111114

112115
/// Add an output video stream.
113116
///
@@ -139,6 +142,8 @@ class StreamWriter {
139142
///
140143
/// If `None`, the video chunk Tensor has to be a CPU Tensor.
141144
/// @endparblock
145+
/// @param filter_desc Additional processing to apply before
146+
/// encoding the input data
142147
void add_video_stream(
143148
double frame_rate,
144149
int width,
@@ -148,7 +153,8 @@ class StreamWriter {
148153
const c10::optional<OptionDict>& encoder_option,
149154
const c10::optional<std::string>& encoder_format,
150155
const c10::optional<std::string>& hw_accel,
151-
const c10::optional<CodecConfig>& codec_config);
156+
const c10::optional<CodecConfig>& codec_config,
157+
const c10::optional<std::string>& filter_desc);
152158
/// Set file-level metadata
153159
/// @param metadata metadata.
154160
void set_metadata(const OptionDict& metadata);

torchaudio/csrc/ffmpeg/stream_writer/tensor_converter.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,12 @@ void write_interlaced_video(
144144
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(2) == buffer->width);
145145
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(3) == num_channels);
146146

147-
// TODO: writable
148147
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
149-
TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
148+
if (!av_frame_is_writable(buffer)) {
149+
int ret = av_frame_make_writable(buffer);
150+
TORCH_INTERNAL_ASSERT(
151+
ret >= 0, "Failed to make frame writable: ", av_err2string(ret));
152+
}
150153

151154
size_t stride = buffer->width * num_channels;
152155
uint8_t* src = frame.data_ptr<uint8_t>();
@@ -191,9 +194,12 @@ void write_planar_video(
191194
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(2), buffer->height);
192195
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(frame.size(3), buffer->width);
193196

194-
// TODO: writable
195197
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00472
196-
TORCH_INTERNAL_ASSERT(av_frame_is_writable(buffer), "frame is not writable.");
198+
if (!av_frame_is_writable(buffer)) {
199+
int ret = av_frame_make_writable(buffer);
200+
TORCH_INTERNAL_ASSERT(
201+
ret >= 0, "Failed to make frame writable: ", av_err2string(ret));
202+
}
197203

198204
for (int j = 0; j < num_colors; ++j) {
199205
uint8_t* src = frame.index({0, j}).data_ptr<uint8_t>();

torchaudio/io/_stream_writer.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,15 @@ def decorator(obj):
5353
Default: ``None``."""
5454

5555

56+
_filter_desc = """Additional processing to apply before encoding the input media.
57+
"""
58+
5659
_format_common_args = _format_doc(
5760
encoder=_encoder,
5861
encoder_option=_encoder_option,
5962
encoder_format=_encoder_format,
6063
codec_config=_codec_config,
64+
filter_desc=_filter_desc,
6165
)
6266

6367

@@ -159,6 +163,7 @@ def add_audio_stream(
159163
encoder_option: Optional[Dict[str, str]] = None,
160164
encoder_format: Optional[str] = None,
161165
codec_config: Optional[CodecConfig] = None,
166+
filter_desc: Optional[str] = None,
162167
):
163168
"""Add an output audio stream.
164169
@@ -186,9 +191,11 @@ def add_audio_stream(
186191
encoder_format (str or None, optional): {encoder_format}
187192
188193
codec_config (CodecConfig or None, optional): {codec_config}
194+
195+
filter_desc (str or None, optional): {filter_desc}
189196
"""
190197
self._s.add_audio_stream(
191-
sample_rate, num_channels, format, encoder, encoder_option, encoder_format, codec_config
198+
sample_rate, num_channels, format, encoder, encoder_option, encoder_format, codec_config, filter_desc
192199
)
193200

194201
@_format_common_args
@@ -203,6 +210,7 @@ def add_video_stream(
203210
encoder_format: Optional[str] = None,
204211
hw_accel: Optional[str] = None,
205212
codec_config: Optional[CodecConfig] = None,
213+
filter_desc: Optional[str] = None,
206214
):
207215
"""Add an output video stream.
208216
@@ -245,9 +253,20 @@ def add_video_stream(
245253
Default: ``None``.
246254
247255
codec_config (CodecConfig or None, optional): {codec_config}
256+
257+
filter_desc (str or None, optional): {filter_desc}
248258
"""
249259
self._s.add_video_stream(
250-
frame_rate, width, height, format, encoder, encoder_option, encoder_format, hw_accel, codec_config
260+
frame_rate,
261+
width,
262+
height,
263+
format,
264+
encoder,
265+
encoder_option,
266+
encoder_format,
267+
hw_accel,
268+
codec_config,
269+
filter_desc,
251270
)
252271

253272
def set_metadata(self, metadata: Dict[str, str]):

0 commit comments

Comments
 (0)