Skip to content

Add additional filter graph option to StreamWriter #3194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/torchaudio_unittest/io/stream_writer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,3 +580,57 @@ def write_audio(buffer, bit_rate):
out1_size = dst.tell()

self.assertGreater(out1_size, out0_size)

def test_filter_graph_audio(self):
"""Can apply additional effect with filter graph"""
sample_rate = 8000
num_channels = 2
ext = "wav"
filename = f"test.{ext}"

original = get_audio_chunk("s16", num_channels=num_channels, sample_rate=sample_rate)

dst = self.get_dst(filename)
w = StreamWriter(dst, format=ext)
w.add_audio_stream(sample_rate=8000, num_channels=num_channels, filter_desc="areverse", format="s16")

with w.open():
w.write_audio_chunk(0, original)

# check
if self.test_fileobj:
dst.flush()

reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_audio_stream(-1)
reader.process_all_packets()
(output,) = reader.pop_chunks()

self.assertEqual(output, original.flip(0))

def test_filter_graph_video(self):
"""Can apply additional effect with filter graph"""
rate = 30
num_frames, width, height = 400, 160, 90
ext = "mp4"
filename = f"test.{ext}"

original = torch.zeros((num_frames, 3, height, width), dtype=torch.uint8)

dst = self.get_dst(filename)
w = StreamWriter(dst, format=ext)
w.add_video_stream(frame_rate=rate, format="rgb24", height=height, width=width, filter_desc="framestep=2")

with w.open():
w.write_video_chunk(0, original)

# check
if self.test_fileobj:
dst.flush()

reader = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
reader.add_video_stream(-1)
reader.process_all_packets()
(output,) = reader.pop_chunks()

