Skip to content

Commit 61202b9

Browse files
Dan-FloresDaniel Flores
andauthored
Update Video Encoder and tests for 6 container formats (#913)
Co-authored-by: Daniel Flores <[email protected]>
1 parent e5b2eef commit 61202b9

File tree

8 files changed

+179
-47
lines changed

8 files changed

+179
-47
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#include "src/torchcodec/_core/Encoder.h"
55
#include "torch/types.h"
66

7+
extern "C" {
8+
#include <libavutil/pixdesc.h>
9+
}
10+
711
namespace facebook::torchcodec {
812

913
namespace {
@@ -587,15 +591,6 @@ void VideoEncoder::initializeEncoder(
587591
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
588592
avCodecContext_.reset(avCodecContext);
589593

590-
// Set encoding options
591-
// TODO-VideoEncoder: Allow bitrate to be set
592-
std::optional<int> desiredBitRate = videoStreamOptions.bitRate;
593-
if (desiredBitRate.has_value()) {
594-
TORCH_CHECK(
595-
*desiredBitRate >= 0, "bit_rate=", *desiredBitRate, " must be >= 0.");
596-
}
597-
avCodecContext_->bit_rate = desiredBitRate.value_or(0);
598-
599594
// Store dimension order and input pixel format
600595
// TODO-VideoEncoder: Remove assumption that tensor in NCHW format
601596
auto sizes = frames_.sizes();
@@ -608,9 +603,15 @@ void VideoEncoder::initializeEncoder(
608603
outWidth_ = inWidth_;
609604
outHeight_ = inHeight_;
610605

611-
// Use YUV420P as default output format
612606
// TODO-VideoEncoder: Enable other pixel formats
613-
outPixelFormat_ = AV_PIX_FMT_YUV420P;
607+
// Let FFmpeg choose best pixel format to minimize loss
608+
outPixelFormat_ = avcodec_find_best_pix_fmt_of_list(
609+
getSupportedPixelFormats(*avCodec), // List of supported formats
610+
AV_PIX_FMT_GBRP, // We reorder input to GBRP currently
611+
0, // No alpha channel
612+
nullptr // Discard conversion loss information
613+
);
614+
TORCH_CHECK(outPixelFormat_ != -1, "Failed to find best pix fmt")
614615

615616
// Configure codec parameters
616617
avCodecContext_->codec_id = avCodec->id;
@@ -621,37 +622,39 @@ void VideoEncoder::initializeEncoder(
621622
avCodecContext_->time_base = {1, inFrameRate_};
622623
avCodecContext_->framerate = {inFrameRate_, 1};
623624

624-
// TODO-VideoEncoder: Allow GOP size and max B-frames to be set
625-
if (videoStreamOptions.gopSize.has_value()) {
626-
avCodecContext_->gop_size = *videoStreamOptions.gopSize;
627-
} else {
628-
avCodecContext_->gop_size = 12; // Default GOP size
625+
// Set flag for containers that require extradata to be in the codec context
626+
if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) {
627+
avCodecContext_->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
629628
}
630629

631-
if (videoStreamOptions.maxBFrames.has_value()) {
632-
avCodecContext_->max_b_frames = *videoStreamOptions.maxBFrames;
633-
} else {
634-
avCodecContext_->max_b_frames = 0; // No max B-frames to reduce compression
630+
// Apply videoStreamOptions
631+
AVDictionary* options = nullptr;
632+
if (videoStreamOptions.crf.has_value()) {
633+
av_dict_set(
634+
&options,
635+
"crf",
636+
std::to_string(videoStreamOptions.crf.value()).c_str(),
637+
0);
635638
}
639+
int status = avcodec_open2(avCodecContext_.get(), avCodec, &options);
640+
av_dict_free(&options);
636641

637-
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
638642
TORCH_CHECK(
639643
status == AVSUCCESS,
640644
"avcodec_open2 failed: ",
641645
getFFMPEGErrorStringFromErrorCode(status));
642646

643-
AVStream* avStream = avformat_new_stream(avFormatContext_.get(), nullptr);
644-
TORCH_CHECK(avStream != nullptr, "Couldn't create new stream.");
647+
avStream_ = avformat_new_stream(avFormatContext_.get(), nullptr);
648+
TORCH_CHECK(avStream_ != nullptr, "Couldn't create new stream.");
645649

646650
// Set the stream time base to encode correct frame timestamps
647-
avStream->time_base = avCodecContext_->time_base;
651+
avStream_->time_base = avCodecContext_->time_base;
648652
status = avcodec_parameters_from_context(
649-
avStream->codecpar, avCodecContext_.get());
653+
avStream_->codecpar, avCodecContext_.get());
650654
TORCH_CHECK(
651655
status == AVSUCCESS,
652656
"avcodec_parameters_from_context failed: ",
653657
getFFMPEGErrorStringFromErrorCode(status));
654-
streamIndex_ = avStream->index;
655658
}
656659

657660
void VideoEncoder::encode() {
@@ -694,7 +697,7 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame(
694697
outWidth_,
695698
outHeight_,
696699
outPixelFormat_,
697-
SWS_BILINEAR,
700+
SWS_BICUBIC, // Used by FFmpeg CLI
698701
nullptr,
699702
nullptr,
700703
nullptr));
@@ -757,7 +760,7 @@ void VideoEncoder::encodeFrame(
757760
"Error while sending frame: ",
758761
getFFMPEGErrorStringFromErrorCode(status));
759762

760-
while (true) {
763+
while (status >= 0) {
761764
ReferenceAVPacket packet(autoAVPacket);
762765
status = avcodec_receive_packet(avCodecContext_.get(), packet.get());
763766
if (status == AVERROR(EAGAIN) || status == AVERROR_EOF) {
@@ -776,7 +779,16 @@ void VideoEncoder::encodeFrame(
776779
"Error receiving packet: ",
777780
getFFMPEGErrorStringFromErrorCode(status));
778781

779-
packet->stream_index = streamIndex_;
782+
// The code below is borrowed from torchaudio:
783+
// https://github.com/pytorch/audio/blob/b6a3368a45aaafe05f1a6a9f10c68adc5e944d9e/src/libtorio/ffmpeg/stream_writer/encoder.cpp#L46
784+
// Setting packet->duration to 1 allows the last frame to be properly
785+
// encoded, and needs to be set before calling av_packet_rescale_ts.
786+
if (packet->duration == 0) {
787+
packet->duration = 1;
788+
}
789+
av_packet_rescale_ts(
790+
packet.get(), avCodecContext_->time_base, avStream_->time_base);
791+
packet->stream_index = avStream_->index;
780792

781793
status = av_interleaved_write_frame(avFormatContext_.get(), packet.get());
782794
TORCH_CHECK(

src/torchcodec/_core/Encoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class VideoEncoder {
153153

154154
UniqueEncodingAVFormatContext avFormatContext_;
155155
UniqueAVCodecContext avCodecContext_;
156-
int streamIndex_ = -1;
156+
AVStream* avStream_;
157157
UniqueSwsContext swsContext_;
158158

159159
const torch::Tensor frames_;

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,26 @@ const int* getSupportedSampleRates(const AVCodec& avCodec) {
9090
return supportedSampleRates;
9191
}
9292

93+
const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec) {
94+
const AVPixelFormat* supportedPixelFormats = nullptr;
95+
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1
96+
int numPixelFormats = 0;
97+
int ret = avcodec_get_supported_config(
98+
nullptr,
99+
&avCodec,
100+
AV_CODEC_CONFIG_PIX_FORMAT,
101+
0,
102+
reinterpret_cast<const void**>(&supportedPixelFormats),
103+
&numPixelFormats);
104+
if (ret < 0 || supportedPixelFormats == nullptr) {
105+
TORCH_CHECK(false, "Couldn't get supported pixel formats from encoder.");
106+
}
107+
#else
108+
supportedPixelFormats = avCodec.pix_fmts;
109+
#endif
110+
return supportedPixelFormats;
111+
}
112+
93113
const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec) {
94114
const AVSampleFormat* supportedSampleFormats = nullptr;
95115
#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7.1

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ void setDuration(const UniqueAVFrame& frame, int64_t duration);
168168

169169
const int* getSupportedSampleRates(const AVCodec& avCodec);
170170
const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec);
171+
const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec);
171172

172173
int getNumChannels(const UniqueAVFrame& avFrame);
173174
int getNumChannels(const UniqueAVCodecContext& avCodecContext);

src/torchcodec/_core/StreamOptions.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ struct VideoStreamOptions {
4545
std::string_view deviceVariant = "default";
4646

4747
// Encoding options
48-
std::optional<int> bitRate;
49-
std::optional<int> gopSize;
50-
std::optional<int> maxBFrames;
48+
// TODO-VideoEncoder: Consider adding other optional fields here
49+
// (bit rate, gop size, max b frames, preset)
50+
std::optional<int> crf;
5151
};
5252

5353
struct AudioStreamOptions {

src/torchcodec/_core/custom_ops.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3333
m.def(
3434
"encode_audio_to_file(Tensor samples, int sample_rate, str filename, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()");
3535
m.def(
36-
"encode_video_to_file(Tensor frames, int frame_rate, str filename) -> ()");
36+
"encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()");
3737
m.def(
3838
"encode_audio_to_tensor(Tensor samples, int sample_rate, str format, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> Tensor");
3939
m.def(
@@ -501,8 +501,10 @@ OpsAudioFramesOutput get_frames_by_pts_in_range_audio(
501501
void encode_video_to_file(
502502
const at::Tensor& frames,
503503
int64_t frame_rate,
504-
std::string_view file_name) {
504+
std::string_view file_name,
505+
std::optional<int64_t> crf = std::nullopt) {
505506
VideoStreamOptions videoStreamOptions;
507+
videoStreamOptions.crf = crf;
506508
VideoEncoder(
507509
frames,
508510
validateInt64ToInt(frame_rate, "frame_rate"),

src/torchcodec/_core/ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def encode_video_to_file_abstract(
259259
frames: torch.Tensor,
260260
frame_rate: int,
261261
filename: str,
262+
crf: Optional[int] = None,
262263
) -> None:
263264
return
264265

test/test_ops.py

Lines changed: 108 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
import os
1010
from functools import partial
1111

12-
from .utils import in_fbcode
13-
1412
os.environ["TORCH_LOGS"] = "output_code"
1513
import json
1614
import subprocess
@@ -47,6 +45,10 @@
4745
from .utils import (
4846
all_supported_devices,
4947
assert_frames_equal,
48+
assert_tensor_close_on_at_least,
49+
get_ffmpeg_major_version,
50+
in_fbcode,
51+
IS_WINDOWS,
5052
NASA_AUDIO,
5153
NASA_AUDIO_MP3,
5254
NASA_VIDEO,
@@ -55,6 +57,7 @@
5557
SINE_MONO_S32,
5658
SINE_MONO_S32_44100,
5759
SINE_MONO_S32_8000,
60+
TEST_SRC_2_720P,
5861
unsplit_device_str,
5962
)
6063

@@ -1381,24 +1384,117 @@ def decode(self, file_path) -> torch.Tensor:
13811384
frames, *_ = get_frames_in_range(decoder, start=0, stop=60)
13821385
return frames
13831386

1384-
@pytest.mark.parametrize("format", ("mov", "mp4", "avi"))
1385-
# TODO-VideoEncoder: enable additional formats (mkv, webm)
1386-
def test_video_encoder_test_round_trip(self, tmp_path, format):
1387-
# TODO-VideoEncoder: Test with FFmpeg's testsrc2 video
1388-
asset = NASA_VIDEO
1389-
1387+
@pytest.mark.parametrize("format", ("mov", "mp4", "mkv", "webm"))
1388+
def test_video_encoder_round_trip(self, tmp_path, format):
13901389
# Test that decode(encode(decode(asset))) == decode(asset)
1390+
ffmpeg_version = get_ffmpeg_major_version()
1391+
# In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm.
1392+
# As a result, we skip the round trip test.
1393+
if ffmpeg_version == 6 and format != "webm":
1394+
pytest.skip(
1395+
f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test."
1396+
)
1397+
if format == "webm" and (
1398+
ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7))
1399+
):
1400+
pytest.skip("Codec for webm is not available in this FFmpeg installation.")
1401+
asset = TEST_SRC_2_720P
13911402
source_frames = self.decode(str(asset.path)).data
13921403

13931404
encoded_path = str(tmp_path / f"encoder_output.{format}")
13941405
frame_rate = 30 # Frame rate is fixed with num frames decoded
1395-
encode_video_to_file(source_frames, frame_rate, encoded_path)
1406+
encode_video_to_file(
1407+
frames=source_frames, frame_rate=frame_rate, filename=encoded_path, crf=0
1408+
)
13961409
round_trip_frames = self.decode(encoded_path).data
1397-
1398-
# Check that PSNR for decode(encode(samples)) is above 30
1410+
assert source_frames.shape == round_trip_frames.shape
1411+
assert source_frames.dtype == round_trip_frames.dtype
1412+
1413+
# If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels
1414+
# are within a higher tolerance.
1415+
if ffmpeg_version == 6:
1416+
assert_close = partial(assert_tensor_close_on_at_least, percentage=99)
1417+
atol = 15
1418+
else:
1419+
assert_close = torch.testing.assert_close
1420+
atol = 2
13991421
for s_frame, rt_frame in zip(source_frames, round_trip_frames):
1400-
res = psnr(s_frame, rt_frame)
1422+
assert psnr(s_frame, rt_frame) > 30
1423+
assert_close(s_frame, rt_frame, atol=atol, rtol=0)
1424+
1425+
@pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available")
1426+
@pytest.mark.parametrize(
1427+
"format", ("mov", "mp4", "avi", "mkv", "webm", "flv", "gif")
1428+
)
1429+
def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format):
1430+
ffmpeg_version = get_ffmpeg_major_version()
1431+
if format == "webm":
1432+
if ffmpeg_version == 4:
1433+
pytest.skip(
1434+
"Codec for webm is not available in the FFmpeg4 installation."
1435+
)
1436+
if IS_WINDOWS and ffmpeg_version in (6, 7):
1437+
pytest.skip(
1438+
"Codec for webm is not available in the FFmpeg6/7 installation on Windows."
1439+
)
1440+
asset = TEST_SRC_2_720P
1441+
source_frames = self.decode(str(asset.path)).data
1442+
frame_rate = 30
1443+
1444+
# Encode with FFmpeg CLI
1445+
temp_raw_path = str(tmp_path / "temp_input.raw")
1446+
with open(temp_raw_path, "wb") as f:
1447+
f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes())
1448+
1449+
ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}")
1450+
crf = 0
1451+
quality_params = ["-crf", str(crf)]
1452+
# Some codecs (ex. MPEG4) do not support CRF.
1453+
# Flags not supported by the selected codec will be ignored.
1454+
ffmpeg_cmd = [
1455+
"ffmpeg",
1456+
"-y",
1457+
"-f",
1458+
"rawvideo",
1459+
"-pix_fmt",
1460+
"rgb24",
1461+
"-s",
1462+
f"{source_frames.shape[3]}x{source_frames.shape[2]}",
1463+
"-r",
1464+
str(frame_rate),
1465+
"-i",
1466+
temp_raw_path,
1467+
*quality_params,
1468+
ffmpeg_encoded_path,
1469+
]
1470+
subprocess.run(ffmpeg_cmd, check=True)
1471+
1472+
# Encode with our video encoder
1473+
encoder_output_path = str(tmp_path / f"encoder_output.{format}")
1474+
encode_video_to_file(
1475+
frames=source_frames,
1476+
frame_rate=frame_rate,
1477+
filename=encoder_output_path,
1478+
crf=crf,
1479+
)
1480+
1481+
ffmpeg_frames = self.decode(ffmpeg_encoded_path).data
1482+
encoder_frames = self.decode(encoder_output_path).data
1483+
1484+
assert ffmpeg_frames.shape[0] == encoder_frames.shape[0]
1485+
1486+
# If FFmpeg selects a codec or pixel format that uses qscale (not crf),
1487+
# the VideoEncoder outputs *slightly* different frames.
1488+
# There may be additional subtle differences in the encoder.
1489+
percentage = 94 if ffmpeg_version == 6 or format == "avi" else 99
1490+
1491+
# Check that PSNR between both encoded versions is high
1492+
for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames):
1493+
res = psnr(ff_frame, enc_frame)
14011494
assert res > 30
1495+
assert_tensor_close_on_at_least(
1496+
ff_frame, enc_frame, percentage=percentage, atol=2
1497+
)
14021498

14031499

14041500
if __name__ == "__main__":

0 commit comments

Comments
 (0)