diff --git a/.travis.yml b/.travis.yml index 07e1e8900a0..e86e7bc9a3a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,6 +35,7 @@ before_install: - pip install future - pip install pytest pytest-cov codecov - pip install mock + - conda install av -c conda-forge install: diff --git a/test/test_io.py b/test/test_io.py new file mode 100644 index 00000000000..775a00fd9b1 --- /dev/null +++ b/test/test_io.py @@ -0,0 +1,83 @@ +import os +import tempfile +import torch +import torchvision.io as io +import unittest + + +try: + import av +except ImportError: + av = None + + +class Tester(unittest.TestCase): + # compression adds artifacts, thus we add a tolerance of + # 6 in 0-255 range + TOLERANCE = 6 + + def _create_video_frames(self, num_frames, height, width): + y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) + data = [] + for i in range(num_frames): + xc = float(i) / num_frames + yc = 1 - float(i) / (2 * num_frames) + d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 + data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) + + return torch.stack(data, 0) + + @unittest.skipIf(av is None, "PyAV unavailable") + def test_write_read_video(self): + with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + data = self._create_video_frames(10, 300, 300) + io.write_video(f.name, data, fps=5) + + lv, _, info = io.read_video(f.name) + + self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE) + self.assertEqual(info["video_fps"], 5) + + @unittest.skipIf(av is None, "PyAV unavailable") + def test_read_timestamps(self): + with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + data = self._create_video_frames(10, 300, 300) + io.write_video(f.name, data, fps=5) + + pts = io.read_video_timestamps(f.name) + + # note: not all formats/codecs provide accurate information for computing the + # timestamps. For the format that we use here, this information is available, + # so we use it as a baseline + container = av.open(f.name) + stream = container.streams[0] + pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) + num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) + expected_pts = [i * pts_step for i in range(num_frames)] + + self.assertEqual(pts, expected_pts) + + @unittest.skipIf(av is None, "PyAV unavailable") + def test_read_partial_video(self): + with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + data = self._create_video_frames(10, 300, 300) + io.write_video(f.name, data, fps=5) + + pts = io.read_video_timestamps(f.name) + + for start in range(5): + for l in range(1, 4): + lv, _, _ = io.read_video(f.name, pts[start], pts[start + l - 1]) + s_data = data[start:(start + l)] + self.assertEqual(len(lv), l) + self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE) + + lv, _, _ = io.read_video(f.name, pts[4] + 1, pts[7]) + self.assertEqual(len(lv), 4) + self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) + + # TODO add tests for audio + + +if __name__ == '__main__': + unittest.main() diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 82ba966dd5a..68361bfb029 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -3,6 +3,7 @@ from torchvision import ops from torchvision import transforms from torchvision import utils +from torchvision import io try: from .version import __version__ # noqa: F401 diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py new file mode 100644 index 00000000000..3f7b9ab258b --- /dev/null +++ b/torchvision/io/__init__.py @@ -0,0 +1,6 @@ +from .video import write_video, read_video, read_video_timestamps + + +__all__ = [ + 'write_video', 'read_video', 'read_video_timestamps' +] diff --git a/torchvision/io/video.py b/torchvision/io/video.py new file mode 100644 index 00000000000..f80177b46dd --- /dev/null +++ b/torchvision/io/video.py @@ -0,0 +1,171 @@ +import gc +import torch +import numpy as np + +try: + import av +except ImportError: + av = None + + +def _check_av_available(): + if av is None: + raise ImportError("""\ +PyAV is not installed, and is necessary for the video operations in torchvision. +See https://github.com/mikeboers/PyAV#installation for instructions on how to +install PyAV on your system. +""") + + +# PyAV has some reference cycles +_CALLED_TIMES = 0 +_GC_COLLECTION_INTERVAL = 20 + + +def write_video(filename, video_array, fps): + """ + Writes a 4d tensor in [T, H, W, C] format in a video file + + Arguments: + filename (str): path where the video will be saved + video_array (Tensor[T, H, W, C]): tensor containing the individual frames, + as a uint8 tensor in [T, H, W, C] format + fps (Number): frames per second + """ + _check_av_available() + video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() + + container = av.open(filename, mode='w') + + stream = container.add_stream('mpeg4', rate=fps) + stream.width = video_array.shape[2] + stream.height = video_array.shape[1] + stream.pix_fmt = 'yuv420p' + + for img in video_array: + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + for packet in stream.encode(frame): + container.mux(packet) + + # Flush stream + for packet in stream.encode(): + container.mux(packet) + + # Close the file + container.close() + + +def _read_from_stream(container, start_offset, end_offset, stream, stream_name): + global _CALLED_TIMES, _GC_COLLECTION_INTERVAL + _CALLED_TIMES += 1 + if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: + gc.collect() + + container.seek(start_offset, any_frame=False, backward=True, stream=stream) + frames = [] + first_frame = None + for idx, frame in enumerate(container.decode(**stream_name)): + if frame.pts < start_offset: + first_frame = frame + continue + if first_frame and first_frame.pts < start_offset: + if frame.pts != start_offset: + frames.append(first_frame) + first_frame = None + frames.append(frame) + if frame.pts >= end_offset: + break + return frames + + +def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): + start, end = audio_frames[0].pts, audio_frames[-1].pts + total_aframes = aframes.shape[1] + step_per_aframe = (end - start + 1) / total_aframes + s_idx = 0 + e_idx = total_aframes + if start < ref_start: + s_idx = int((ref_start - start) / step_per_aframe) + if end > ref_end: + e_idx = int((ref_end - end) / step_per_aframe) + return aframes[:, s_idx:e_idx] + + +def read_video(filename, start_pts=0, end_pts=None): + """ + Reads a video from a file, returning both the video frames as well as + the audio frames + + Arguments: + filename (str): path to the video file + start_pts (int, optional): the start presentation time of the video + end_pts (int, optional): the end presentation time + + Returns: + vframes (Tensor[T, H, W, C]): the `T` video frames + aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels + and `L` is the number of points + info (Dict): metadata for the video and audio. Can contain the fields + - video_fps (float) + - audio_fps (int) + """ + _check_av_available() + + if end_pts is None: + end_pts = float("inf") + + if end_pts < start_pts: + raise ValueError("end_pts should be larger than start_pts, got " + "start_pts={} and end_pts={}".format(start_pts, end_pts)) + + container = av.open(filename) + info = {} + + video_frames = [] + if container.streams.video: + video_frames = _read_from_stream(container, start_pts, end_pts, + container.streams.video[0], {'video': 0}) + info["video_fps"] = float(container.streams.video[0].average_rate) + audio_frames = [] + if container.streams.audio: + audio_frames = _read_from_stream(container, start_pts, end_pts, + container.streams.audio[0], {'audio': 0}) + info["audio_fps"] = container.streams.audio[0].rate + + container.close() + + vframes = [frame.to_rgb().to_ndarray() for frame in video_frames] + aframes = [frame.to_ndarray() for frame in audio_frames] + vframes = torch.as_tensor(np.stack(vframes)) + if aframes: + aframes = np.concatenate(aframes, 1) + aframes = torch.as_tensor(aframes) + aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) + else: + aframes = torch.empty((1, 0), dtype=torch.float32) + + return vframes, aframes, info + + +def read_video_timestamps(filename): + """ + List the video frames timestamps. + + Note that the function decodes the whole video frame-by-frame. + + Arguments: + filename (str): path to the video file + + Returns: + pts (List[int]): presentation timestamps for each one of the frames + in the video. + """ + _check_av_available() + container = av.open(filename) + + video_frames = [] + if container.streams.video: + video_frames = _read_from_stream(container, 0, float("inf"), + container.streams.video[0], {'video': 0}) + container.close() + return [x.pts for x in video_frames]