self.assertEqual(output.shape, [num_frames // 2, 3, height, width])
24 changes: 12 additions & 12 deletions torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class StreamReader {
/// (opening source).
explicit StreamReader(
const std::string& src,
const c10::optional<std::string>& format = {},
const c10::optional<OptionDict>& option = {});
const c10::optional<std::string>& format = c10::nullopt,
const c10::optional<OptionDict>& option = c10::nullopt);

/// @cond

Expand All @@ -72,8 +72,8 @@ class StreamReader {
// TODO: Move this to wrapper class
explicit StreamReader(
AVIOContext* io_ctx,
const c10::optional<std::string>& format = {},
const c10::optional<OptionDict>& option = {});
const c10::optional<std::string>& format = c10::nullopt,
const c10::optional<OptionDict>& option = c10::nullopt);

/// @endcond

Expand Down Expand Up @@ -190,9 +190,9 @@ class StreamReader {
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_option);
const c10::optional<std::string>& filter_desc = c10::nullopt,
const c10::optional<std::string>& decoder = c10::nullopt,
const c10::optional<OptionDict>& decoder_option = c10::nullopt);
/// Define an output video stream.
///
/// @param i,frames_per_chunk,num_chunks,filter_desc,decoder,decoder_option
Expand All @@ -211,10 +211,10 @@ class StreamReader {
int64_t i,
int64_t frames_per_chunk,
int64_t num_chunks,
const c10::optional<std::string>& filter_desc,
const c10::optional<std::string>& decoder,
const c10::optional<OptionDict>& decoder_option,
const c10::optional<std::string>& hw_accel);
const c10::optional<std::string>& filter_desc = c10::nullopt,
const c10::optional<std::string>& decoder = c10::nullopt,
const c10::optional<OptionDict>& decoder_option = c10::nullopt,
const c10::optional<std::string>& hw_accel = c10::nullopt);
/// Remove an output stream.
///
/// @param i The index of the output stream to be removed.
Expand Down Expand Up @@ -288,7 +288,7 @@ class StreamReader {
/// @param timeout See `process_packet_block()`
/// @param backoff See `process_packet_block()`
int fill_buffer(
const c10::optional<double>& timeout = {},
const c10::optional<double>& timeout = c10::nullopt,
const double backoff = 10.);

///@}
Expand Down
40 changes: 32 additions & 8 deletions torchaudio/csrc/ffmpeg/stream_writer/encode_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -513,19 +513,26 @@ FilterGraph get_audio_filter_graph(
AVSampleFormat src_fmt,
int sample_rate,
uint64_t channel_layout,
const c10::optional<std::string>& filter_desc,
AVSampleFormat enc_fmt,
int nb_samples) {
const std::string filter_desc = [&]() -> const std::string {
const std::string desc = [&]() -> const std::string {
if (src_fmt == enc_fmt) {
if (nb_samples == 0) {
return "anull";
return filter_desc.value_or("anull");
} else {
std::stringstream ss;
if (filter_desc) {
ss << filter_desc.value() << ",";
}
ss << "asetnsamples=n=" << nb_samples << ":p=0";
return ss.str();
}
} else {
std::stringstream ss;
if (filter_desc) {
ss << filter_desc.value() << ",";
}
ss << "aformat=" << av_get_sample_fmt_name(enc_fmt);
if (nb_samples > 0) {
ss << ",asetnsamples=n=" << nb_samples << ":p=0";
Expand All @@ -537,7 +544,7 @@ FilterGraph get_audio_filter_graph(
FilterGraph f{AVMEDIA_TYPE_AUDIO};
f.add_audio_src(src_fmt, {1, sample_rate}, sample_rate, channel_layout);
f.add_sink();
f.add_process(filter_desc);
f.add_process(desc);
f.create_filter();
return f;
}
Expand All @@ -547,13 +554,17 @@ FilterGraph get_video_filter_graph(
AVRational rate,
int width,
int height,
const c10::optional<std::string>& filter_desc,
AVPixelFormat enc_fmt,
bool is_cuda) {
auto desc = [&]() -> std::string {
if (src_fmt == enc_fmt || is_cuda) {
return "null";
return filter_desc.value_or("null");
} else {
std::stringstream ss;
if (filter_desc) {
ss << filter_desc.value() << ",";
}
ss << "format=" << av_get_pix_fmt_name(enc_fmt);
return ss.str();
}
Expand Down Expand Up @@ -624,7 +635,8 @@ EncodeProcess get_audio_encode_process(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<CodecConfig>& codec_config) {
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
// 1. Check the source format, rate and channels
const AVSampleFormat src_fmt = get_sample_fmt(format);
TORCH_CHECK(
Expand Down Expand Up @@ -663,7 +675,12 @@ EncodeProcess get_audio_encode_process(

// 5. Build filter graph
FilterGraph filter_graph = get_audio_filter_graph(
src_fmt, src_sample_rate, channel_layout, enc_fmt, codec_ctx->frame_size);
src_fmt,
src_sample_rate,
channel_layout,
filter_desc,
enc_fmt,
codec_ctx->frame_size);

// 6. Instantiate source frame
AVFramePtr src_frame = get_audio_frame(
Expand Down Expand Up @@ -701,7 +718,8 @@ EncodeProcess get_video_encode_process(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config) {
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
// 1. Checkc the source format, rate and resolution
const AVPixelFormat src_fmt = get_pix_fmt(format);
AVRational src_rate = av_d2q(frame_rate, 1 << 24);
Expand Down Expand Up @@ -742,7 +760,13 @@ EncodeProcess get_video_encode_process(

// 5. Build filter graph
FilterGraph filter_graph = get_video_filter_graph(
src_fmt, src_rate, src_width, src_height, enc_fmt, hw_accel.has_value());
src_fmt,
src_rate,
src_width,
src_height,
filter_desc,
enc_fmt,
hw_accel.has_value());

// 6. Instantiate source frame
AVFramePtr src_frame = [&]() {
Expand Down
6 changes: 4 additions & 2 deletions torchaudio/csrc/ffmpeg/stream_writer/encode_process.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ EncodeProcess get_audio_encode_process(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<CodecConfig>& codec_config);
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc);

EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx,
Expand All @@ -53,6 +54,7 @@ EncodeProcess get_video_encode_process(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config);
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc);

}; // namespace torchaudio::io
12 changes: 8 additions & 4 deletions torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ void StreamWriter::add_audio_stream(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<CodecConfig>& codec_config) {
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
Expand All @@ -73,7 +74,8 @@ void StreamWriter::add_audio_stream(
encoder,
encoder_option,
encoder_format,
codec_config));
codec_config,
filter_desc));
}

void StreamWriter::add_video_stream(
Expand All @@ -85,7 +87,8 @@ void StreamWriter::add_video_stream(
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config) {
const c10::optional<CodecConfig>& codec_config,
const c10::optional<std::string>& filter_desc) {
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
Expand All @@ -100,7 +103,8 @@ void StreamWriter::add_video_stream(
encoder_option,
encoder_format,
hw_accel,
codec_config));
codec_config,
filter_desc));
}

void StreamWriter::set_metadata(const OptionDict& metadata) {
Expand Down
36 changes: 21 additions & 15 deletions torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#pragma once

#include <torch/torch.h>
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encode_process.h>
Expand Down Expand Up @@ -34,7 +34,7 @@ class StreamWriter {
/// ``dst``.
explicit StreamWriter(
const std::string& dst,
const c10::optional<std::string>& format = {});
const c10::optional<std::string>& format = c10::nullopt);

/// @cond

Expand All @@ -45,7 +45,7 @@ class StreamWriter {
// TODO: Move this into wrapper class.
explicit StreamWriter(
AVIOContext* io_ctx,
const c10::optional<std::string>& format);
const c10::optional<std::string>& format = c10::nullopt);

/// @endcond

Expand Down Expand Up @@ -100,14 +100,17 @@ class StreamWriter {
/// To list supported formats for the encoder, you can use
/// ``ffmpeg -h encoder=<ENCODER>`` command.
/// @param codec_config Codec configuration.
/// @param filter_desc Additional processing to apply before
/// encoding the input data
void add_audio_stream(
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<CodecConfig>& codec_config);
const c10::optional<std::string>& encoder = c10::nullopt,
const c10::optional<OptionDict>& encoder_option = c10::nullopt,
const c10::optional<std::string>& encoder_format = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);

/// Add an output video stream.
///
Expand Down Expand Up @@ -139,16 +142,19 @@ class StreamWriter {
///
/// If `None`, the video chunk Tensor has to be a CPU Tensor.
/// @endparblock
/// @param filter_desc Additional processing to apply before
/// encoding the input data
void add_video_stream(
double frame_rate,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<CodecConfig>& codec_config);
const c10::optional<std::string>& encoder = c10::nullopt,
const c10::optional<OptionDict>& encoder_option = c10::nullopt,
const c10::optional<std::string>& encoder_format = c10::nullopt,
const c10::optional<std::string>& hw_accel = c10::nullopt,
const c10::optional<CodecConfig>& codec_config = c10::nullopt,
const c10::optional<std::string>& filter_desc = c10::nullopt);
/// Set file-level metadata
/// @param metadata metadata.
void set_metadata(const OptionDict& metadata);
Expand All @@ -160,7 +166,7 @@ class StreamWriter {
/// Open the output file / device and write the header.
///
/// @param opt Private options for protocol, device and muxer.
void open(const c10::optional<OptionDict>& opt);
void open(const c10::optional<OptionDict>& opt = c10::nullopt);
/// Close the output file / device and finalize metadata.
void close();

Expand All @@ -182,7 +188,7 @@ class StreamWriter {
void write_audio_chunk(
int i,
const torch::Tensor& frames,
const c10::optional<double>& pts = {});
const c10::optional<double>& pts = c10::nullopt);
/// Write video data
/// @param i Stream index.
/// @param frames Video/image tensor. Shape: ``(time, channel, height,
Expand All @@ -203,7 +209,7 @@ class StreamWriter {
void write_video_chunk(
int i,
const torch::Tensor& frames,
const c10::optional<double>& pts = {});
const c10::optional<double>& pts = c10::nullopt);
/// Flush the frames from encoders and write the frames to the destination.
void flush();
};
Expand Down
Loading