diff --git a/torchvision/datasets/ucf101.py b/torchvision/datasets/ucf101.py index 464eb0018f2..18ce07d0801 100644 --- a/torchvision/datasets/ucf101.py +++ b/torchvision/datasets/ucf101.py @@ -1,4 +1,3 @@ -import glob import os from .utils import list_dir @@ -50,17 +49,28 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, if not 1 <= fold <= 3: raise ValueError("fold should be between 1 and 3, got {}".format(fold)) - extensions = ('avi',) self.fold = fold self.train = train + self.transform = transform + + # Create class to index mapping with sorted class names + self.classes = list(sorted(list_dir(root))) + class_to_idx = {c: i for i, c in enumerate(self.classes)} + + # Iterate through root directory to retrieve the path and the labels + # for each dataset example + self.samples = make_dataset( + self.root, class_to_idx, ('avi',), is_valid_file=None) - classes = list(sorted(list_dir(root))) - class_to_idx = {classes[i]: i for i in range(len(classes))} - self.samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file=None) - self.classes = classes - video_list = [x[0] for x in self.samples] - video_clips = VideoClips( - video_list, + # Get the video paths that belong to the selected fold and split + _video_paths_in_fold = self._fold_paths(annotation_path, fold, train) + # Filter the dataset samples so only the video paths belonging to the + # selected fold are processed + self.samples = [o for o in self.samples if o[0] in _video_paths_in_fold] + + # At this point, only the needed videos' path are selected + self.video_clips = VideoClips( + [x[0] for x in self.samples], frames_per_clip, step_between_clips, frame_rate, @@ -71,35 +81,30 @@ def __init__(self, root, annotation_path, frames_per_clip, step_between_clips=1, _video_min_dimension=_video_min_dimension, _audio_samples=_audio_samples, ) - self.video_clips_metadata = video_clips.metadata - self.indices = self._select_fold(video_list, annotation_path, fold, train) - self.video_clips = video_clips.subset(self.indices) - self.transform = transform + self.video_clips_metadata = self.video_clips.metadata @property def metadata(self): return self.video_clips_metadata - def _select_fold(self, video_list, annotation_path, fold, train): - name = "train" if train else "test" - name = "{}list{:02d}.txt".format(name, fold) + def _fold_paths(self, annotation_path, fold, train): + split = 'train' if train else 'test' + name = f'{split}list{fold:02d}.txt' f = os.path.join(annotation_path, name) - selected_files = [] + with open(f, "r") as fid: - data = fid.readlines() - data = [x.strip().split(" ") for x in data] - data = [os.path.join(self.root, x[0]) for x in data] - selected_files.extend(data) - selected_files = set(selected_files) - indices = [i for i in range(len(video_list)) if video_list[i] in selected_files] - return indices + video_files = fid.readlines() + video_files = [o.strip().split(" ")[0] for o in video_files] + video_files = [os.path.join(self.root, o) for o in video_files] + video_files = set(video_files) + return video_files def __len__(self): return self.video_clips.num_clips() def __getitem__(self, idx): video, audio, info, video_idx = self.video_clips.get_clip(idx) - label = self.samples[self.indices[video_idx]][1] + label = self.samples[video_idx][1] if self.transform is not None: video = self.transform(video)