Skip to content

Commit 49b01e3

Browse files
stephenyan1231fmassa
authored andcommitted
add metadata to video dataset classes. bug fix. more robustness (#1376)
* add metadata to video dataset classes. bug fix. more robustness * query video backend within VideoClips class * Fix tests * Fix lint
1 parent d02db17 commit 49b01e3

File tree

7 files changed

+82
-26
lines changed

7 files changed

+82
-26
lines changed

test/test_datasets_video_utils.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from torchvision import io
88
from torchvision.datasets.video_utils import VideoClips, unfold
9-
from torchvision import get_video_backend
109

1110
from common_utils import get_tmp_dir
1211

@@ -62,23 +61,22 @@ def test_unfold(self):
6261
@unittest.skipIf(not io.video._av_available(), "this test requires av")
6362
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
6463
def test_video_clips(self):
65-
_backend = get_video_backend()
6664
with get_list_of_videos(num_videos=3) as video_list:
67-
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
65+
video_clips = VideoClips(video_list, 5, 5)
6866
self.assertEqual(video_clips.num_clips(), 1 + 2 + 3)
6967
for i, (v_idx, c_idx) in enumerate([(0, 0), (1, 0), (1, 1), (2, 0), (2, 1), (2, 2)]):
7068
video_idx, clip_idx = video_clips.get_clip_location(i)
7169
self.assertEqual(video_idx, v_idx)
7270
self.assertEqual(clip_idx, c_idx)
7371

74-
video_clips = VideoClips(video_list, 6, 6, _backend=_backend)
72+
video_clips = VideoClips(video_list, 6, 6)
7573
self.assertEqual(video_clips.num_clips(), 0 + 1 + 2)
7674
for i, (v_idx, c_idx) in enumerate([(1, 0), (2, 0), (2, 1)]):
7775
video_idx, clip_idx = video_clips.get_clip_location(i)
7876
self.assertEqual(video_idx, v_idx)
7977
self.assertEqual(clip_idx, c_idx)
8078

81-
video_clips = VideoClips(video_list, 6, 1, _backend=_backend)
79+
video_clips = VideoClips(video_list, 6, 1)
8280
self.assertEqual(video_clips.num_clips(), 0 + (10 - 6 + 1) + (15 - 6 + 1))
8381
for i, v_idx, c_idx in [(0, 1, 0), (4, 1, 4), (5, 2, 0), (6, 2, 1)]:
8482
video_idx, clip_idx = video_clips.get_clip_location(i)
@@ -87,9 +85,8 @@ def test_video_clips(self):
8785

8886
@unittest.skip("Moved to reference scripts for now")
8987
def test_video_sampler(self):
90-
_backend = get_video_backend()
9188
with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list:
92-
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
89+
video_clips = VideoClips(video_list, 5, 5)
9390
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
9491
self.assertEqual(len(sampler), 3 * 3)
9592
indices = torch.tensor(list(iter(sampler)))
@@ -100,9 +97,8 @@ def test_video_sampler(self):
10097

10198
@unittest.skip("Moved to reference scripts for now")
10299
def test_video_sampler_unequal(self):
103-
_backend = get_video_backend()
104100
with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list:
105-
video_clips = VideoClips(video_list, 5, 5, _backend=_backend)
101+
video_clips = VideoClips(video_list, 5, 5)
106102
sampler = RandomClipSampler(video_clips, 3) # noqa: F821
107103
self.assertEqual(len(sampler), 2 + 3 + 3)
108104
indices = list(iter(sampler))
@@ -120,11 +116,10 @@ def test_video_sampler_unequal(self):
120116
@unittest.skipIf(not io.video._av_available(), "this test requires av")
121117
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
122118
def test_video_clips_custom_fps(self):
123-
_backend = get_video_backend()
124119
with get_list_of_videos(num_videos=3, sizes=[12, 12, 12], fps=[3, 4, 6]) as video_list:
125120
num_frames = 4
126121
for fps in [1, 3, 4, 10]:
127-
video_clips = VideoClips(video_list, num_frames, num_frames, fps, _backend=_backend)
122+
video_clips = VideoClips(video_list, num_frames, num_frames, fps)
128123
for i in range(video_clips.num_clips()):
129124
video, audio, info, video_idx = video_clips.get_clip(i)
130125
self.assertEqual(video.shape[0], num_frames)

torchvision/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from torchvision import models
24
from torchvision import datasets
35
from torchvision import ops
@@ -57,7 +59,10 @@ def set_video_backend(backend):
5759
raise ValueError(
5860
"Invalid video backend '%s'. Options are 'pyav' and 'video_reader'" % backend
5961
)
60-
_video_backend = backend
62+
if backend == "video_reader" and not io._HAS_VIDEO_OPT:
63+
warnings.warn("video_reader video backend is not available")
64+
else:
65+
_video_backend = backend
6166

6267

6368
def get_video_backend():

torchvision/datasets/hmdb51.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import glob
22
import os
33

4-
from .video_utils import VideoClips
54
from .utils import list_dir
65
from .folder import make_dataset
6+
from .video_utils import VideoClips
77
from .vision import VisionDataset
88

99

@@ -51,7 +51,8 @@ class HMDB51(VisionDataset):
5151

5252
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
5353
frame_rate=None, fold=1, train=True, transform=None,
54-
_precomputed_metadata=None):
54+
_precomputed_metadata=None, num_workers=1, _video_width=0,
55+
_video_height=0, _video_min_dimension=0, _audio_samples=0):
5556
super(HMDB51, self).__init__(root)
5657
if not 1 <= fold <= 3:
5758
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
@@ -71,11 +72,21 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
7172
step_between_clips,
7273
frame_rate,
7374
_precomputed_metadata,
75+
num_workers=num_workers,
76+
_video_width=_video_width,
77+
_video_height=_video_height,
78+
_video_min_dimension=_video_min_dimension,
79+
_audio_samples=_audio_samples,
7480
)
81+
self.video_clips_metadata = video_clips.metadata
7582
self.indices = self._select_fold(video_list, annotation_path, fold, train)
7683
self.video_clips = video_clips.subset(self.indices)
7784
self.transform = transform
7885

