-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adds video reading / saving functionalities #1039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
cf5d0a0
WIP
fmassa d71c919
WIP
fmassa 94ac03c
Add some documentation
fmassa fd0fd47
Improve tests and add GC collection
fmassa 3ac0ead
[WIP] add timestamp getter
fmassa 3d8c4b9
Bugfixes
fmassa fc6cf38
Improvements and travis
fmassa aad8910
Add audio fine-grained alignment
fmassa 159cef4
More doc
fmassa 0b1d703
Remove unecessary file
fmassa 657eb01
Remove comment
fmassa 30ce403
Lazy import av
fmassa 6d4bad4
Remove hard-coded constants for the test
fmassa 1e1c7e1
Return info stats from read
fmassa 38596b0
Fix for Python-2
fmassa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import os | ||
import tempfile | ||
import torch | ||
import torchvision.io as io | ||
import unittest | ||
|
||
|
||
try: | ||
import av | ||
except ImportError: | ||
av = None | ||
|
||
|
||
class Tester(unittest.TestCase): | ||
# compression adds artifacts, thus we add a tolerance of | ||
# 6 in 0-255 range | ||
TOLERANCE = 6 | ||
|
||
def _create_video_frames(self, num_frames, height, width): | ||
y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) | ||
data = [] | ||
for i in range(num_frames): | ||
xc = float(i) / num_frames | ||
yc = 1 - float(i) / (2 * num_frames) | ||
d = torch.exp(-((x - xc) ** 2 + (y - yc) ** 2) / 2) * 255 | ||
data.append(d.unsqueeze(2).repeat(1, 1, 3).byte()) | ||
|
||
return torch.stack(data, 0) | ||
|
||
@unittest.skipIf(av is None, "PyAV unavailable") | ||
def test_write_read_video(self): | ||
with tempfile.NamedTemporaryFile(suffix='.mp4') as f: | ||
data = self._create_video_frames(10, 300, 300) | ||
io.write_video(f.name, data, fps=5) | ||
|
||
lv, _, info = io.read_video(f.name) | ||
|
||
self.assertTrue((data.float() - lv.float()).abs().max() < self.TOLERANCE) | ||
self.assertEqual(info["video_fps"], 5) | ||
|
||
@unittest.skipIf(av is None, "PyAV unavailable") | ||
def test_read_timestamps(self): | ||
with tempfile.NamedTemporaryFile(suffix='.mp4') as f: | ||
data = self._create_video_frames(10, 300, 300) | ||
io.write_video(f.name, data, fps=5) | ||
|
||
pts = io.read_video_timestamps(f.name) | ||
|
||
# note: not all formats/codecs provide accurate information for computing the | ||
# timestamps. For the format that we use here, this information is available, | ||
# so we use it as a baseline | ||
container = av.open(f.name) | ||
stream = container.streams[0] | ||
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base)))) | ||
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration))) | ||
expected_pts = [i * pts_step for i in range(num_frames)] | ||
|
||
self.assertEqual(pts, expected_pts) | ||
|
||
@unittest.skipIf(av is None, "PyAV unavailable") | ||
def test_read_partial_video(self): | ||
with tempfile.NamedTemporaryFile(suffix='.mp4') as f: | ||
data = self._create_video_frames(10, 300, 300) | ||
io.write_video(f.name, data, fps=5) | ||
|
||
pts = io.read_video_timestamps(f.name) | ||
|
||
for start in range(5): | ||
for l in range(1, 4): | ||
lv, _, _ = io.read_video(f.name, pts[start], pts[start + l - 1]) | ||
s_data = data[start:(start + l)] | ||
self.assertEqual(len(lv), l) | ||
self.assertTrue((s_data.float() - lv.float()).abs().max() < self.TOLERANCE) | ||
|
||
lv, _, _ = io.read_video(f.name, pts[4] + 1, pts[7]) | ||
self.assertEqual(len(lv), 4) | ||
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE) | ||
|
||
# TODO add tests for audio | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .video import write_video, read_video, read_video_timestamps | ||
|
||
|
||
__all__ = [ | ||
'write_video', 'read_video', 'read_video_timestamps' | ||
] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import gc | ||
import torch | ||
import numpy as np | ||
|
||
try: | ||
import av | ||
except ImportError: | ||
av = None | ||
|
||
|
||
def _check_av_available(): | ||
if av is None: | ||
raise ImportError("""\ | ||
PyAV is not installed, and is necessary for the video operations in torchvision. | ||
See https://github.com/mikeboers/PyAV#installation for instructions on how to | ||
install PyAV on your system. | ||
""") | ||
|
||
|
||
# PyAV has some reference cycles | ||
_CALLED_TIMES = 0 | ||
_GC_COLLECTION_INTERVAL = 20 | ||
|
||
|
||
def write_video(filename, video_array, fps): | ||
""" | ||
Writes a 4d tensor in [T, H, W, C] format in a video file | ||
|
||
Arguments: | ||
filename (str): path where the video will be saved | ||
video_array (Tensor[T, H, W, C]): tensor containing the individual frames, | ||
as a uint8 tensor in [T, H, W, C] format | ||
fps (Number): frames per second | ||
""" | ||
_check_av_available() | ||
video_array = torch.as_tensor(video_array, dtype=torch.uint8).numpy() | ||
|
||
container = av.open(filename, mode='w') | ||
|
||
stream = container.add_stream('mpeg4', rate=fps) | ||
stream.width = video_array.shape[2] | ||
stream.height = video_array.shape[1] | ||
stream.pix_fmt = 'yuv420p' | ||
|
||
for img in video_array: | ||
frame = av.VideoFrame.from_ndarray(img, format='rgb24') | ||
for packet in stream.encode(frame): | ||
container.mux(packet) | ||
|
||
# Flush stream | ||
for packet in stream.encode(): | ||
container.mux(packet) | ||
|
||
# Close the file | ||
container.close() | ||
|
||
|
||
def _read_from_stream(container, start_offset, end_offset, stream, stream_name): | ||
global _CALLED_TIMES, _GC_COLLECTION_INTERVAL | ||
_CALLED_TIMES += 1 | ||
if _CALLED_TIMES % _GC_COLLECTION_INTERVAL == _GC_COLLECTION_INTERVAL - 1: | ||
gc.collect() | ||
|
||
container.seek(start_offset, any_frame=False, backward=True, stream=stream) | ||
frames = [] | ||
first_frame = None | ||
for idx, frame in enumerate(container.decode(**stream_name)): | ||
if frame.pts < start_offset: | ||
first_frame = frame | ||
continue | ||
if first_frame and first_frame.pts < start_offset: | ||
if frame.pts != start_offset: | ||
frames.append(first_frame) | ||
first_frame = None | ||
frames.append(frame) | ||
if frame.pts >= end_offset: | ||
break | ||
return frames | ||
|
||
|
||
def _align_audio_frames(aframes, audio_frames, ref_start, ref_end): | ||
start, end = audio_frames[0].pts, audio_frames[-1].pts | ||
total_aframes = aframes.shape[1] | ||
step_per_aframe = (end - start + 1) / total_aframes | ||
s_idx = 0 | ||
e_idx = total_aframes | ||
if start < ref_start: | ||
s_idx = int((ref_start - start) / step_per_aframe) | ||
if end > ref_end: | ||
e_idx = int((ref_end - end) / step_per_aframe) | ||
return aframes[:, s_idx:e_idx] | ||
|
||
|
||
def read_video(filename, start_pts=0, end_pts=None): | ||
""" | ||
Reads a video from a file, returning both the video frames as well as | ||
the audio frames | ||
|
||
Arguments: | ||
filename (str): path to the video file | ||
start_pts (int, optional): the start presentation time of the video | ||
end_pts (int, optional): the end presentation time | ||
|
||
Returns: | ||
vframes (Tensor[T, H, W, C]): the `T` video frames | ||
aframes (Tensor[K, L]): the audio frames, where `K` is the number of channels | ||
and `L` is the number of points | ||
info (Dict): metadata for the video and audio. Can contain the fields | ||
- video_fps (float) | ||
- audio_fps (int) | ||
""" | ||
_check_av_available() | ||
|
||
if end_pts is None: | ||
end_pts = float("inf") | ||
|
||
if end_pts < start_pts: | ||
raise ValueError("end_pts should be larger than start_pts, got " | ||
"start_pts={} and end_pts={}".format(start_pts, end_pts)) | ||
|
||
container = av.open(filename) | ||
info = {} | ||
|
||
video_frames = [] | ||
if container.streams.video: | ||
video_frames = _read_from_stream(container, start_pts, end_pts, | ||
container.streams.video[0], {'video': 0}) | ||
info["video_fps"] = float(container.streams.video[0].average_rate) | ||
audio_frames = [] | ||
if container.streams.audio: | ||
audio_frames = _read_from_stream(container, start_pts, end_pts, | ||
container.streams.audio[0], {'audio': 0}) | ||
info["audio_fps"] = container.streams.audio[0].rate | ||
|
||
container.close() | ||
|
||
vframes = [frame.to_rgb().to_ndarray() for frame in video_frames] | ||
aframes = [frame.to_ndarray() for frame in audio_frames] | ||
vframes = torch.as_tensor(np.stack(vframes)) | ||
if aframes: | ||
aframes = np.concatenate(aframes, 1) | ||
aframes = torch.as_tensor(aframes) | ||
aframes = _align_audio_frames(aframes, audio_frames, start_pts, end_pts) | ||
else: | ||
aframes = torch.empty((1, 0), dtype=torch.float32) | ||
|
||
return vframes, aframes, info | ||
|
||
|
||
def read_video_timestamps(filename): | ||
""" | ||
List the video frames timestamps. | ||
|
||
Note that the function decodes the whole video frame-by-frame. | ||
|
||
Arguments: | ||
filename (str): path to the video file | ||
|
||
Returns: | ||
pts (List[int]): presentation timestamps for each one of the frames | ||
in the video. | ||
""" | ||
_check_av_available() | ||
container = av.open(filename) | ||
|
||
video_frames = [] | ||
if container.streams.video: | ||
video_frames = _read_from_stream(container, 0, float("inf"), | ||
container.streams.video[0], {'video': 0}) | ||
container.close() | ||
return [x.pts for x in video_frames] |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to add an option to resample audio to a specific SR online?
(note, I have this in the experimental repo)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question again.
I'm inclined to always return the audio at a fixed frequency (say 44kHz), so that the results are always consistent.
Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"I'm inclined to always return the audio at a fixed frequency"
definitely, BUT I think the exact value should be left for the user to decide?
I feel like a simple if
sr != user_defined_sr call resampler
should be enough?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that for a reading function, we should try to make it as simple as possible, and add additional transforms for resampling the audio / video if needed.
But you have great points about the stats that should be returned.
I'll modify the implementation to return a third argument, a dict with the fps etc.