Skip to content

Adds video reading / saving functionalities #1039

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 2, 2019
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ before_install:
- pip install future
- pip install pytest pytest-cov codecov
- pip install mock
- conda install av -c conda-forge


install:
Expand Down
83 changes: 83 additions & 0 deletions test/test_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import tempfile
import torch
import torchvision.io as io
import unittest


try:
import av
except ImportError:
av = None


class Tester(unittest.TestCase):
# compression adds artifacts, thus we add a tolerance of
# 6 in 0-255 range
TOLERANCE = 6

def _create_video_frames(self, num_frames, height, width):
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width))
data = []
for i in range(num_frames):
xc = float(i) / num_frames
yc = 1 - float(i) / (2 * num_frames)
d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255
data.append(d.unsqueeze(2).repeat(1, 1, 3).byte())

return torch.stack(data, 0)

@unittest.skipIf(av is None, "PyAV unavailable")
def test_write_read_video(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)

lv, _, info = io.read_video(f.name)

self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE)
self.assertEqual(info["video_fps"], 5)

@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_timestamps(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)

pts = io.read_video_timestamps(f.name)

# note: not all formats/codecs provide accurate information for computing the
# timestamps. For the format that we use here, this information is available,
# so we use it as a baseline
container = av.open(f.name)
stream = container.streams[0]
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
expected_pts = [i * pts_step for i in range(num_frames)]

self.assertEqual(pts, expected_pts)

@unittest.skipIf(av is None, "PyAV unavailable")
def test_read_partial_video(self):
with tempfile.NamedTemporaryFile(suffix='.mp4') as f:
data = self._create_video_frames(10, 300, 300)
io.write_video(f.name, data, fps=5)

pts = io.read_video_timestamps(f.name)

for start in range(5):
for l in range(1, 4):
lv, _, _ = io.read_video(f.name, pts[start], pts[start + l - 1])
s_data = data[start:(start + l)]
self.assertEqual(len(lv), l)
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE)

lv, _, _ = io.read_video(f.name, pts[4] + 1, pts[7])
self.assertEqual(len(lv), 4)
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)

# TODO add tests for audio


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions torchvision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torchvision import ops
from torchvision import transforms
from torchvision import utils
from torchvision import io

try:
from .version import __version__ # noqa: F401
Expand Down
6 changes: 6 additions & 0 deletions torchvision/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .video import write_video, read_video, read_video_timestamps


__all__ = [
'write_video', 'read_video', 'read_video_timestamps'
]
171 changes: 171 additions & 0 deletions torchvision/io/video.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
import gc
import torch
import numpy as np

try:
import av
except ImportError:
av = None


def _check_av_available():
if av is None:
raise ImportError("""\
PyAV is not installed, and is necessary for the video operations in torchvision.
See https://github.com/mikeboers/PyAV#installation for instructions on how to
install PyAV on your system.
""")


# PyAV has some reference cycles
_CALLED_TIMES = 0
_GC_COLLECTION_INTERVAL = 20


def write_video(filename, video_array, fps):
"""
Writes a 4d tensor in [T, H, W, C] format in a video file

Arguments:
filename (str): path where the video will be saved
video_array (Tensor[T, H, W, C]): tensor containing the individual frames,
as a uint8 tensor in [T, H, W, C] format
fps (Number): frames per second
"""
_check_av_available()
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy()

container = av.open(filename, mode='w')

stream = container.add_stream('mpeg4', rate=fps)
stream.width = video_array.shape[2]
stream.height = video_array.shape[1]
stream.pix_fmt = 'yuv420p'

for img in video_array:
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
for packet in stream.encode(frame):
container.mux(packet)

# Flush stream
for packet in stream.encode():
container.mux(packet)

# Close the file
container.close()


def _read_from_stream(container, start_offset, end_offset, stream, stream_name):
Copy link
Contributor

@bjuncek bjuncek Jun 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to add an option to resample audio to a specific SR online?

(note, I have this in the experimental repo)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question again.

I'm inclined to always return the audio at a fixed frequency (say 44kHz), so that the results are always consistent.

Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"I'm inclined to always return the audio at a fixed frequency"

definitely, BUT I think the exact value should be left for the user to decide?
I feel like a simple if sr != user_defined_sr call resampler should be enough?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that for a reading function, we should try to make it as simple as possible, and add additional transforms for resampling the audio / video if needed.

But you have great points about the stats that should be returned.

I'll modify the implementation to return a third argument, a dict with the fps etc.

global _CALLED_TIMES, _GC_COLLECTION_INTERVAL
_CALLED_TIMES += 1
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1:
gc.collect()

container.seek(start_offset, any_frame=False, backward=True, stream=stream)
frames = []
first_frame = None
for idx, frame in enumerate(container.decode(**stream_name)):
if frame.pts < start_offset:
first_frame = frame
continue
if first_frame and first_frame.pts < start_offset:
if frame.pts != start_offset:
frames.append(first_frame)
first_frame = None
frames.append(frame)
if frame.pts >= end_offset:
break
return frames


def _align_audio_frames(aframes, audio_frames, ref_start, ref_end):
start, end = audio_frames[0].pts, audio_frames[-1].pts
total_aframes = aframes.shape[1]
step_per_aframe = (end - start + 1) / total_aframes
s_idx = 0
e_idx = total_aframes
if start < ref_start:
s_idx = int((ref_start - start) / step_per_aframe)
if end > ref_end:
e_idx = int((ref_end - end) / step_per_aframe)
return aframes[:, s_idx:e_idx]


def read_video(filename, start_pts=0, end_pts=None):
"""
Reads a video from a file, returning both the video frames as well as
the audio frames

Arguments:
filename (str): path to the video file
start_pts (int, optional): the start presentation time of the video
end_pts (int, optional): the end presentation time

Returns:
vframes (Tensor[T, H, W, C]): the `T` video frames
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels
and `L` is the number of points
info (Dict): metadata for the video and audio. Can contain the fields
- video_fps (float)
- audio_fps (int)
"""
_check_av_available()

if end_pts is None:
end_pts = float("inf")

if end_pts < start_pts:
raise ValueError("end_pts should be larger than start_pts, got "
"start_pts={} and end_pts={}".format(start_pts, end_pts))

container = av.open(filename)
info = {}

video_frames = []
if container.streams.video:
video_frames = _read_from_stream(container, start_pts, end_pts,
container.streams.video[0], {'video': 0})
info["video_fps"] = float(container.streams.video[0].average_rate)
audio_frames = []
if container.streams.audio:
audio_frames = _read_from_stream(container, start_pts, end_pts,
container.streams.audio[0], {'audio': 0})
info["audio_fps"] = container.streams.audio[0].rate

container.close()

vframes = [frame.to_rgb().to_ndarray() for frame in video_frames]
aframes = [frame.to_ndarray() for frame in audio_frames]
vframes = torch.as_tensor(np.stack(vframes))
if aframes:
aframes = np.concatenate(aframes, 1)
aframes = torch.as_tensor(aframes)
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts)
else:
aframes = torch.empty((1, 0), dtype=torch.float32)

return vframes, aframes, info


def read_video_timestamps(filename):
"""
List the video frames timestamps.

Note that the function decodes the whole video frame-by-frame.

Arguments:
filename (str): path to the video file

Returns:
pts (List[int]): presentation timestamps for each one of the frames
in the video.
"""
_check_av_available()
container = av.open(filename)

video_frames = []
if container.streams.video:
video_frames = _read_from_stream(container, 0, float("inf"),
container.streams.video[0], {'video': 0})
container.close()
return [x.pts for x in video_frames]