|  | 
| 9 | 9 | import os | 
| 10 | 10 | from functools import partial | 
| 11 | 11 | 
 | 
| 12 |  | -from .utils import in_fbcode | 
| 13 |  | - | 
| 14 | 12 | os.environ["TORCH_LOGS"] = "output_code" | 
| 15 | 13 | import json | 
| 16 | 14 | import subprocess | 
|  | 
| 47 | 45 | from .utils import ( | 
| 48 | 46 |     all_supported_devices, | 
| 49 | 47 |     assert_frames_equal, | 
|  | 48 | +    assert_tensor_close_on_at_least, | 
|  | 49 | +    get_ffmpeg_major_version, | 
|  | 50 | +    in_fbcode, | 
|  | 51 | +    IS_WINDOWS, | 
| 50 | 52 |     NASA_AUDIO, | 
| 51 | 53 |     NASA_AUDIO_MP3, | 
| 52 | 54 |     NASA_VIDEO, | 
|  | 
| 55 | 57 |     SINE_MONO_S32, | 
| 56 | 58 |     SINE_MONO_S32_44100, | 
| 57 | 59 |     SINE_MONO_S32_8000, | 
|  | 60 | +    TEST_SRC_2_720P, | 
| 58 | 61 |     unsplit_device_str, | 
| 59 | 62 | ) | 
| 60 | 63 | 
 | 
| @@ -1381,24 +1384,117 @@ def decode(self, file_path) -> torch.Tensor: | 
| 1381 | 1384 |         frames, *_ = get_frames_in_range(decoder, start=0, stop=60) | 
| 1382 | 1385 |         return frames | 
| 1383 | 1386 | 
 | 
| 1384 |  | -    @pytest.mark.parametrize("format", ("mov", "mp4", "avi")) | 
| 1385 |  | -    # TODO-VideoEncoder: enable additional formats (mkv, webm) | 
| 1386 |  | -    def test_video_encoder_test_round_trip(self, tmp_path, format): | 
| 1387 |  | -        # TODO-VideoEncoder: Test with FFmpeg's testsrc2 video | 
| 1388 |  | -        asset = NASA_VIDEO | 
| 1389 |  | - | 
|  | 1387 | +    @pytest.mark.parametrize("format", ("mov", "mp4", "mkv", "webm")) | 
|  | 1388 | +    def test_video_encoder_round_trip(self, tmp_path, format): | 
| 1390 | 1389 |         # Test that decode(encode(decode(asset))) == decode(asset) | 
|  | 1390 | +        ffmpeg_version = get_ffmpeg_major_version() | 
|  | 1391 | +        # In FFmpeg6, the default codec's best pixel format is lossy for all container formats but webm. | 
|  | 1392 | +        # As a result, we skip the round trip test. | 
|  | 1393 | +        if ffmpeg_version == 6 and format != "webm": | 
|  | 1394 | +            pytest.skip( | 
|  | 1395 | +                f"FFmpeg6 defaults to lossy encoding for {format}, skipping round-trip test." | 
|  | 1396 | +            ) | 
|  | 1397 | +        if format == "webm" and ( | 
|  | 1398 | +            ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) | 
|  | 1399 | +        ): | 
|  | 1400 | +            pytest.skip("Codec for webm is not available in this FFmpeg installation.") | 
|  | 1401 | +        asset = TEST_SRC_2_720P | 
| 1391 | 1402 |         source_frames = self.decode(str(asset.path)).data | 
| 1392 | 1403 | 
 | 