86+
@property
87+
def metadata(self):
88+
return self.video_clips_metadata
89+
7990
def _select_fold(self, video_list, annotation_path, fold, train):
8091
target_tag = 1 if train else 2
8192
name = "*test_split{}.txt".format(fold)

torchvision/datasets/kinetics.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .video_utils import VideoClips
21
from .utils import list_dir
32
from .folder import make_dataset
3+
from .video_utils import VideoClips
44
from .vision import VisionDataset
55

66

@@ -37,7 +37,9 @@ class Kinetics400(VisionDataset):
3737
"""
3838

3939
def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
40-
extensions=('avi',), transform=None, _precomputed_metadata=None):
40+
extensions=('avi',), transform=None, _precomputed_metadata=None,
41+
num_workers=1, _video_width=0, _video_height=0,
42+
_video_min_dimension=0, _audio_samples=0):
4143
super(Kinetics400, self).__init__(root)
4244
extensions = ('avi',)
4345

@@ -52,9 +54,18 @@ def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
5254
step_between_clips,
5355
frame_rate,
5456
_precomputed_metadata,
57+
num_workers=num_workers,
58+
_video_width=_video_width,
59+
_video_height=_video_height,
60+
_video_min_dimension=_video_min_dimension,
61+
_audio_samples=_audio_samples,
5562
)
5663
self.transform = transform
5764

65+
@property
66+
def metadata(self):
67+
return self.video_clips.metadata
68+
5869
def __len__(self):
5970
return self.video_clips.num_clips()
6071

torchvision/datasets/ucf101.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import glob
22
import os
33

4-
from .video_utils import VideoClips
54
from .utils import list_dir
65
from .folder import make_dataset
6+
from .video_utils import VideoClips
77
from .vision import VisionDataset
88

99

@@ -44,7 +44,8 @@ class UCF101(VisionDataset):
4444

4545
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
4646
frame_rate=None, fold=1, train=True, transform=None,
47-
_precomputed_metadata=None):
47+
_precomputed_metadata=None, num_workers=1, _video_width=0,
48+
_video_height=0, _video_min_dimension=0, _audio_samples=0):
4849
super(UCF101, self).__init__(root)
4950
if not 1 <= fold <= 3:
5051
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
@@ -64,11 +65,21 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
6465
step_between_clips,
6566
frame_rate,
6667
_precomputed_metadata,
68+
num_workers=num_workers,
69+
_video_width=_video_width,
70+
_video_height=_video_height,
71+
_video_min_dimension=_video_min_dimension,
72+
_audio_samples=_audio_samples,
6773
)
74+
self.video_clips_metadata = video_clips.metadata
6875
self.indices = self._select_fold(video_list, annotation_path, fold, train)
6976
self.video_clips = video_clips.subset(self.indices)
7077
self.transform = transform
7178

79+
@property
80+
def metadata(self):
81+
return self.video_clips_metadata
82+
7283
def _select_fold(self, video_list, annotation_path, fold, train):
7384
name = "train" if train else "test"
7485
name = "{}list{:02d}.txt".format(name, fold)

torchvision/datasets/video_utils.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,18 @@ class VideoClips(object):
6868
0 means that the data will be loaded in the main process. (default: 0)
6969
"""
7070
def __init__(self, video_paths, clip_length_in_frames=16, frames_between_clips=1,
71-
frame_rate=None, _precomputed_metadata=None, num_workers=0, _backend="pyav"):
71+
frame_rate=None, _precomputed_metadata=None, num_workers=0,
72+
_video_width=0, _video_height=0, _video_min_dimension=0,
73+
_audio_samples=0):
74+
from torchvision import get_video_backend
75+
7276
self.video_paths = video_paths
7377
self.num_workers = num_workers
74-
self._backend = _backend
78+
self._backend = get_video_backend()
79+
self._video_width = _video_width
80+
self._video_height = _video_height
81+
self._video_min_dimension = _video_min_dimension
82+
self._audio_samples = _audio_samples
7583

