|
12 | 12 |
|
13 | 13 | VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
|
14 | 14 |
|
15 |
| -test_videos = [ |
16 |
| - "RATRACE_wave_f_nm_np1_fr_goo_37.avi", |
17 |
| - "TrumanShow_wave_f_nm_np1_fr_med_26.avi", |
18 |
| - "v_SoccerJuggling_g23_c01.avi", |
19 |
| - "v_SoccerJuggling_g24_c01.avi", |
20 |
| - "R6llTwEh07w.mp4", |
21 |
| - "SOX5yA1l24A.mp4", |
22 |
| - "WUzgd7C1pWA.mp4", |
23 |
| -] |
24 |
| - |
25 | 15 |
|
26 | 16 | @pytest.mark.skipif(_HAS_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder")
|
27 | 17 | class TestVideoGPUDecoder:
|
28 | 18 | @pytest.mark.skipif(av is None, reason="PyAV unavailable")
|
29 |
| - def test_frame_reading(self): |
30 |
| - for test_video in test_videos: |
31 |
| - full_path = os.path.join(VIDEO_DIR, test_video) |
32 |
| - decoder = VideoReader(full_path, device="cuda:0") |
33 |
| - with av.open(full_path) as container: |
34 |
| - for av_frame in container.decode(container.streams.video[0]): |
35 |
| - av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray()) |
36 |
| - vision_frames = next(decoder)["data"] |
37 |
| - mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float())) |
38 |
| - assert mean_delta < 0.75 |
| 19 | + @pytest.mark.parametrize( |
| 20 | + "video_file", |
| 21 | + [ |
| 22 | + "RATRACE_wave_f_nm_np1_fr_goo_37.avi", |
| 23 | + "TrumanShow_wave_f_nm_np1_fr_med_26.avi", |
| 24 | + "v_SoccerJuggling_g23_c01.avi", |
| 25 | + "v_SoccerJuggling_g24_c01.avi", |
| 26 | + "R6llTwEh07w.mp4", |
| 27 | + "SOX5yA1l24A.mp4", |
| 28 | + "WUzgd7C1pWA.mp4", |
| 29 | + ], |
| 30 | + ) |
| 31 | + def test_frame_reading(self, video_file): |
| 32 | + full_path = os.path.join(VIDEO_DIR, video_file) |
| 33 | + decoder = VideoReader(full_path, device="cuda:0") |
| 34 | + with av.open(full_path) as container: |
| 35 | + for av_frame in container.decode(container.streams.video[0]): |
| 36 | + av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray()) |
| 37 | + vision_frames = next(decoder)["data"] |
| 38 | + mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float())) |
| 39 | + assert mean_delta < 0.75 |
39 | 40 |
|
40 | 41 | @pytest.mark.skipif(av is None, reason="PyAV unavailable")
|
41 | 42 | @pytest.mark.parametrize("keyframes", [True, False])
|
@@ -65,16 +66,27 @@ def test_seek_reading(self, keyframes, full_path, duration):
|
65 | 66 | assert mean_delta < 0.75
|
66 | 67 |
|
67 | 68 | @pytest.mark.skipif(av is None, reason="PyAV unavailable")
|
68 |
| - def test_metadata(self): |
69 |
| - for test_video in test_videos: |
70 |
| - full_path = os.path.join(VIDEO_DIR, test_video) |
71 |
| - decoder = VideoReader(full_path, device="cuda:0") |
72 |
| - video_metadata = decoder.get_metadata()["video"] |
73 |
| - with av.open(full_path) as container: |
74 |
| - video = container.streams.video[0] |
75 |
| - av_duration = float(video.duration * video.time_base) |
76 |
| - assert math.isclose(video_metadata["duration"], av_duration, rel_tol=1e-2) |
77 |
| - assert math.isclose(video_metadata["fps"], video.base_rate, rel_tol=1e-2) |
| 69 | + @pytest.mark.parametrize( |
| 70 | + "video_file", |
| 71 | + [ |
| 72 | + "RATRACE_wave_f_nm_np1_fr_goo_37.avi", |
| 73 | + "TrumanShow_wave_f_nm_np1_fr_med_26.avi", |
| 74 | + "v_SoccerJuggling_g23_c01.avi", |
| 75 | + "v_SoccerJuggling_g24_c01.avi", |
| 76 | + "R6llTwEh07w.mp4", |
| 77 | + "SOX5yA1l24A.mp4", |
| 78 | + "WUzgd7C1pWA.mp4", |
| 79 | + ], |
| 80 | + ) |
| 81 | + def test_metadata(self, video_file): |
| 82 | + full_path = os.path.join(VIDEO_DIR, video_file) |
| 83 | + decoder = VideoReader(full_path, device="cuda:0") |
| 84 | + video_metadata = decoder.get_metadata()["video"] |
| 85 | + with av.open(full_path) as container: |
| 86 | + video = container.streams.video[0] |
| 87 | + av_duration = float(video.duration * video.time_base) |
| 88 | + assert math.isclose(video_metadata["duration"], av_duration, rel_tol=1e-2) |
| 89 | + assert math.isclose(video_metadata["fps"], video.base_rate, rel_tol=1e-2) |
78 | 90 |
|
79 | 91 |
|
80 | 92 | if __name__ == "__main__":
|
|
0 commit comments