From cf5d0a0a7867d8ede0125a3e6716a8769d1932af Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 12 Jun 2019 16:28:49 +0200 Subject: [PATCH 01/15] WIP --- torchvision/io/video.py | 25 ++++ torchvision/io/video_reader.py | 224 +++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 torchvision/io/video.py create mode 100644 torchvision/io/video_reader.py diff --git a/torchvision/io/video.py b/torchvision/io/video.py new file mode 100644 index 00000000000..92ea2127c25 --- /dev/null +++ b/torchvision/io/video.py @@ -0,0 +1,25 @@ +import av +import torch + + +def write_video(filename, video_array, fps): + video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() + + container = av.open(filename, mode='w') + + stream = container.add_stream('mpeg4', rate=fps) + stream.width = video_array.shape[2] + stream.height = video_array.shape[1] + stream.pix_fmt = 'yuv420p' + + for img in video_array: + frame = av.VideoFrame.from_ndarray(img, format='rgb24') + for packet in stream.encode(frame): + container.mux(packet) + + # Flush stream + for packet in stream.encode(): + container.mux(packet) + + # Close the file + container.close() diff --git a/torchvision/io/video_reader.py b/torchvision/io/video_reader.py new file mode 100644 index 00000000000..f17be948dc8 --- /dev/null +++ b/torchvision/io/video_reader.py @@ -0,0 +1,224 @@ +import av +import gc +import warnings + + + +_CALLED_TIMES = 0 +_GC_COLLECTION_INTERVAL = 20 + + +# remove warnings +av.logging.set_level(av.logging.ERROR) + + + +class VideoReader(object): + """ + Simple wrapper around PyAV that exposes a few useful functions for + dealing with video reading. + """ + def __init__(self, video_path, sampling_rate=1, decode_lossy=False, audio_resample_rate=None): + """ + Arguments: + video_path (str): path of the video to be loaded + """ + self.container = av.open(video_path) + self.sampling_rate = sampling_rate + self.resampler = None + if audio_resample_rate is not None: + self.resampler = av.AudioResampler(rate=audio_resample_rate) + + + if self.container.streams.video: + # enable multi-threaded video decoding + if decode_lossy: + warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning) + self.container.streams.video[0].thread_type = 'AUTO' + self.video_stream = self.container.streams.video[0] + else: + self.video_stream = None + + def seek(self, offset, backward=True, any_frame=False): + stream = self.video_stream + self.container.seek(offset, any_frame=any_frame, backward=backward, stream=stream) + + def _occasional_gc(self): + # there are a lot of reference cycles in PyAV, so need to manually call + # the garbage collector from time to time + global _CALLED_TIMES, _GC_COLLECTION_INTERVAL + _CALLED_TIMES += 1 + if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: + gc.collect() + + def _read_video(self, offset, num_frames): + self._occasional_gc() + self.seek(offset) + video_frames = [] + count = 0 + for idx, frame in enumerate(self.container.decode(video=0)): + if frame.pts < offset: + continue + video_frames.append(frame) + if count >= num_frames - 1: + break + count += 1 + return video_frames + + def _resample_audio_frame(self, frame): + curr_pts = frame.pts + frame.pts = None + frame = self.resampler.resample(frame) + frame.pts = curr_pts + return frame + + + def _read_audio(self, offset, end_offset): + self._occasional_gc() + if not self.container.streams.audio: + return [] + + self.container.seek(offset, backward=True, any_frame=False, stream=self.container.streams.audio[0]) + + audio_frames = [] + first_frame = None + for idx, frame in enumerate(self.container.decode(audio=0)): + if frame.pts < offset: + first_frame = frame + continue + if first_frame and first_frame.pts < offset: + if self.resampler is not None: + first_frame = self._resample_audio_frame(first_frame) + audio_frames.append(first_frame) + first_frame = None + # if we want to resample audio to a different framerate + if self.resampler is not None: + frame = self._resample_audio_frame(frame) + audio_frames.append(frame) + if frame.pts > end_offset: + break + return audio_frames + + def read(self, offset, num_frames): + """ + Reads video frames and audio frames starting from offset. + The number of video frames read is given by num_frames. + The number of audio frames read is defined by the start and + end time of the first and last video frames, respectively + Arguments: + offset (int): the start time from the read + num_frames (int): the number of video frames to be read + Returns: + video_frames (List[av.VideoFrame]) + audio_frames (List[av.AudioFrame]) + """ + if self.container is None: + return [], [] + + num_frames = self.sampling_rate * num_frames + video_frames = self._read_video(offset, num_frames) + if len(video_frames) < 1: + end_offset = offset + elif len(video_frames) < 2: + end_offset = video_frames[-1].pts + else: + step = video_frames[-1].pts - video_frames[-2].pts + end_offset = video_frames[-1].pts + step - 1 + try: + audio_frames = self._read_audio(offset, end_offset) + except av.AVError: + audio_frames = [] + return video_frames, audio_frames + + def list_keyframes(self): + """ + Returns a list of start times for all the keyframes in the video + Returns: + keyframes (List[int]) + """ + keyframes = [] + if self.video_stream is None: + return [] + pts = -1 + while True: + try: + self.seek(pts + 1, backward=False) + except av.AVError: + break + packet = next(self.container.demux(video=0)) + pts = packet.pts + #TODO: double check if this is needed + if pts is None: + # should we simply return []? + return keyframes + + if packet.is_keyframe: + keyframes.append(pts) + return keyframes + + def _compute_end_video_pts(self): + self.seek(self.container.duration, any_frame=True) + end_step = next(self.container.demux(video=0)).pts + if end_step is None: + self.seek(self.container.duration, any_frame=False) + gen = self.container.demux(video=0) + last_pts = 0 + while True: + last_pts = next(gen).pts + if last_pts is None: + break + end_step = last_pts + return end_step + + def _compute_start_video_pts(self): + self.seek(0) + start = next(self.container.demux(video=0)).pts + return start + + def _compute_step_pts(self): + self.seek(0) + pts = [] + num = 11 + gen = self.container.demux(video=0) + for _ in range(num): + next(gen) + for _ in range(num): + pts.append(next(gen).pts) + print(pts) + steps = [p1 - p2 for p1, p2 in zip(pts[1:], pts[:-1])] + print(steps) + steps = max(set(steps), key=steps.count) + return int(steps) + + def _compute_step_pts(self): + frames = self._read_video(0, 2) + steps = frames[1].pts - frames[0].pts + return steps + + def list_every(self, n_frames): + step = 1 / float(self.video_stream.average_rate * self.video_stream.time_base) + end = self._compute_end_video_pts() + start = self._compute_start_video_pts() + step = self._compute_step_pts() + """ + orig_step = int(step) + for i in range(-10, 10): + if (end - start) % (orig_step + i) == 0: + step = orig_step + i + break + """ + return list(range(start, end + 1, int(step)))[::n_frames] + + + def _decode_every(self): + """ + A function used for truly decoding every single frame. + This should not be used outside of the dataset indexing step + Returns: + timestamp of every frame within the video (List[int]) + """ + if self.video_stream is None or self.container is None: + return [] + self.seek(0, backward=False) + d = [p for p in self.container.decode(video=0)] + return [x.pts for x in d[::1]] From d71c91944b635509ac26f35099bb1dd4cfd7df22 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 14 Jun 2019 19:26:42 +0200 Subject: [PATCH 02/15] WIP --- test/test_io.py | 21 +++++++++++ torchvision/__init__.py | 1 + torchvision/io/__init__.py | 6 +++ torchvision/io/video.py | 76 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 104 insertions(+) create mode 100644 test/test_io.py create mode 100644 torchvision/io/__init__.py diff --git a/test/test_io.py b/test/test_io.py new file mode 100644 index 00000000000..b5357031bcb --- /dev/null +++ b/test/test_io.py @@ -0,0 +1,21 @@ +import os +import tempfile +import torch +import torchvision.io as io +import unittest + + +class Tester(unittest.TestCase): + + def test_write_read_video(self): + with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + data = torch.randint(0, 255, (10, 300, 300, 3), dtype=torch.uint8) + io.write_video(f.name, data, fps=5) + + lv, _ = io.read_video(f.name) + print((data.float() - lv.float()).abs().max()) + + + +if __name__ == '__main__': + unittest.main() diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 82ba966dd5a..68361bfb029 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -3,6 +3,7 @@ from torchvision import ops from torchvision import transforms from torchvision import utils +from torchvision import io try: from .version import __version__ # noqa: F401 diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py new file mode 100644 index 00000000000..dc9c181a655 --- /dev/null +++ b/torchvision/io/__init__.py @@ -0,0 +1,6 @@ +from .video import write_video, read_video + + +__all__ = [ + 'write_video', 'read_video' +] diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 92ea2127c25..cd6a20d5713 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -1,5 +1,7 @@ import av import torch +import numpy as np +import math def write_video(filename, video_array, fps): @@ -23,3 +25,77 @@ def write_video(filename, video_array, fps): # Close the file container.close() + + +def _read_from_stream(container, start_offset, end_offset, stream, stream_name): + container.seek(start_offset, any_frame=False, backward=True, stream=stream) + frames = [] + first_frame = None + for idx, frame in enumerate(container.decode(**stream_name)): + if frame.pts < start_offset: + first_frame = frame + continue + if first_frame and first_frame.pts < start_offset: + audio_frames.append(first_frame) + first_frame = None + frames.append(frame) + if frame.pts > end_offset: + break + return frames + + +def read_video(filename, start_pts=0, end_pts=math.inf): + container = av.open(filename) + + video_frames = [] + if container.streams.video: + video_frames = _read_from_stream(container, start_pts, end_pts, container.streams.video[0], {'video': 0}) + audio_frames = [] + if container.streams.audio: + audio_frames = _read_from_stream(container, start_pts, end_pts, container.streams.audio[0], {'audio': 0}) + + container.close() + + vframes = [frame.to_rgb().to_ndarray() for frame in video_frames] + aframes = [frame.to_ndarray() for frame in audio_frames] + vframes = torch.as_tensor(np.stack(vframes)) + if aframes: + aframes = np.concatenate(aframes, 1) + aframes = torch.as_tensor(aframes) + else: + aframes = torch.empty((1, 0), dtype=torch.float32) + + # return video_frames, audio_frames + return vframes, aframes + + +def _read_video(filename, start_offset, end_offset): + container = av.open(filename) + + # video + container.seek(start_offset, any_frame=False, backward=True, stream=container.streams.video[0]) + video_frames = [] + for idx, frame in enumerate(container.decode(video=0)): + if frame.pts < start_offset: + continue + if frame.pts > end_offset: + break + video_frames.append(frame) + + # audio + container.seek(start_offset, backward=True, any_frame=False, stream=container.streams.audio[0]) + audio_frames = [] + first_frame = None + for idx, frame in enumerate(container.decode(audio=0)): + if frame.pts < start_offset: + first_frame = frame + continue + if first_frame and first_frame.pts < start_offset: + audio_frames.append(first_frame) + first_frame = None + audio_frames.append(frame) + if frame.pts > end_offset: + break + + container.close() + return video_frames, audio_frames From 94ac03c7ed008b38696bc5411d484b8c7f360150 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 19 Jun 2019 14:08:18 +0200 Subject: [PATCH 03/15] Add some documentation --- torchvision/io/video.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index cd6a20d5713..1dcb8f35aab 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -5,6 +5,15 @@ def write_video(filename, video_array, fps): + """ + Writes a 4d tensor in a video file + + Arguments: + filename (str): path where the video will be saved + video_array (Tensor[T, H, W, C]): tensor containing the individual frames, + as a uint8 tensor in [T, H, W, C] format + fps (Number): frames per second + """ video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() container = av.open(filename, mode='w') @@ -45,6 +54,20 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): def read_video(filename, start_pts=0, end_pts=math.inf): + """ + Reads a video from a file, returning both the video frames as well as + the audio frames + + Arguments: + filename (str): path to the video file + start_pts (int, optional): the start presentation time of the video + end_pts (int, optional): the end presentation time + + Returns: + vframes (Tensor[T, H, W, C]): the `T` video frames + aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels + and `L` is the number of points + """ container = av.open(filename) video_frames = [] From fd0fd47634da8ba7ee3ffbfb3829b83ada19e8bb Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 19 Jun 2019 18:07:32 +0200 Subject: [PATCH 04/15] Improve tests and add GC collection --- test/test_io.py | 19 ++++++++++++--- torchvision/io/video.py | 51 +++++++++++++---------------------------- 2 files changed, 32 insertions(+), 38 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index b5357031bcb..6df58aebcfc 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -6,15 +6,28 @@ class Tester(unittest.TestCase): - + + def _create_video_frames(self, num_frames, height, width): + y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) + data = [] + for i in range(num_frames): + xc = float(i) / num_frames + yc = 1 - float(i) / (2 * num_frames) + d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 + data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) + + return torch.stack(data, 0) + def test_write_read_video(self): with tempfile.NamedTemporaryFile(suffix='.mp4') as f: - data = torch.randint(0, 255, (10, 300, 300, 3), dtype=torch.uint8) + data = self._create_video_frames(10, 300, 300) io.write_video(f.name, data, fps=5) lv, _ = io.read_video(f.name) - print((data.float() - lv.float()).abs().max()) + # compression adds artifacts, thus we add a tolerance of + # 5 in 0-255 range + self.assertTrue((data.float() - lv.float()).abs().max() < 5) if __name__ == '__main__': diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 1dcb8f35aab..3a58e890e29 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -1,12 +1,18 @@ import av +import gc import torch import numpy as np import math +# PyAV has some reference cycles +_CALLED_TIMES = 0 +_GC_COLLECTION_INTERVAL = 20 + + def write_video(filename, video_array, fps): """ - Writes a 4d tensor in a video file + Writes a 4d tensor in [T, H, W, C] format in a video file Arguments: filename (str): path where the video will be saved @@ -37,6 +43,11 @@ def write_video(filename, video_array, fps): def _read_from_stream(container, start_offset, end_offset, stream, stream_name): + global _CALLED_TIMES, _GC_COLLECTION_INTERVAL + _CALLED_TIMES += 1 + if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: + gc.collect() + container.seek(start_offset, any_frame=False, backward=True, stream=stream) frames = [] first_frame = None @@ -72,10 +83,12 @@ def read_video(filename, start_pts=0, end_pts=math.inf): video_frames = [] if container.streams.video: - video_frames = _read_from_stream(container, start_pts, end_pts, container.streams.video[0], {'video': 0}) + video_frames = _read_from_stream(container, start_pts, end_pts, + container.streams.video[0], {'video': 0}) audio_frames = [] if container.streams.audio: - audio_frames = _read_from_stream(container, start_pts, end_pts, container.streams.audio[0], {'audio': 0}) + audio_frames = _read_from_stream(container, start_pts, end_pts, + container.streams.audio[0], {'audio': 0}) container.close() @@ -90,35 +103,3 @@ def read_video(filename, start_pts=0, end_pts=math.inf): # return video_frames, audio_frames return vframes, aframes - - -def _read_video(filename, start_offset, end_offset): - container = av.open(filename) - - # video - container.seek(start_offset, any_frame=False, backward=True, stream=container.streams.video[0]) - video_frames = [] - for idx, frame in enumerate(container.decode(video=0)): - if frame.pts < start_offset: - continue - if frame.pts > end_offset: - break - video_frames.append(frame) - - # audio - container.seek(start_offset, backward=True, any_frame=False, stream=container.streams.audio[0]) - audio_frames = [] - first_frame = None - for idx, frame in enumerate(container.decode(audio=0)): - if frame.pts < start_offset: - first_frame = frame - continue - if first_frame and first_frame.pts < start_offset: - audio_frames.append(first_frame) - first_frame = None - audio_frames.append(frame) - if frame.pts > end_offset: - break - - container.close() - return video_frames, audio_frames From 3ac0ead32a75229fc167410e34d2003264f7e361 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 20 Jun 2019 11:36:57 +0200 Subject: [PATCH 05/15] [WIP] add timestamp getter --- test/test_io.py | 12 ++++++++++++ torchvision/io/__init__.py | 4 ++-- torchvision/io/video.py | 11 +++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index 6df58aebcfc..4cd763dafd2 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -29,6 +29,18 @@ def test_write_read_video(self): # 5 in 0-255 range self.assertTrue((data.float() - lv.float()).abs().max() < 5) + def test_read_timestamps(self): + with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + data = self._create_video_frames(10, 300, 300) + io.write_video(f.name, data, fps=5) + + lv = io.read_video_timestamps(f.name) + print(lv) + import av + container = av.open(f.name) + from IPython import embed; embed() + + if __name__ == '__main__': unittest.main() diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index dc9c181a655..3f7b9ab258b 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -1,6 +1,6 @@ -from .video import write_video, read_video +from .video import write_video, read_video, read_video_timestamps __all__ = [ - 'write_video', 'read_video' + 'write_video', 'read_video', 'read_video_timestamps' ] diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 3a58e890e29..a7abc010ba0 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -103,3 +103,14 @@ def read_video(filename, start_pts=0, end_pts=math.inf): # return video_frames, audio_frames return vframes, aframes + + +def read_video_timestamps(filename): + container = av.open(filename) + + video_frames = [] + if container.streams.video: + video_frames = _read_from_stream(container, 0, math.inf, + container.streams.video[0], {'video': 0}) + container.close() + return [x.pts for x in video_frames] From 3d8c4b999eee33b224d43696cdd636f71c7652f2 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 21 Jun 2019 15:39:23 +0200 Subject: [PATCH 06/15] Bugfixes --- test/test_io.py | 29 +++++++++++++++++++++-------- torchvision/io/video.py | 7 ++++--- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index 4cd763dafd2..e50df619910 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -6,6 +6,9 @@ class Tester(unittest.TestCase): + # compression adds artifacts, thus we add a tolerance of + # 5 in 0-255 range + TOLERANCE = 5 def _create_video_frames(self, num_frames, height, width): y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) @@ -25,22 +28,32 @@ def test_write_read_video(self): lv, _ = io.read_video(f.name) - # compression adds artifacts, thus we add a tolerance of - # 5 in 0-255 range - self.assertTrue((data.float() - lv.float()).abs().max() < 5) + self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE) def test_read_timestamps(self): with tempfile.NamedTemporaryFile(suffix='.mp4') as f: data = self._create_video_frames(10, 300, 300) io.write_video(f.name, data, fps=5) - lv = io.read_video_timestamps(f.name) - print(lv) - import av - container = av.open(f.name) - from IPython import embed; embed() + pts = io.read_video_timestamps(f.name) + self.assertEqual(pts, [0, 2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432]) + def test_read_partial_video(self): + with tempfile.NamedTemporaryFile(suffix='.mp4') as f: + data = self._create_video_frames(10, 300, 300) + io.write_video(f.name, data, fps=5) + + pts = io.read_video_timestamps(f.name) + + lv, _ = io.read_video(f.name, pts[4], pts[7]) + self.assertEqual(len(lv), 4) + self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) + + lv, _ = io.read_video(f.name, pts[4] + 1, pts[7]) + self.assertEqual(len(lv), 4) + self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/io/video.py b/torchvision/io/video.py index a7abc010ba0..d8ecf0508cf 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -55,11 +55,12 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): if frame.pts < start_offset: first_frame = frame continue - if first_frame and first_frame.pts < start_offset: - audio_frames.append(first_frame) + if first_frame and first_frame.pts < start_offset:# and frame.pts > start_offset: + if frame.pts != start_offset: + frames.append(first_frame) first_frame = None frames.append(frame) - if frame.pts > end_offset: + if frame.pts >= end_offset: break return frames From fc6cf3869ee0bd9bf928bade7fc45e04a745cb14 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 21 Jun 2019 16:51:05 +0200 Subject: [PATCH 07/15] Improvements and travis --- .travis.yml | 1 + test/test_io.py | 9 ++++++--- torchvision/io/video.py | 4 ++++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 07e1e8900a0..e86e7bc9a3a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,6 +35,7 @@ before_install: - pip install future - pip install pytest pytest-cov codecov - pip install mock + - conda install av -c conda-forge install: diff --git a/test/test_io.py b/test/test_io.py index e50df619910..34c1d961fee 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -46,9 +46,12 @@ def test_read_partial_video(self): pts = io.read_video_timestamps(f.name) - lv, _ = io.read_video(f.name, pts[4], pts[7]) - self.assertEqual(len(lv), 4) - self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) + for start in range(5): + for l in range(1, 4): + lv, _ = io.read_video(f.name, pts[start], pts[start + l - 1]) + s_data = data[start:(start + l)] + self.assertEqual(len(lv), l) + self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE) lv, _ = io.read_video(f.name, pts[4] + 1, pts[7]) self.assertEqual(len(lv), 4) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index d8ecf0508cf..5c4aed06ac3 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -80,6 +80,10 @@ def read_video(filename, start_pts=0, end_pts=math.inf): aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points """ + if end_pts < start_pts: + raise ValueError("end_pts should be larger than start_pts, got " + "start_pts={} and end_pts={}".format(start_pts, end_pts)) + container = av.open(filename) video_frames = [] From aad891026cdfcd9692a9cf77e3aaf35b6dcbbf57 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 21 Jun 2019 17:05:57 +0200 Subject: [PATCH 08/15] Add audio fine-grained alignment --- test/test_io.py | 3 ++- torchvision/io/video.py | 15 ++++++++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index 34c1d961fee..ff631505a1f 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -38,7 +38,6 @@ def test_read_timestamps(self): pts = io.read_video_timestamps(f.name) self.assertEqual(pts, [0, 2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432]) - def test_read_partial_video(self): with tempfile.NamedTemporaryFile(suffix='.mp4') as f: data = self._create_video_frames(10, 300, 300) @@ -57,6 +56,8 @@ def test_read_partial_video(self): self.assertEqual(len(lv), 4) self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) + # TODO add tests for audio + if __name__ == '__main__': unittest.main() diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 5c4aed06ac3..b252fe1f27a 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -65,6 +65,19 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): return frames +def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): + start, end = audio_frames[0].pts, audio_frames[-1].pts + total_aframes = aframes.shape[1] + step_per_aframe = (end - start + 1) / total_aframes + s_idx = 0 + e_idx = total_aframes + if start < ref_start: + s_idx = int((ref_start - start) / step_per_aframe) + if end > ref_end: + e_idx = int((ref_end - end) / step_per_aframe) + return aframes[:, s_idx:e_idx] + + def read_video(filename, start_pts=0, end_pts=math.inf): """ Reads a video from a file, returning both the video frames as well as @@ -103,10 +116,10 @@ def read_video(filename, start_pts=0, end_pts=math.inf): if aframes: aframes = np.concatenate(aframes, 1) aframes = torch.as_tensor(aframes) + aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) else: aframes = torch.empty((1, 0), dtype=torch.float32) - # return video_frames, audio_frames return vframes, aframes From 159cef491dcd8d9d8bd464832480182e4aad9eb5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 21 Jun 2019 17:27:25 +0200 Subject: [PATCH 09/15] More doc --- torchvision/io/video.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index b252fe1f27a..2c2a346dc66 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -124,6 +124,18 @@ def read_video(filename, start_pts=0, end_pts=math.inf): def read_video_timestamps(filename): + """ + List the video frames timestamps. + + Note that the function decodes the whole video frame-by-frame. + + Arguments: + filename (str): path to the video file + + Returns: + pts (List[int]): presentation timestamps for each one of the frames + in the video. + """ container = av.open(filename) video_frames = [] From 0b1d7037d1c4bf531f7bb28ea96372740ad52ca9 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 21 Jun 2019 17:32:26 +0200 Subject: [PATCH 10/15] Remove unecessary file --- torchvision/io/video_reader.py | 224 --------------------------------- 1 file changed, 224 deletions(-) delete mode 100644 torchvision/io/video_reader.py diff --git a/torchvision/io/video_reader.py b/torchvision/io/video_reader.py deleted file mode 100644 index f17be948dc8..00000000000 --- a/torchvision/io/video_reader.py +++ /dev/null @@ -1,224 +0,0 @@ -import av -import gc -import warnings - - - -_CALLED_TIMES = 0 -_GC_COLLECTION_INTERVAL = 20 - - -# remove warnings -av.logging.set_level(av.logging.ERROR) - - - -class VideoReader(object): - """ - Simple wrapper around PyAV that exposes a few useful functions for - dealing with video reading. - """ - def __init__(self, video_path, sampling_rate=1, decode_lossy=False, audio_resample_rate=None): - """ - Arguments: - video_path (str): path of the video to be loaded - """ - self.container = av.open(video_path) - self.sampling_rate = sampling_rate - self.resampler = None - if audio_resample_rate is not None: - self.resampler = av.AudioResampler(rate=audio_resample_rate) - - - if self.container.streams.video: - # enable multi-threaded video decoding - if decode_lossy: - warnings.warn('VideoReader| thread_type==AUTO can yield potential frame dropping!', RuntimeWarning) - self.container.streams.video[0].thread_type = 'AUTO' - self.video_stream = self.container.streams.video[0] - else: - self.video_stream = None - - def seek(self, offset, backward=True, any_frame=False): - stream = self.video_stream - self.container.seek(offset, any_frame=any_frame, backward=backward, stream=stream) - - def _occasional_gc(self): - # there are a lot of reference cycles in PyAV, so need to manually call - # the garbage collector from time to time - global _CALLED_TIMES, _GC_COLLECTION_INTERVAL - _CALLED_TIMES += 1 - if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: - gc.collect() - - def _read_video(self, offset, num_frames): - self._occasional_gc() - self.seek(offset) - video_frames = [] - count = 0 - for idx, frame in enumerate(self.container.decode(video=0)): - if frame.pts < offset: - continue - video_frames.append(frame) - if count >= num_frames - 1: - break - count += 1 - return video_frames - - def _resample_audio_frame(self, frame): - curr_pts = frame.pts - frame.pts = None - frame = self.resampler.resample(frame) - frame.pts = curr_pts - return frame - - - def _read_audio(self, offset, end_offset): - self._occasional_gc() - if not self.container.streams.audio: - return [] - - self.container.seek(offset, backward=True, any_frame=False, stream=self.container.streams.audio[0]) - - audio_frames = [] - first_frame = None - for idx, frame in enumerate(self.container.decode(audio=0)): - if frame.pts < offset: - first_frame = frame - continue - if first_frame and first_frame.pts < offset: - if self.resampler is not None: - first_frame = self._resample_audio_frame(first_frame) - audio_frames.append(first_frame) - first_frame = None - # if we want to resample audio to a different framerate - if self.resampler is not None: - frame = self._resample_audio_frame(frame) - audio_frames.append(frame) - if frame.pts > end_offset: - break - return audio_frames - - def read(self, offset, num_frames): - """ - Reads video frames and audio frames starting from offset. - The number of video frames read is given by num_frames. - The number of audio frames read is defined by the start and - end time of the first and last video frames, respectively - Arguments: - offset (int): the start time from the read - num_frames (int): the number of video frames to be read - Returns: - video_frames (List[av.VideoFrame]) - audio_frames (List[av.AudioFrame]) - """ - if self.container is None: - return [], [] - - num_frames = self.sampling_rate * num_frames - video_frames = self._read_video(offset, num_frames) - if len(video_frames) < 1: - end_offset = offset - elif len(video_frames) < 2: - end_offset = video_frames[-1].pts - else: - step = video_frames[-1].pts - video_frames[-2].pts - end_offset = video_frames[-1].pts + step - 1 - try: - audio_frames = self._read_audio(offset, end_offset) - except av.AVError: - audio_frames = [] - return video_frames, audio_frames - - def list_keyframes(self): - """ - Returns a list of start times for all the keyframes in the video - Returns: - keyframes (List[int]) - """ - keyframes = [] - if self.video_stream is None: - return [] - pts = -1 - while True: - try: - self.seek(pts + 1, backward=False) - except av.AVError: - break - packet = next(self.container.demux(video=0)) - pts = packet.pts - #TODO: double check if this is needed - if pts is None: - # should we simply return []? - return keyframes - - if packet.is_keyframe: - keyframes.append(pts) - return keyframes - - def _compute_end_video_pts(self): - self.seek(self.container.duration, any_frame=True) - end_step = next(self.container.demux(video=0)).pts - if end_step is None: - self.seek(self.container.duration, any_frame=False) - gen = self.container.demux(video=0) - last_pts = 0 - while True: - last_pts = next(gen).pts - if last_pts is None: - break - end_step = last_pts - return end_step - - def _compute_start_video_pts(self): - self.seek(0) - start = next(self.container.demux(video=0)).pts - return start - - def _compute_step_pts(self): - self.seek(0) - pts = [] - num = 11 - gen = self.container.demux(video=0) - for _ in range(num): - next(gen) - for _ in range(num): - pts.append(next(gen).pts) - print(pts) - steps = [p1 - p2 for p1, p2 in zip(pts[1:], pts[:-1])] - print(steps) - steps = max(set(steps), key=steps.count) - return int(steps) - - def _compute_step_pts(self): - frames = self._read_video(0, 2) - steps = frames[1].pts - frames[0].pts - return steps - - def list_every(self, n_frames): - step = 1 / float(self.video_stream.average_rate * self.video_stream.time_base) - end = self._compute_end_video_pts() - start = self._compute_start_video_pts() - step = self._compute_step_pts() - """ - orig_step = int(step) - for i in range(-10, 10): - if (end - start) % (orig_step + i) == 0: - step = orig_step + i - break - """ - return list(range(start, end + 1, int(step)))[::n_frames] - - - def _decode_every(self): - """ - A function used for truly decoding every single frame. - This should not be used outside of the dataset indexing step - Returns: - timestamp of every frame within the video (List[int]) - """ - if self.video_stream is None or self.container is None: - return [] - self.seek(0, backward=False) - d = [p for p in self.container.decode(video=0)] - return [x.pts for x in d[::1]] From 657eb01a60775e59e5d815ca4073fb7818901697 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 21 Jun 2019 17:33:01 +0200 Subject: [PATCH 11/15] Remove comment --- torchvision/io/video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 2c2a346dc66..3ef9a556e40 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -55,7 +55,7 @@ def _read_from_stream(container, start_offset, end_offset, stream, stream_name): if frame.pts < start_offset: first_frame = frame continue - if first_frame and first_frame.pts < start_offset:# and frame.pts > start_offset: + if first_frame and first_frame.pts < start_offset: if frame.pts != start_offset: frames.append(first_frame) first_frame = None From 30ce4034ddb609c34514ed1d9d2fa99311b6c4fd Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 24 Jun 2019 16:30:21 +0200 Subject: [PATCH 12/15] Lazy import av --- torchvision/io/video.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 3ef9a556e40..34b63c36727 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -1,9 +1,21 @@ -import av import gc import torch import numpy as np import math +try: + import av +except ImportError: + av = None + + +def _check_av_available(): + if av is None: + raise ImportError("""\ +PyAV is not installed, and is necessary for the video operations in torchvision. +See https://github.com/mikeboers/PyAV#installation for instructions on how to +install PyAV on your system. +""") # PyAV has some reference cycles _CALLED_TIMES = 0 @@ -20,6 +32,7 @@ def write_video(filename, video_array, fps): as a uint8 tensor in [T, H, W, C] format fps (Number): frames per second """ + _check_av_available() video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() container = av.open(filename, mode='w') @@ -93,6 +106,8 @@ def read_video(filename, start_pts=0, end_pts=math.inf): aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points """ + _check_av_available() + if end_pts < start_pts: raise ValueError("end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts)) @@ -136,6 +151,7 @@ def read_video_timestamps(filename): pts (List[int]): presentation timestamps for each one of the frames in the video. """ + _check_av_available() container = av.open(filename) video_frames = [] From 6d4bad4a6936d13a051b887349562d0b25f062dc Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 24 Jun 2019 17:03:07 +0200 Subject: [PATCH 13/15] Remove hard-coded constants for the test --- test/test_io.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/test/test_io.py b/test/test_io.py index ff631505a1f..e5458afa092 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -5,6 +5,12 @@ import unittest +try: + import av +except ImportError: + av = None + + class Tester(unittest.TestCase): # compression adds artifacts, thus we add a tolerance of # 5 in 0-255 range @@ -21,6 +27,7 @@ def _create_video_frames(self, num_frames, height, width): return torch.stack(data, 0) + @unittest.skipIf(av is None, "PyAV unavailable") def test_write_read_video(self): with tempfile.NamedTemporaryFile(suffix='.mp4') as f: data = self._create_video_frames(10, 300, 300) @@ -30,14 +37,26 @@ def test_write_read_video(self): self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE) + @unittest.skipIf(av is None, "PyAV unavailable") def test_read_timestamps(self): with tempfile.NamedTemporaryFile(suffix='.mp4') as f: data = self._create_video_frames(10, 300, 300) io.write_video(f.name, data, fps=5) pts = io.read_video_timestamps(f.name) - self.assertEqual(pts, [0, 2048, 4096, 6144, 8192, 10240, 12288, 14336, 16384, 18432]) + # note: not all formats/codecs provide accurate information for computing the + # timestamps. For the format that we use here, this information is available, + # so we use it as a baseline + container = av.open(f.name) + stream = container.streams[0] + pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) + num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) + expected_pts = [i * pts_step for i in range(num_frames)] + + self.assertEqual(pts, expected_pts) + + @unittest.skipIf(av is None, "PyAV unavailable") def test_read_partial_video(self): with tempfile.NamedTemporaryFile(suffix='.mp4') as f: data = self._create_video_frames(10, 300, 300) From 1e1c7e128aa06da9e4a3c32b2b47534d6c1d5527 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 26 Jun 2019 15:39:19 +0200 Subject: [PATCH 14/15] Return info stats from read --- test/test_io.py | 11 ++++++----- torchvision/io/video.py | 9 ++++++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index e5458afa092..775a00fd9b1 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -13,8 +13,8 @@ class Tester(unittest.TestCase): # compression adds artifacts, thus we add a tolerance of - # 5 in 0-255 range - TOLERANCE = 5 + # 6 in 0-255 range + TOLERANCE = 6 def _create_video_frames(self, num_frames, height, width): y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) @@ -33,9 +33,10 @@ def test_write_read_video(self): data = self._create_video_frames(10, 300, 300) io.write_video(f.name, data, fps=5) - lv, _ = io.read_video(f.name) + lv, _, info = io.read_video(f.name) self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE) + self.assertEqual(info["video_fps"], 5) @unittest.skipIf(av is None, "PyAV unavailable") def test_read_timestamps(self): @@ -66,12 +67,12 @@ def test_read_partial_video(self): for start in range(5): for l in range(1, 4): - lv, _ = io.read_video(f.name, pts[start], pts[start + l - 1]) + lv, _, _ = io.read_video(f.name, pts[start], pts[start + l - 1]) s_data = data[start:(start + l)] self.assertEqual(len(lv), l) self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE) - lv, _ = io.read_video(f.name, pts[4] + 1, pts[7]) + lv, _, _ = io.read_video(f.name, pts[4] + 1, pts[7]) self.assertEqual(len(lv), 4) self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 34b63c36727..314597ef04f 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -17,6 +17,7 @@ def _check_av_available(): install PyAV on your system. """) + # PyAV has some reference cycles _CALLED_TIMES = 0 _GC_COLLECTION_INTERVAL = 20 @@ -105,6 +106,9 @@ def read_video(filename, start_pts=0, end_pts=math.inf): vframes (Tensor[T, H, W, C]): the `T` video frames aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels and `L` is the number of points + info (Dict): metadata for the video and audio. Can contain the fields + - video_fps (float) + - audio_fps (int) """ _check_av_available() @@ -113,15 +117,18 @@ def read_video(filename, start_pts=0, end_pts=math.inf): "start_pts={} and end_pts={}".format(start_pts, end_pts)) container = av.open(filename) + info = {} video_frames = [] if container.streams.video: video_frames = _read_from_stream(container, start_pts, end_pts, container.streams.video[0], {'video': 0}) + info["video_fps"] = float(container.streams.video[0].average_rate) audio_frames = [] if container.streams.audio: audio_frames = _read_from_stream(container, start_pts, end_pts, container.streams.audio[0], {'audio': 0}) + info["audio_fps"] = container.streams.audio[0].rate container.close() @@ -135,7 +142,7 @@ def read_video(filename, start_pts=0, end_pts=math.inf): else: aframes = torch.empty((1, 0), dtype=torch.float32) - return vframes, aframes + return vframes, aframes, info def read_video_timestamps(filename): From 38596b09740782ab59222e9ca515425d7c854593 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 26 Jun 2019 16:03:09 +0200 Subject: [PATCH 15/15] Fix for Python-2 --- torchvision/io/video.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/io/video.py b/torchvision/io/video.py index 314597ef04f..f80177b46dd 100644 --- a/torchvision/io/video.py +++ b/torchvision/io/video.py @@ -1,7 +1,6 @@ import gc import torch import numpy as np -import math try: import av @@ -92,7 +91,7 @@ def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): return aframes[:, s_idx:e_idx] -def read_video(filename, start_pts=0, end_pts=math.inf): +def read_video(filename, start_pts=0, end_pts=None): """ Reads a video from a file, returning both the video frames as well as the audio frames @@ -112,6 +111,9 @@ def read_video(filename, start_pts=0, end_pts=math.inf): """ _check_av_available() + if end_pts is None: + end_pts = float("inf") + if end_pts < start_pts: raise ValueError("end_pts should be larger than start_pts, got " "start_pts={} and end_pts={}".format(start_pts, end_pts)) @@ -163,7 +165,7 @@ def read_video_timestamps(filename): video_frames = [] if container.streams.video: - video_frames = _read_from_stream(container, 0, math.inf, + video_frames = _read_from_stream(container, 0, float("inf"), container.streams.video[0], {'video': 0}) container.close() return [x.pts for x in video_frames]