Skip to content

Commit 85ffd93

Browse files
authored
Expose frame-rate and cache to video datasets (#1356)
1 parent 31fad34 commit 85ffd93

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

torchvision/datasets/hmdb51.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ class HMDB51(VisionDataset):
5050
}
5151

5252
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
53-
fold=1, train=True, transform=None):
53+
frame_rate=None, fold=1, train=True, transform=None,
54+
_precomputed_metadata=None):
5455
super(HMDB51, self).__init__(root)
5556
if not 1 <= fold <= 3:
5657
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
@@ -64,7 +65,13 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
6465
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
6566
self.classes = classes
6667
video_list = [x[0] for x in self.samples]
67-
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
68+
video_clips = VideoClips(
69+
video_list,
70+
frames_per_clip,
71+
step_between_clips,
72+
frame_rate,
73+
_precomputed_metadata,
74+
)
6875
self.indices = self._select_fold(video_list, annotation_path, fold, train)
6976
self.video_clips = video_clips.subset(self.indices)
7077
self.transform = transform

torchvision/datasets/kinetics.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ class Kinetics400(VisionDataset):
3636
label (int): class of the video clip
3737
"""
3838

39-
def __init__(self, root, frames_per_clip, step_between_clips=1, transform=None):
39+
def __init__(self, root, frames_per_clip, step_between_clips=1, frame_rate=None,
40+
extensions=('avi',), transform=None, _precomputed_metadata=None):
4041
super(Kinetics400, self).__init__(root)
4142
extensions = ('avi',)
4243

@@ -45,7 +46,13 @@ def __init__(self, root, frames_per_clip, step_between_clips=1, transform=None):
4546
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
4647
self.classes = classes
4748
video_list = [x[0] for x in self.samples]
48-
self.video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
49+
self.video_clips = VideoClips(
50+
video_list,
51+
frames_per_clip,
52+
step_between_clips,
53+
frame_rate,
54+
_precomputed_metadata,
55+
)
4956
self.transform = transform
5057

5158
def __len__(self):

torchvision/datasets/ucf101.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class UCF101(VisionDataset):
4343
"""
4444

4545
def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
46-
fold=1, train=True, transform=None):
46+
frame_rate=None, fold=1, train=True, transform=None,
47+
_precomputed_metadata=None):
4748
super(UCF101, self).__init__(root)
4849
if not 1 <= fold <= 3:
4950
raise ValueError("fold should be between 1 and 3, got {}".format(fold))
@@ -57,7 +58,13 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1,
5758
self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None)
5859
self.classes = classes
5960
video_list = [x[0] for x in self.samples]
60-
video_clips = VideoClips(video_list, frames_per_clip, step_between_clips)
61+
video_clips = VideoClips(
62+
video_list,
63+
frames_per_clip,
64+
step_between_clips,
65+
frame_rate,
66+
_precomputed_metadata,
67+
)
6168
self.indices = self._select_fold(video_list, annotation_path, fold, train)
6269
self.video_clips = video_clips.subset(self.indices)
6370
self.transform = transform

0 commit comments

Comments
 (0)