| 1393 | 1404 |         encoded_path = str(tmp_path / f"encoder_output.{format}") | 
| 1394 | 1405 |         frame_rate = 30  # Frame rate is fixed with num frames decoded | 
| 1395 |  | -        encode_video_to_file(source_frames, frame_rate, encoded_path) | 
|  | 1406 | +        encode_video_to_file( | 
|  | 1407 | +            frames=source_frames, frame_rate=frame_rate, filename=encoded_path, crf=0 | 
|  | 1408 | +        ) | 
| 1396 | 1409 |         round_trip_frames = self.decode(encoded_path).data | 
| 1397 |  | - | 
| 1398 |  | -        # Check that PSNR for decode(encode(samples)) is above 30 | 
|  | 1410 | +        assert source_frames.shape == round_trip_frames.shape | 
|  | 1411 | +        assert source_frames.dtype == round_trip_frames.dtype | 
|  | 1412 | + | 
|  | 1413 | +        # If FFmpeg selects a codec or pixel format that does lossy encoding, assert 99% of pixels | 
|  | 1414 | +        # are within a higher tolerance. | 
|  | 1415 | +        if ffmpeg_version == 6: | 
|  | 1416 | +            assert_close = partial(assert_tensor_close_on_at_least, percentage=99) | 
|  | 1417 | +            atol = 15 | 
|  | 1418 | +        else: | 
|  | 1419 | +            assert_close = torch.testing.assert_close | 
|  | 1420 | +            atol = 2 | 
| 1399 | 1421 |         for s_frame, rt_frame in zip(source_frames, round_trip_frames): | 
| 1400 |  | -            res = psnr(s_frame, rt_frame) | 
|  | 1422 | +            assert psnr(s_frame, rt_frame) > 30 | 
|  | 1423 | +            assert_close(s_frame, rt_frame, atol=atol, rtol=0) | 
|  | 1424 | + | 
|  | 1425 | +    @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") | 
|  | 1426 | +    @pytest.mark.parametrize( | 
|  | 1427 | +        "format", ("mov", "mp4", "avi", "mkv", "webm", "flv", "gif") | 
|  | 1428 | +    ) | 
|  | 1429 | +    def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): | 
|  | 1430 | +        ffmpeg_version = get_ffmpeg_major_version() | 
|  | 1431 | +        if format == "webm": | 
|  | 1432 | +            if ffmpeg_version == 4: | 
|  | 1433 | +                pytest.skip( | 
|  | 1434 | +                    "Codec for webm is not available in the FFmpeg4 installation." | 
|  | 1435 | +                ) | 
|  | 1436 | +            if IS_WINDOWS and ffmpeg_version in (6, 7): | 
|  | 1437 | +                pytest.skip( | 
|  | 1438 | +                    "Codec for webm is not available in the FFmpeg6/7 installation on Windows." | 
|  | 1439 | +                ) | 
|  | 1440 | +        asset = TEST_SRC_2_720P | 
|  | 1441 | +        source_frames = self.decode(str(asset.path)).data | 
|  | 1442 | +        frame_rate = 30 | 
|  | 1443 | + | 
|  | 1444 | +        # Encode with FFmpeg CLI | 
|  | 1445 | +        temp_raw_path = str(tmp_path / "temp_input.raw") | 
|  | 1446 | +        with open(temp_raw_path, "wb") as f: | 
|  | 1447 | +            f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) | 
|  | 1448 | + | 
|  | 1449 | +        ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}") | 
|  | 1450 | +        crf = 0 | 
|  | 1451 | +        quality_params = ["-crf", str(crf)] | 
|  | 1452 | +        # Some codecs (ex. MPEG4) do not support CRF. | 
|  | 1453 | +        # Flags not supported by the selected codec will be ignored. | 
|  | 1454 | +        ffmpeg_cmd = [ | 
|  | 1455 | +            "ffmpeg", | 
|  | 1456 | +            "-y", | 
|  | 1457 | +            "-f", | 
|  | 1458 | +            "rawvideo", | 
|  | 1459 | +            "-pix_fmt", | 
|  | 1460 | +            "rgb24", | 
|  | 1461 | +            "-s", | 
|  | 1462 | +            f"{source_frames.shape[3]}x{source_frames.shape[2]}", | 
|  | 1463 | +            "-r", | 
|  | 1464 | +            str(frame_rate), | 
|  | 1465 | +            "-i", | 
|  | 1466 | +            temp_raw_path, | 
|  | 1467 | +            *quality_params, | 
|  | 1468 | +            ffmpeg_encoded_path, | 
|  | 1469 | +        ] | 
|  | 1470 | +        subprocess.run(ffmpeg_cmd, check=True) | 
|  | 1471 | + | 
|  | 1472 | +        # Encode with our video encoder | 
|  | 1473 | +        encoder_output_path = str(tmp_path / f"encoder_output.{format}") | 
|  | 1474 | +        encode_video_to_file( | 
|  | 1475 | +            frames=source_frames, | 
|  | 1476 | +            frame_rate=frame_rate, | 
|  | 1477 | +            filename=encoder_output_path, | 
|  | 1478 | +            crf=crf, | 
|  | 1479 | +        ) | 
|  | 1480 | + | 
|  | 1481 | +        ffmpeg_frames = self.decode(ffmpeg_encoded_path).data | 
|  | 1482 | +        encoder_frames = self.decode(encoder_output_path).data | 
|  | 1483 | + | 
|  | 1484 | +        assert ffmpeg_frames.shape[0] == encoder_frames.shape[0] | 
|  | 1485 | + | 
|  | 1486 | +        # If FFmpeg selects a codec or pixel format that uses qscale (not crf), | 
|  | 1487 | +        # the VideoEncoder outputs *slightly* different frames. | 
|  | 1488 | +        # There may be additional subtle differences in the encoder. | 
|  | 1489 | +        percentage = 94 if ffmpeg_version == 6 or format == "avi" else 99 | 
|  | 1490 | + | 
|  | 1491 | +        # Check that PSNR between both encoded versions is high | 
|  | 1492 | +        for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): | 
|  | 1493 | +            res = psnr(ff_frame, enc_frame) | 
| 1401 | 1494 |             assert res > 30 | 
|  | 1495 | +            assert_tensor_close_on_at_least( | 
|  | 1496 | +                ff_frame, enc_frame, percentage=percentage, atol=2 | 
|  | 1497 | +            ) | 
| 1402 | 1498 | 
 | 
| 1403 | 1499 | 
 | 
| 1404 | 1500 | if __name__ == "__main__": | 
|  | 
0 commit comments