diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 192babf62dc..74852c2f721 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -11,9 +11,10 @@ import torchvision import torchvision.datasets.video_utils from torchvision import transforms +from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler import utils -from sampler import DistributedSampler, UniformClipSampler, RandomClipSampler + from scheduler import WarmupMultiStepLR import transforms as T diff --git a/test/test_datasets_samplers.py b/test/test_datasets_samplers.py new file mode 100644 index 00000000000..f99c63e65d3 --- /dev/null +++ b/test/test_datasets_samplers.py @@ -0,0 +1,88 @@ +import contextlib +import sys +import os +import torch +import unittest + +from torchvision import io +from torchvision.datasets.samplers import RandomClipSampler, UniformClipSampler +from torchvision.datasets.video_utils import VideoClips, unfold +from torchvision import get_video_backend + +from common_utils import get_tmp_dir + + +@contextlib.contextmanager +def get_list_of_videos(num_videos=5, sizes=None, fps=None): + with get_tmp_dir() as tmp_dir: + names = [] + for i in range(num_videos): + if sizes is None: + size = 5 * (i + 1) + else: + size = sizes[i] + if fps is None: + f = 5 + else: + f = fps[i] + data = torch.randint(0, 255, (size, 300, 400, 3), dtype=torch.uint8) + name = os.path.join(tmp_dir, "{}.mp4".format(i)) + names.append(name) + io.write_video(name, data, fps=f) + + yield names + + +@unittest.skipIf(not io.video._av_available(), "this test requires av") +class Tester(unittest.TestCase): + def test_random_clip_sampler(self): + with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: + video_clips = VideoClips(video_list, 5, 5) + sampler = RandomClipSampler(video_clips, 3) + self.assertEqual(len(sampler), 3 * 3) + indices = torch.tensor(list(iter(sampler))) + videos = indices // 5 + v_idxs, count = torch.unique(videos, return_counts=True) + self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2]))) + self.assertTrue(count.equal(torch.tensor([3, 3, 3]))) + + def test_random_clip_sampler_unequal(self): + with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: + video_clips = VideoClips(video_list, 5, 5) + sampler = RandomClipSampler(video_clips, 3) + self.assertEqual(len(sampler), 2 + 3 + 3) + indices = list(iter(sampler)) + self.assertIn(0, indices) + self.assertIn(1, indices) + # remove elements of the first video, to simplify testing + indices.remove(0) + indices.remove(1) + indices = torch.tensor(indices) - 2 + videos = indices // 5 + v_idxs, count = torch.unique(videos, return_counts=True) + self.assertTrue(v_idxs.equal(torch.tensor([0, 1]))) + self.assertTrue(count.equal(torch.tensor([3, 3]))) + + def test_uniform_clip_sampler(self): + with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: + video_clips = VideoClips(video_list, 5, 5) + sampler = UniformClipSampler(video_clips, 3) + self.assertEqual(len(sampler), 3 * 3) + indices = torch.tensor(list(iter(sampler))) + videos = indices // 5 + v_idxs, count = torch.unique(videos, return_counts=True) + self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2]))) + self.assertTrue(count.equal(torch.tensor([3, 3, 3]))) + self.assertTrue(indices.equal(torch.tensor([0, 2, 4, 5, 7, 9, 10, 12, 14]))) + + def test_uniform_clip_sampler_insufficient_clips(self): + with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: + video_clips = VideoClips(video_list, 5, 5) + sampler = UniformClipSampler(video_clips, 3) + self.assertEqual(len(sampler), 3 * 3) + indices = torch.tensor(list(iter(sampler))) + self.assertTrue(indices.equal(torch.tensor([0, 0, 1, 2, 4, 6, 7, 9, 11]))) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_datasets_video_utils.py b/test/test_datasets_video_utils.py index a9cb7ab50ef..ccca068d367 100644 --- a/test/test_datasets_video_utils.py +++ b/test/test_datasets_video_utils.py @@ -83,36 +83,6 @@ def test_video_clips(self): self.assertEqual(video_idx, v_idx) self.assertEqual(clip_idx, c_idx) - @unittest.skip("Moved to reference scripts for now") - def test_video_sampler(self): - with get_list_of_videos(num_videos=3, sizes=[25, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5) - sampler = RandomClipSampler(video_clips, 3) # noqa: F821 - self.assertEqual(len(sampler), 3 * 3) - indices = torch.tensor(list(iter(sampler))) - videos = indices // 5 - v_idxs, count = torch.unique(videos, return_counts=True) - self.assertTrue(v_idxs.equal(torch.tensor([0, 1, 2]))) - self.assertTrue(count.equal(torch.tensor([3, 3, 3]))) - - @unittest.skip("Moved to reference scripts for now") - def test_video_sampler_unequal(self): - with get_list_of_videos(num_videos=3, sizes=[10, 25, 25]) as video_list: - video_clips = VideoClips(video_list, 5, 5) - sampler = RandomClipSampler(video_clips, 3) # noqa: F821 - self.assertEqual(len(sampler), 2 + 3 + 3) - indices = list(iter(sampler)) - self.assertIn(0, indices) - self.assertIn(1, indices) - # remove elements of the first video, to simplify testing - indices.remove(0) - indices.remove(1) - indices = torch.tensor(indices) - 2 - videos = indices // 5 - v_idxs, count = torch.unique(videos, return_counts=True) - self.assertTrue(v_idxs.equal(torch.tensor([0, 1]))) - self.assertTrue(count.equal(torch.tensor([3, 3]))) - @unittest.skipIf(not io.video._av_available(), "this test requires av") @unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows') def test_video_clips_custom_fps(self): diff --git a/torchvision/datasets/samplers/__init__.py b/torchvision/datasets/samplers/__init__.py new file mode 100644 index 00000000000..870322d39b4 --- /dev/null +++ b/torchvision/datasets/samplers/__init__.py @@ -0,0 +1,3 @@ +from .clip_sampler import DistributedSampler, UniformClipSampler, RandomClipSampler + +__all__ = ('DistributedSampler', 'UniformClipSampler', 'RandomClipSampler') diff --git a/references/video_classification/sampler.py b/torchvision/datasets/samplers/clip_sampler.py similarity index 81% rename from references/video_classification/sampler.py rename to torchvision/datasets/samplers/clip_sampler.py index b92dad013c6..3d4c788fc61 100644 --- a/references/video_classification/sampler.py +++ b/torchvision/datasets/samplers/clip_sampler.py @@ -60,33 +60,45 @@ def set_epoch(self, epoch): class UniformClipSampler(torch.utils.data.Sampler): """ - Samples at most `max_video_clips_per_video` clips for each video, equally spaced + Sample `num_video_clips_per_video` clips for each video, equally spaced. + When number of unique clips in the video is fewer than num_video_clips_per_video, + repeat the clips until `num_video_clips_per_video` clips are collected + Arguments: video_clips (VideoClips): video clips to sample from - max_clips_per_video (int): maximum number of clips to be sampled per video + num_clips_per_video (int): number of clips to be sampled per video """ - def __init__(self, video_clips, max_clips_per_video): + def __init__(self, video_clips, num_clips_per_video): if not isinstance(video_clips, torchvision.datasets.video_utils.VideoClips): raise TypeError("Expected video_clips to be an instance of VideoClips, " "got {}".format(type(video_clips))) self.video_clips = video_clips - self.max_clips_per_video = max_clips_per_video + self.num_clips_per_video = num_clips_per_video def __iter__(self): idxs = [] s = 0 - # select at most max_clips_per_video for each video, uniformly spaced + # select num_clips_per_video for each video, uniformly spaced for c in self.video_clips.clips: length = len(c) - step = max(length // self.max_clips_per_video, 1) - sampled = torch.arange(length)[::step] + s + if length == 0: + # corner case where video decoding fails + continue + + sampled = ( + torch.linspace(s, s + length - 1, steps=self.num_clips_per_video) + .floor() + .to(torch.int64) + ) s += length idxs.append(sampled) idxs = torch.cat(idxs).tolist() return iter(idxs) def __len__(self): - return sum(min(len(c), self.max_clips_per_video) for c in self.video_clips.clips) + return sum( + self.num_clips_per_video for c in self.video_clips.clips if len(c) > 0 + ) class RandomClipSampler(torch.utils.data.Sampler):