diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index b6790d49cce..a9cb7ab50ef 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -6,7 +6,6 @@ from torchvision import io from torchvision.datasets.video_utils import VideoClips, unfold -from torchvision import get_video_backend from common_utils import get_tmp_dir @@ -62,23 +61,22 @@ def test_unfold(self): @unittest.skipIf(not io.video._av_available(), "this test requires av") @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_video_clips(self): - _backend = get_video_backend() with get_list_of_videos(num_videos=3) as video_list: - video_clips = VideoClips(video_list, 5, 5, _backend=_backend) + video_clips = VideoClips(video_list, 5, 5) self.assertEqual(video_clips.num_clips(), 1 + 2 + 3) for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]): video_idx, clip_idx = video_clips.get_clip_location(i) self.assertEqual(video_idx, v_idx) self.assertEqual(clip_idx, c_idx) - video_clips = VideoClips(video_list, 6, 6, _backend=_backend) + video_clips = VideoClips(video_list, 6, 6) self.assertEqual(video_clips.num_clips(), 0 + 1 + 2) for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]): video_idx, clip_idx = video_clips.get_clip_location(i) self.assertEqual(video_idx, v_idx) self.assertEqual(clip_idx, c_idx) - video_clips = VideoClips(video_list, 6, 1, _backend=_backend) + video_clips = VideoClips(video_list, 6, 1) self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1)) for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]: video_idx, clip_idx = video_clips.get_clip_location(i) @@ -87,9 +85,8 @@ def test_video_clips(self): @unittest.skip("Moved to reference scripts for now") def test_video_sampler(self): - _backend = get_video_backend() with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5, _backend=_backend) + video_clips = VideoClips(video_list, 5, 5) sampler = RandomClipSampler(video_clips, 3) # noqa: F821 self.assertEqual(len(sampler), 3 * 3) indices = torch.tensor(list(iter(sampler))) @@ -100,9 +97,8 @@ def test_video_sampler(self): @unittest.skip("Moved to reference scripts for now") def test_video_sampler_unequal(self): - _backend = get_video_backend() with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5, _backend=_backend) + video_clips = VideoClips(video_list, 5, 5) sampler = RandomClipSampler(video_clips, 3) # noqa: F821 self.assertEqual(len(sampler), 2 + 3 + 3) indices = list(iter(sampler)) @@ -120,11 +116,10 @@ def test_video_sampler_unequal(self): @unittest.skipIf(not io.video._av_available(), "this test requires av") @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_video_clips_custom_fps(self): - _backend = get_video_backend() with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list: num_frames = 4 for fps in [1, 3, 4, 10]: - video_clips = VideoClips(video_list, num_frames, num_frames, fps, _backend=_backend) + video_clips = VideoClips(video_list, num_frames, num_frames, fps) for i in range(video_clips.num_clips()): video, audio, info, video_idx = video_clips.get_clip(i) self.assertEqual(video.shape[0], num_frames) diff --git a/torchvision/__init__.py b/torchvision/__init__.py index 2fd780497d3..ca155712671 100644 --- a/torchvision/__init__.py +++ b/torchvision/__init__.py @@ -1,3 +1,5 @@ +import warnings + from torchvision import models from torchvision import datasets from torchvision import ops @@ -57,7 +59,10 @@ def set_video_backend(backend): raise ValueError( "Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend ) - _video_backend = backend + if backend == "video_reader" and not io._HAS_VIDEO_OPT: + warnings.warn("video_reader video backend is not available") + else: + _video_backend = backend def get_video_backend(): diff --git a/torchvision/datasets/hmdb51.py b/torchvision/datasets/hmdb51.py index 438da54c30e..b5fad588785 100644 --- a/torchvision/datasets/hmdb51.py +++ b/torchvision/datasets/hmdb51.py @@ -1,9 +1,9 @@ import glob import os -from .video_utils import VideoClips from .utils import list_dir from .folder import make_dataset +from .video_utils import VideoClips from .vision import VisionDataset @@ -51,7 +51,8 @@ class HMDB51(VisionDataset): def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, frame_rate=None, fold=1, train=True, transform=None, - _precomputed_metadata=None): + _precomputed_metadata=None, num_workers=1, _video_width=0, + _video_height=0, _video_min_dimension=0, _audio_samples=0): super(HMDB51, self).__init__(root) if not 1 <= fold <= 3: raise ValueError("fold should be between 1 and 3, got {}".format(fold)) @@ -71,11 +72,21 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, step_between_clips, frame_rate, _precomputed_metadata, + num_workers=num_workers, + _video_width=_video_width, + _video_height=_video_height, + _video_min_dimension=_video_min_dimension, + _audio_samples=_audio_samples, ) + self.video_clips_metadata = video_clips.metadata self.indices = self._select_fold(video_list, annotation_path, fold, train) self.video_clips = video_clips.subset(self.indices) self.transform = transform + @property + def metadata(self): + return self.video_clips_metadata + def _select_fold(self, video_list, annotation_path, fold, train): target_tag = 1 if train else 2 name = "*test_split{}.txt".format(fold) diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index 0a7ce8d41f3..2c8faa343d6 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -1,6 +1,6 @@ -from .video_utils import VideoClips from .utils import list_dir from .folder import make_dataset +from .video_utils import VideoClips from .vision import VisionDataset @@ -37,7 +37,9 @@ class Kinetics400(VisionDataset): """ def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None, - extensions=('avi',), transform=None, _precomputed_metadata=None): + extensions=('avi',), transform=None, _precomputed_metadata=None, + num_workers=1, _video_width=0, _video_height=0, + _video_min_dimension=0, _audio_samples=0): super(Kinetics400, self).__init__(root) extensions = ('avi',) @@ -52,9 +54,18 @@ def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None, step_between_clips, frame_rate, _precomputed_metadata, + num_workers=num_workers, + _video_width=_video_width, + _video_height=_video_height, + _video_min_dimension=_video_min_dimension, + _audio_samples=_audio_samples, ) self.transform = transform + @property + def metadata(self): + return self.video_clips.metadata + def __len__(self): return self.video_clips.num_clips() diff --git a/torchvision/datasets/ucf101.py b/torchvision/datasets/ucf101.py index 6abef63b7e2..50734c4b469 100644 --- a/torchvision/datasets/ucf101.py +++ b/torchvision/datasets/ucf101.py @@ -1,9 +1,9 @@ import glob import os -from .video_utils import VideoClips from .utils import list_dir from .folder import make_dataset +from .video_utils import VideoClips from .vision import VisionDataset @@ -44,7 +44,8 @@ class UCF101(VisionDataset): def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, frame_rate=None, fold=1, train=True, transform=None, - _precomputed_metadata=None): + _precomputed_metadata=None, num_workers=1, _video_width=0, + _video_height=0, _video_min_dimension=0, _audio_samples=0): super(UCF101, self).__init__(root) if not 1 <= fold <= 3: raise ValueError("fold should be between 1 and 3, got {}".format(fold)) @@ -64,11 +65,21 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, step_between_clips, frame_rate, _precomputed_metadata, + num_workers=num_workers, + _video_width=_video_width, + _video_height=_video_height, + _video_min_dimension=_video_min_dimension, + _audio_samples=_audio_samples, ) + self.video_clips_metadata = video_clips.metadata self.indices = self._select_fold(video_list, annotation_path, fold, train) self.video_clips = video_clips.subset(self.indices) self.transform = transform + @property + def metadata(self): + return self.video_clips_metadata + def _select_fold(self, video_list, annotation_path, fold, train): name = "train" if train else "test" name = "{}list{:02d}.txt".format(name, fold) diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index b5c19d07088..b23ded6f1d9 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -68,10 +68,18 @@ class VideoClips(object): 0 means that the data will be loaded in the main process. (default: 0) """ def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1, - frame_rate=None, _precomputed_metadata=None, num_workers=0, _backend="pyav"): + frame_rate=None, _precomputed_metadata=None, num_workers=0, + _video_width=0, _video_height=0, _video_min_dimension=0, + _audio_samples=0): + from torchvision import get_video_backend + self.video_paths = video_paths self.num_workers = num_workers - self._backend = _backend + self._backend = get_video_backend() + self._video_width = _video_width + self._video_height = _video_height + self._video_min_dimension = _video_min_dimension + self._audio_samples = _audio_samples if _precomputed_metadata is None: self._compute_frame_pts() @@ -145,6 +153,7 @@ def metadata(self): _metadata.update({"video_fps": self.video_fps}) else: _metadata.update({"info": self.info}) + return _metadata def subset(self, indices): video_paths = [self.video_paths[i] for i in indices] @@ -162,7 +171,11 @@ def subset(self, indices): else: metadata.update({"info": info}) return type(self)(video_paths, self.num_frames, self.step, self.frame_rate, - _precomputed_metadata=metadata) + _precomputed_metadata=metadata, num_workers=self.num_workers, + _video_width=self._video_width, + _video_height=self._video_height, + _video_min_dimension=self._video_min_dimension, + _audio_samples=self._audio_samples) @staticmethod def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate): @@ -206,9 +219,15 @@ def compute_clips(self, num_frames, step, frame_rate=None): self.resampling_idxs.append(idxs) else: for video_pts, info in zip(self.video_pts, self.info): - clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, info["video_fps"], frame_rate) - self.clips.append(clips) - self.resampling_idxs.append(idxs) + if "video_fps" in info: + clips, idxs = self.compute_clips_for_video( + video_pts, num_frames, step, info["video_fps"], frame_rate) + self.clips.append(clips) + self.resampling_idxs.append(idxs) + else: + # properly handle the cases where video decoding fails + self.clips.append(torch.zeros(0, num_frames, dtype=torch.int64)) + self.resampling_idxs.append(torch.zeros(0, dtype=torch.int64)) clip_lengths = torch.as_tensor([len(v) for v in self.clips]) self.cumulative_sizes = clip_lengths.cumsum(0).tolist() @@ -296,8 +315,12 @@ def get_clip(self, idx): ) video, audio, info = _read_video_from_file( video_path, + video_width=self._video_width, + video_height=self._video_height, + video_min_dimension=self._video_min_dimension, video_pts_range=(video_start_pts, video_end_pts), video_timebase=info["video_timebase"], + audio_samples=self._audio_samples, audio_pts_range=(audio_start_pts, audio_end_pts), audio_timebase=audio_timebase, ) diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index 978ac31555a..36be33884b8 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -1,8 +1,8 @@ from .video import write_video, read_video, read_video_timestamps -from ._video_opt import _read_video_from_file, _read_video_timestamps_from_file +from ._video_opt import _read_video_from_file, _read_video_timestamps_from_file, _HAS_VIDEO_OPT __all__ = [ 'write_video', 'read_video', 'read_video_timestamps', - '_read_video_from_file', '_read_video_timestamps_from_file', + '_read_video_from_file', '_read_video_timestamps_from_file', '_HAS_VIDEO_OPT', ]