7684
if _precomputed_metadata is None:
7785
self._compute_frame_pts()
@@ -145,6 +153,7 @@ def metadata(self):
145153
_metadata.update({"video_fps": self.video_fps})
146154
else:
147155
_metadata.update({"info": self.info})
156+
return _metadata
148157

149158
def subset(self, indices):
150159
video_paths = [self.video_paths[i] for i in indices]
@@ -162,7 +171,11 @@ def subset(self, indices):
162171
else:
163172
metadata.update({"info": info})
164173
return type(self)(video_paths, self.num_frames, self.step, self.frame_rate,
165-
_precomputed_metadata=metadata)
174+
_precomputed_metadata=metadata, num_workers=self.num_workers,
175+
_video_width=self._video_width,
176+
_video_height=self._video_height,
177+
_video_min_dimension=self._video_min_dimension,
178+
_audio_samples=self._audio_samples)
166179

167180
@staticmethod
168181
def compute_clips_for_video(video_pts, num_frames, step, fps, frame_rate):
@@ -206,9 +219,15 @@ def compute_clips(self, num_frames, step, frame_rate=None):
206219
self.resampling_idxs.append(idxs)
207220
else:
208221
for video_pts, info in zip(self.video_pts, self.info):
209-
clips, idxs = self.compute_clips_for_video(video_pts, num_frames, step, info["video_fps"], frame_rate)
210-
self.clips.append(clips)
211-
self.resampling_idxs.append(idxs)
222+
if "video_fps" in info:
223+
clips, idxs = self.compute_clips_for_video(
224+
video_pts, num_frames, step, info["video_fps"], frame_rate)
225+
self.clips.append(clips)
226+
self.resampling_idxs.append(idxs)
227+
else:
228+
# properly handle the cases where video decoding fails
229+
self.clips.append(torch.zeros(0, num_frames, dtype=torch.int64))
230+
self.resampling_idxs.append(torch.zeros(0, dtype=torch.int64))
212231
clip_lengths = torch.as_tensor([len(v) for v in self.clips])
213232
self.cumulative_sizes = clip_lengths.cumsum(0).tolist()
214233

@@ -296,8 +315,12 @@ def get_clip(self, idx):
296315
)
297316
video, audio, info = _read_video_from_file(
298317
video_path,
318+
video_width=self._video_width,
319+
video_height=self._video_height,
320+
video_min_dimension=self._video_min_dimension,
299321
video_pts_range=(video_start_pts, video_end_pts),
300322
video_timebase=info["video_timebase"],
323+
audio_samples=self._audio_samples,
301324
audio_pts_range=(audio_start_pts, audio_end_pts),
302325
audio_timebase=audio_timebase,
303326
)

torchvision/io/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from .video import write_video, read_video, read_video_timestamps
2-
from ._video_opt import _read_video_from_file, _read_video_timestamps_from_file
2+
from ._video_opt import _read_video_from_file, _read_video_timestamps_from_file, _HAS_VIDEO_OPT
33

44

55
__all__ = [
66
'write_video', 'read_video', 'read_video_timestamps',
7-
'_read_video_from_file', '_read_video_timestamps_from_file',
7+
'_read_video_from_file', '_read_video_timestamps_from_file', '_HAS_VIDEO_OPT',
88
]

0 commit comments

Comments
 (0)