From da670f0c38781a4cc4635449f43418aa47f8a265 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Feb 2022 08:50:22 +0100 Subject: [PATCH 1/9] add hmdb51 dataset and prototype for new style video decoding --- torchvision/datasets/utils.py | 10 + .../prototype/datasets/_builtin/__init__.py | 1 + .../datasets/_builtin/hmdb51.categories | 51 +++++ .../prototype/datasets/_builtin/hmdb51.py | 116 ++++++++++ .../prototype/datasets/utils/__init__.py | 1 + .../prototype/datasets/utils/_video.py | 209 ++++++++++++++++++ torchvision/prototype/features/_encoded.py | 31 ++- torchvision/prototype/features/_image.py | 9 +- torchvision/prototype/utils/_internal.py | 19 ++ 9 files changed, 436 insertions(+), 11 deletions(-) create mode 100644 torchvision/prototype/datasets/_builtin/hmdb51.categories create mode 100644 torchvision/prototype/datasets/_builtin/hmdb51.py create mode 100644 torchvision/prototype/datasets/utils/_video.py diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index dbc9cf2a6b4..9ec4501e3b1 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -1,4 +1,5 @@ import bz2 +import contextlib import gzip import hashlib import itertools @@ -301,6 +302,15 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No ".tgz": (".tar", ".gz"), } +with contextlib.suppress(ImportError): + import rarfile + + def _extract_rar(from_path: str, to_path: str, compression: Optional[str]) -> None: + with rarfile.RarFile(from_path, "r") as rar: + rar.extractall(to_path) + + _ARCHIVE_EXTRACTORS[".rar"] = _extract_rar + def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]: """Detect the archive type and/or compression of a file. diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index 9fdfca904f5..bab1cf873ac 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -7,6 +7,7 @@ from .dtd import DTD from .fer2013 import FER2013 from .gtsrb import GTSRB +from .hmdb51 import HMDB51 from .imagenet import ImageNet from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .oxford_iiit_pet import OxfordIITPet diff --git a/torchvision/prototype/datasets/_builtin/hmdb51.categories b/torchvision/prototype/datasets/_builtin/hmdb51.categories new file mode 100644 index 00000000000..3217416f524 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/hmdb51.categories @@ -0,0 +1,51 @@ +brush_hair +cartwheel +catch +chew +clap +climb +climb_stairs +dive +draw_sword +dribble +drink +eat +fall_floor +fencing +flic_flac +golf +handstand +hit +hug +jump +kick +kick_ball +kiss +laugh +pick +pour +pullup +punch +push +pushup +ride_bike +ride_horse +run +shake_hands +shoot_ball +shoot_bow +shoot_gun +sit +situp +smile +smoke +somersault +stand +swing_baseball +sword +sword_exercise +talk +throw +turn +walk +wave diff --git a/torchvision/prototype/datasets/_builtin/hmdb51.py b/torchvision/prototype/datasets/_builtin/hmdb51.py new file mode 100644 index 00000000000..e6b453ae31c --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/hmdb51.py @@ -0,0 +1,116 @@ +import functools +import pathlib +import re +from typing import Any, Dict, List, Tuple, BinaryIO + +from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, IterKeyZipper +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, +) +from torchvision.prototype.datasets.utils._internal import ( + INFINITE_BUFFER_SIZE, + getitem, + path_accessor, + hint_sharding, + hint_shuffling, +) +from torchvision.prototype.features import EncodedVideo, Label + + +class HMDB51(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "hmdb51", + homepage="https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/", + valid_options=dict( + split=("train", "test"), + split_number=("1", "2", "3"), + ), + ) + + def _extract_videos_archive(self, path: pathlib.Path) -> pathlib.Path: + folder = OnlineResource._extract(path) + for rar_file in folder.glob("*.rar"): + OnlineResource._extract(rar_file) + rar_file.unlink() + return folder + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + url_root = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10" + + splits = HttpResource( + f"{url_root}/test_train_splits.rar", + sha256="229c94f845720d01eb3946d39f39292ea962d50a18136484aa47c1eba251d2b7", + ) + videos = HttpResource( + f"{url_root}/hmdb51_org.rar", + sha256="9e714a0d8b76104d76e932764a7ca636f929fff66279cda3f2e326fa912a328e", + ) + videos._preprocess = self._extract_videos_archive + return [splits, videos] + + _SPLIT_FILE_PATTERN = re.compile(r"(?P\w+?)_test_split(?P[1-3])[.]txt") + + def _is_split_number(self, data: Tuple[str, Any], *, split_number: str) -> bool: + path = pathlib.Path(data[0]) + return self._SPLIT_FILE_PATTERN.match(path.name)["split_number"] == split_number # type: ignore[union-attr] + + _SPLIT_ID_TO_NAME = { + "1": "train", + "2": "test", + } + + def _is_split(self, data: Dict[str, Any], *, split: str) -> bool: + split_id = data["split_id"] + + # TODO: explain + if split_id not in self._SPLIT_ID_TO_NAME: + return False + + return self._SPLIT_ID_TO_NAME[split_id] == split + + def _prepare_sample(self, data: Tuple[List[str], Tuple[str, BinaryIO]]) -> Dict[str, Any]: + _, (path, buffer) = data + path = pathlib.Path(path) + return dict( + label=Label.from_category(path.parent.name, categories=self.categories), + video=EncodedVideo.from_file(buffer, path=path), + ) + + def _make_datapipe( + self, + resource_dps: List[IterDataPipe], + *, + config: DatasetConfig, + ) -> IterDataPipe[Dict[str, Any]]: + splits_dp, videos_dp = resource_dps + + splits_dp = Filter(splits_dp, functools.partial(self._is_split_number, split_number=config.split_number)) + splits_dp = CSVDictParser(splits_dp, fieldnames=("filename", "split_id"), delimiter=" ") + splits_dp = Filter(splits_dp, functools.partial(self._is_split, split=config.split)) + splits_dp = hint_sharding(splits_dp) + splits_dp = hint_shuffling(splits_dp) + + dp = IterKeyZipper( + splits_dp, + videos_dp, + key_fn=getitem("filename"), + ref_key_fn=path_accessor("name"), + buffer_size=INFINITE_BUFFER_SIZE, + ) + return Mapper(dp, self._prepare_sample) + + def _generate_categories(self, root: pathlib.Path) -> List[str]: + config = self.default_config + resources = self.resources(config) + + dp = resources[0].load(root) + categories = { + self._SPLIT_FILE_PATTERN.match(pathlib.Path(path).name)["category"] # type: ignore[union-attr] + for path, _ in dp + } + return sorted(categories) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 9423b65a8ee..89668a3b3d2 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -2,3 +2,4 @@ from ._dataset import DatasetConfig, DatasetInfo, Dataset from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource +from ._video import KeyframeDecoder, RandomFrameDecoder, ClipDecoder diff --git a/torchvision/prototype/datasets/utils/_video.py b/torchvision/prototype/datasets/utils/_video.py new file mode 100644 index 00000000000..1cd2650f206 --- /dev/null +++ b/torchvision/prototype/datasets/utils/_video.py @@ -0,0 +1,209 @@ +import random +from typing import Any, Dict, Iterator, BinaryIO, Optional, Tuple + +import av +import numpy as np +import torch +from torchdata.datapipes.iter import IterDataPipe +from torchvision.io import video, _video_opt +from torchvision.prototype.features import Image, EncodedVideo +from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer, query_recursively + + +class _VideoDecoder(IterDataPipe): + def __init__(self, datapipe: IterDataPipe, *, inline: bool = True) -> None: + self.datapipe = datapipe + self._inline = inline + + def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + raise NotImplementedError + + def _find_encoded_video(self, id: Tuple[Any, ...], obj: Any) -> Optional[Tuple[Any, ...]]: + if isinstance(obj, EncodedVideo): + return id, obj + else: + return None + + def _integrate_data(self, sample: Any, id: Tuple[Any, ...], data: Dict[str, Any]) -> Any: + if not self._inline: + return sample, data + elif not id: + return data + + grand_parent = None + parent = sample + for item in id[:-1]: + grand_parent = parent + parent = parent[item] + + if not isinstance(parent, dict): + raise TypeError( + f"Could not inline the decoded video data, " + f"since the container at item {''.join(str([item]) for item in id[:-1])} " + f"that holds the `EncodedVideo` at item {[id[-1]]} is not a 'dict' but a '{type(parent).__name__}'. " + f"If you don't want to automatically inline the decoded video data, construct the decoder with " + f"{type(self).__name__}(..., inline=False). This will change the return type to a tuple of the input " + f"and the decoded video data for each iteration." + ) + + parent = parent.copy() + del parent[id[-1]] + parent.update(data) + + if not grand_parent: + return parent + + grand_parent[id[-2]] = parent + return sample + + def __iter__(self) -> Iterator[Any]: + for sample in self.datapipe: + ids_and_videos = list(query_recursively(self._find_encoded_video, sample)) + if not ids_and_videos: + raise TypeError("no encoded video") + elif len(ids_and_videos) > 1: + raise ValueError("more than one encoded video") + id, video = ids_and_videos[0] + + buffer = ReadOnlyTensorBuffer(video) + for data in self._decode(buffer, video.meta.copy()): + yield self._integrate_data(sample, id, data) + + +class KeyframeDecoder(_VideoDecoder): + def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + with av.open(buffer, metadata_errors="ignore") as container: + stream = container.streams.video[0] + stream.codec_context.skip_frame = "NONKEY" + for frame in container.decode(stream): + yield dict( + frame=Image.from_pil(frame.to_image()), + pts=frame.pts, + video_meta=dict( + meta, + time_base=float(frame.time_base), + guessed_fps=float(stream.guessed_rate), + ), + ) + + +class RandomFrameDecoder(_VideoDecoder): + def __init__(self, datapipe: IterDataPipe, *, num_samples: int = 1, inline: bool = True) -> None: + super().__init__(datapipe, inline=inline) + self.num_sampler = num_samples + + def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + with av.open(buffer, metadata_errors="ignore") as container: + stream = container.streams.video[0] + # duration is given in time_base units as int + duration = stream.duration + # seek to a random frame + seek_idxs = random.sample(list(range(duration)), self.num_samples) + for i in seek_idxs: + container.seek(i, any_frame=True, stream=stream) + frame = next(container.decode(stream)) + yield dict( + frame=Image.from_pil(frame.to_image()), + pts=frame.pts, + video_meta=dict( + time_base=float(frame.time_base), + guessed_fps=float(stream.guessed_rate), + ), + ) + + +class ClipDecoder(_VideoDecoder): + def __init__( + self, + datapipe: IterDataPipe, + *, + num_frames_per_clip: int = 8, + num_clips_per_video: int = 1, + step_between_clips: int = 1, + inline: bool = True, + ) -> None: + super().__init__(datapipe, inline=inline) + self.num_frames_per_clip = num_frames_per_clip + self.num_clips_per_video = num_clips_per_video + self.step_between_clips = step_between_clips + + def _unfold(self, tensor: torch.Tensor, dilation: int = 1) -> torch.Tensor: + """ + similar to tensor.unfold, but with the dilation + and specialized for 1d tensors + Returns all consecutive windows of `self.num_frames_per_clip` elements, with + `self.step_between_clips` between windows. The distance between each element + in a window is given by `dilation`. + """ + assert tensor.dim() == 1 + o_stride = tensor.stride(0) + numel = tensor.numel() + new_stride = (self.step_between_clips * o_stride, dilation * o_stride) + new_size = ( + (numel - (dilation * (self.num_frames_per_clip - 1) + 1)) // self.step_between_clips + 1, + self.num_frames_per_clip, + ) + if new_size[0] < 1: + new_size = (0, self.num_frames_per_clip) + return torch.as_strided(tensor, new_size, new_stride) + + def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + with av.open(buffer, metadata_errors="ignore") as container: + stream = container.streams.video[0] + time_base = stream.time_base + + # duration is given in time_base units as int + duration = stream.duration + + # get video_stream timestramps + # with a tolerance for pyav imprecission + _ptss = torch.arange(duration - 7) + _ptss = self._unfold(_ptss) + # shuffle the clips + perm = torch.randperm(_ptss.size(0)) + idx = perm[: self.num_clips_per_video] + samples = _ptss[idx] + + for clip_pts in samples: + start_pts = clip_pts[0].item() + end_pts = clip_pts[-1].item() + # video_timebase is the default time_base + pts_unit = "pts" + start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, "pts", time_base) + video_frames = video._read_from_stream( + container, + float(start_pts), + float(end_pts), + pts_unit, + stream, + {"video": 0}, + ) + + vframes_list = [frame.to_ndarray(format="rgb24") for frame in video_frames] + + if vframes_list: + vframes = torch.as_tensor(np.stack(vframes_list)) + # account for rounding errors in conversion + # FIXME: fix this in the code + vframes = vframes[: self.num_frames_per_clip, ...] + + else: + vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) + print("FAIL") + + # [N,H,W,C] to [N,C,H,W] + vframes = vframes.permute(0, 3, 1, 2) + assert vframes.size(0) == self.num_frames_per_clip + + # TODO: support sampling rates (FPS change) + # TODO: optimization (read all and select) + + yield { + "clip": vframes, + "pts": clip_pts, + "range": (start_pts, end_pts), + "video_meta": { + "time_base": float(stream.time_base), + "guessed_fps": float(stream.guessed_rate), + }, + } diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index ab6b821d673..64a1a5f3326 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -1,6 +1,7 @@ import os +import pathlib import sys -from typing import BinaryIO, Tuple, Type, TypeVar, Union, Optional, Any +from typing import BinaryIO, Tuple, Type, TypeVar, Union, Dict, Any, Optional import PIL.Image import torch @@ -9,23 +10,33 @@ from ._feature import _Feature from ._image import Image -D = TypeVar("D", bound="EncodedData") +E = TypeVar("E", bound="EncodedData") class EncodedData(_Feature): - @classmethod - def _to_tensor(cls, data: Any, *, dtype: Optional[torch.dtype], device: Optional[torch.device]) -> torch.Tensor: - # TODO: warn / bail out if we encounter a tensor with shape other than (N,) or with dtype other than uint8? - return super()._to_tensor(data, dtype=dtype, device=device) + meta: Dict[str, Any] + + def __new__( + cls: Type[E], + data: Any, + *, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + **meta: Any, + ) -> E: + encoded_data = super().__new__(cls, data, dtype=dtype, device=device) + encoded_data._metadata.update(dict(meta=meta)) + return encoded_data @classmethod - def from_file(cls: Type[D], file: BinaryIO) -> D: - return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder)) + def from_file(cls: Type[E], file: BinaryIO, **meta: Any) -> E: + return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **meta) @classmethod - def from_path(cls: Type[D], path: Union[str, os.PathLike]) -> D: + def from_path(cls: Type[E], path: Union[str, os.PathLike]) -> E: + path = pathlib.Path(path) with open(path, "rb") as file: - return cls.from_file(file) + return cls.from_file(file, path=path) class EncodedImage(EncodedData): diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 5ecc4cbedb7..cbc8d542121 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -3,9 +3,10 @@ import warnings from typing import Any, Optional, Union, Tuple, cast +import PIL.Image import torch from torchvision.prototype.utils._internal import StrEnum -from torchvision.transforms.functional import to_pil_image +from torchvision.transforms.functional import to_pil_image, pil_to_tensor from torchvision.utils import draw_bounding_boxes from torchvision.utils import make_grid @@ -75,6 +76,12 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace: else: return ColorSpace.OTHER + @classmethod + def from_pil( + cls, image: PIL.Image.Image, *, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None + ) -> "Image": + return cls(pil_to_tensor(image), dtype=dtype, device=device) + def show(self) -> None: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # promote this out of the prototype state diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index fe75c19eb75..8536b829df4 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -26,6 +26,7 @@ Union, List, Dict, + Optional, ) import numpy as np @@ -42,6 +43,7 @@ "fromfile", "ReadOnlyTensorBuffer", "apply_recursively", + "query_recursively", ] @@ -324,3 +326,20 @@ def apply_recursively(fn: Callable, obj: Any) -> Any: return mapping else: return fn(obj) + + +def query_recursively( + fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = () +) -> Iterator[D]: + # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: + # "a" == "a"[0][0]... + if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): + for idx, item in enumerate(obj): + yield from query_recursively(fn, item, id=(*id, idx)) + elif isinstance(obj, collections.abc.Mapping): + for key, item in obj.items(): + yield from query_recursively(fn, item, id=(*id, key)) + else: + result = fn(id, obj) + if result is not None: + yield result From 7bcb0089d9e96c7bb550af7a76b97bd7b038a519 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Feb 2022 13:53:37 +0100 Subject: [PATCH 2/9] port UCF101 --- .../prototype/datasets/_builtin/__init__.py | 1 + .../prototype/datasets/_builtin/hmdb51.py | 1 + .../datasets/_builtin/ucf101.categories | 101 ++++++++++++++++++ .../prototype/datasets/_builtin/ucf101.py | 94 ++++++++++++++++ torchvision/prototype/features/_encoded.py | 4 +- 5 files changed, 199 insertions(+), 2 deletions(-) create mode 100644 torchvision/prototype/datasets/_builtin/ucf101.categories create mode 100644 torchvision/prototype/datasets/_builtin/ucf101.py diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index bab1cf873ac..61f6278ece7 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -15,4 +15,5 @@ from .sbd import SBD from .semeion import SEMEION from .svhn import SVHN +from .ucf101 import UCF101 from .voc import VOC diff --git a/torchvision/prototype/datasets/_builtin/hmdb51.py b/torchvision/prototype/datasets/_builtin/hmdb51.py index e6b453ae31c..5e9c905d60a 100644 --- a/torchvision/prototype/datasets/_builtin/hmdb51.py +++ b/torchvision/prototype/datasets/_builtin/hmdb51.py @@ -26,6 +26,7 @@ def _make_info(self) -> DatasetInfo: return DatasetInfo( "hmdb51", homepage="https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/", + dependencies=("rarfile",), valid_options=dict( split=("train", "test"), split_number=("1", "2", "3"), diff --git a/torchvision/prototype/datasets/_builtin/ucf101.categories b/torchvision/prototype/datasets/_builtin/ucf101.categories new file mode 100644 index 00000000000..dd41d095c7c --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/ucf101.categories @@ -0,0 +1,101 @@ +ApplyEyeMakeup +ApplyLipstick +Archery +BabyCrawling +BalanceBeam +BandMarching +BaseballPitch +Basketball +BasketballDunk +BenchPress +Biking +Billiards +BlowDryHair +BlowingCandles +BodyWeightSquats +Bowling +BoxingPunchingBag +BoxingSpeedBag +BreastStroke +BrushingTeeth +CleanAndJerk +CliffDiving +CricketBowling +CricketShot +CuttingInKitchen +Diving +Drumming +Fencing +FieldHockeyPenalty +FloorGymnastics +FrisbeeCatch +FrontCrawl +GolfSwing +Haircut +Hammering +HammerThrow +HandstandPushups +HandstandWalking +HeadMassage +HighJump +HorseRace +HorseRiding +HulaHoop +IceDancing +JavelinThrow +JugglingBalls +JumpingJack +JumpRope +Kayaking +Knitting +LongJump +Lunges +MilitaryParade +Mixing +MoppingFloor +Nunchucks +ParallelBars +PizzaTossing +PlayingCello +PlayingDaf +PlayingDhol +PlayingFlute +PlayingGuitar +PlayingPiano +PlayingSitar +PlayingTabla +PlayingViolin +PoleVault +PommelHorse +PullUps +Punch +PushUps +Rafting +RockClimbingIndoor +RopeClimbing +Rowing +SalsaSpin +ShavingBeard +Shotput +SkateBoarding +Skiing +Skijet +SkyDiving +SoccerJuggling +SoccerPenalty +StillRings +SumoWrestling +Surfing +Swing +TableTennisShot +TaiChi +TennisSwing +ThrowDiscus +TrampolineJumping +Typing +UnevenBars +VolleyballSpiking +WalkingWithDog +WallPushups +WritingOnBoard +YoYo diff --git a/torchvision/prototype/datasets/_builtin/ucf101.py b/torchvision/prototype/datasets/_builtin/ucf101.py new file mode 100644 index 00000000000..386527544a9 --- /dev/null +++ b/torchvision/prototype/datasets/_builtin/ucf101.py @@ -0,0 +1,94 @@ +import csv +import pathlib +from typing import Any, Dict, List, Tuple, cast, BinaryIO + +from torch.utils.data import IterDataPipe +from torch.utils.data.datapipes.iter import Filter, Mapper +from torchdata.datapipes.iter import CSVParser, IterKeyZipper +from torchvision.prototype.datasets.utils import ( + Dataset, + DatasetConfig, + DatasetInfo, + HttpResource, + OnlineResource, +) +from torchvision.prototype.datasets.utils._internal import ( + path_accessor, + path_comparator, + hint_sharding, + hint_shuffling, + INFINITE_BUFFER_SIZE, +) +from torchvision.prototype.features import Label, EncodedVideo + +csv.register_dialect("ucf101", delimiter=" ") + + +class UCF101(Dataset): + def _make_info(self) -> DatasetInfo: + return DatasetInfo( + "ucf101", + dependencies=("rarfile",), + valid_options=dict( + split=("train", "test"), + fold=("1", "2", "3"), + ), + homepage="https://www.crcv.ucf.edu/data/UCF101.php", + ) + + def _extract_videos_archive(self, path: pathlib.Path) -> pathlib.Path: + folder = OnlineResource._extract(path) + for rar_file in folder.glob("*.rar"): + OnlineResource._extract(rar_file) + rar_file.unlink() + return folder + + def resources(self, config: DatasetConfig) -> List[OnlineResource]: + url_root = "https://www.crcv.ucf.edu/data/UCF101/" + + splits = HttpResource( + f"{url_root}/UCF101TrainTestSplits-RecognitionTask.zip", + sha256="5c0d1a53b8ed364a2ac830a73f405e51bece7d98ce1254fd19ed4a36b224bd27", + ) + + videos = HttpResource( + f"{url_root}/UCF101.rar", + sha256="ca8dfadb4c891cb11316f94d52b6b0ac2a11994e67a0cae227180cd160bd8e55", + extract=True, + ) + videos._preprocess = self._extract_videos_archive + + return [splits, videos] + + def _prepare_sample(self, data: Tuple[Tuple[str, str], Tuple[str, BinaryIO]]) -> Dict[str, Any]: + _, (path, buffer) = data + path = pathlib.Path(path) + return dict( + label=Label.from_category(path.parent.name, categories=self.categories), + video=EncodedVideo.from_file(buffer, path=path), + ) + + def _make_datapipe( + self, resource_dps: List[IterDataPipe], *, config: DatasetConfig + ) -> IterDataPipe[Dict[str, Any]]: + splits_dp, images_dp = resource_dps + + splits_dp: IterDataPipe[Tuple[str, BinaryIO]] = Filter( + splits_dp, path_comparator("name", f"{config.split}list0{config.fold}.txt") + ) + splits_dp = CSVParser(splits_dp, dialect="ucf101") + splits_dp = hint_sharding(splits_dp) + splits_dp = hint_shuffling(splits_dp) + + dp = IterKeyZipper(splits_dp, images_dp, path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE) + return Mapper(dp, self._prepare_sample) + + def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: + config = self.default_config + resources = self.resources(config) + + dp = resources[0].load(root) + dp = Filter(dp, path_comparator("name", "classInd.txt")) + dp = CSVParser(dp, dialect="ucf101") + _, categories = zip(*dp) + return cast(Tuple[str, ...], categories) diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index 64a1a5f3326..f92d7b82e4c 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -33,10 +33,10 @@ def from_file(cls: Type[E], file: BinaryIO, **meta: Any) -> E: return cls(fromfile(file, dtype=torch.uint8, byte_order=sys.byteorder), **meta) @classmethod - def from_path(cls: Type[E], path: Union[str, os.PathLike]) -> E: + def from_path(cls: Type[E], path: Union[str, os.PathLike], **meta: Any) -> E: path = pathlib.Path(path) with open(path, "rb") as file: - return cls.from_file(file, path=path) + return cls.from_file(file, path=path, **meta) class EncodedImage(EncodedData): From 32f813f900707ecc2150d8e016eaca31275bb250 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Feb 2022 15:02:33 +0100 Subject: [PATCH 3/9] appease mypy --- mypy.ini | 4 ++++ torchvision/prototype/datasets/_builtin/hmdb51.py | 5 ++--- torchvision/prototype/datasets/_builtin/ucf101.py | 2 +- torchvision/prototype/datasets/utils/_video.py | 13 ++++++------- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mypy.ini b/mypy.ini index 6d7863b627e..fc2741e28f4 100644 --- a/mypy.ini +++ b/mypy.ini @@ -155,3 +155,7 @@ ignore_missing_imports = True [mypy-h5py.*] ignore_missing_imports = True + +[mypy-rarfile.*] + +ignore_missing_imports = True diff --git a/torchvision/prototype/datasets/_builtin/hmdb51.py b/torchvision/prototype/datasets/_builtin/hmdb51.py index 5e9c905d60a..1f2a404c247 100644 --- a/torchvision/prototype/datasets/_builtin/hmdb51.py +++ b/torchvision/prototype/datasets/_builtin/hmdb51.py @@ -58,7 +58,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _is_split_number(self, data: Tuple[str, Any], *, split_number: str) -> bool: path = pathlib.Path(data[0]) - return self._SPLIT_FILE_PATTERN.match(path.name)["split_number"] == split_number # type: ignore[union-attr] + return self._SPLIT_FILE_PATTERN.match(path.name)["split_number"] == split_number # type: ignore[index] _SPLIT_ID_TO_NAME = { "1": "train", @@ -111,7 +111,6 @@ def _generate_categories(self, root: pathlib.Path) -> List[str]: dp = resources[0].load(root) categories = { - self._SPLIT_FILE_PATTERN.match(pathlib.Path(path).name)["category"] # type: ignore[union-attr] - for path, _ in dp + self._SPLIT_FILE_PATTERN.match(pathlib.Path(path).name)["category"] for path, _ in dp # type: ignore[index] } return sorted(categories) diff --git a/torchvision/prototype/datasets/_builtin/ucf101.py b/torchvision/prototype/datasets/_builtin/ucf101.py index 386527544a9..1653475ce59 100644 --- a/torchvision/prototype/datasets/_builtin/ucf101.py +++ b/torchvision/prototype/datasets/_builtin/ucf101.py @@ -88,7 +88,7 @@ def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: resources = self.resources(config) dp = resources[0].load(root) - dp = Filter(dp, path_comparator("name", "classInd.txt")) + dp: IterDataPipe[Tuple[str, BinaryIO]] = Filter(dp, path_comparator("name", "classInd.txt")) dp = CSVParser(dp, dialect="ucf101") _, categories = zip(*dp) return cast(Tuple[str, ...], categories) diff --git a/torchvision/prototype/datasets/utils/_video.py b/torchvision/prototype/datasets/utils/_video.py index 1cd2650f206..9bde430f817 100644 --- a/torchvision/prototype/datasets/utils/_video.py +++ b/torchvision/prototype/datasets/utils/_video.py @@ -1,5 +1,5 @@ import random -from typing import Any, Dict, Iterator, BinaryIO, Optional, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple import av import numpy as np @@ -15,7 +15,7 @@ def __init__(self, datapipe: IterDataPipe, *, inline: bool = True) -> None: self.datapipe = datapipe self._inline = inline - def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: raise NotImplementedError def _find_encoded_video(self, id: Tuple[Any, ...], obj: Any) -> Optional[Tuple[Any, ...]]: @@ -65,13 +65,12 @@ def __iter__(self) -> Iterator[Any]: raise ValueError("more than one encoded video") id, video = ids_and_videos[0] - buffer = ReadOnlyTensorBuffer(video) - for data in self._decode(buffer, video.meta.copy()): + for data in self._decode(ReadOnlyTensorBuffer(video), video.meta.copy()): yield self._integrate_data(sample, id, data) class KeyframeDecoder(_VideoDecoder): - def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: with av.open(buffer, metadata_errors="ignore") as container: stream = container.streams.video[0] stream.codec_context.skip_frame = "NONKEY" @@ -92,7 +91,7 @@ def __init__(self, datapipe: IterDataPipe, *, num_samples: int = 1, inline: bool super().__init__(datapipe, inline=inline) self.num_sampler = num_samples - def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: with av.open(buffer, metadata_errors="ignore") as container: stream = container.streams.video[0] # duration is given in time_base units as int @@ -147,7 +146,7 @@ def _unfold(self, tensor: torch.Tensor, dilation: int = 1) -> torch.Tensor: new_size = (0, self.num_frames_per_clip) return torch.as_strided(tensor, new_size, new_stride) - def _decode(self, buffer: BinaryIO, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: with av.open(buffer, metadata_errors="ignore") as container: stream = container.streams.video[0] time_base = stream.time_base From d37df0184b188e398a26c154dcd2748750c3a7d2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Feb 2022 10:16:52 +0100 Subject: [PATCH 4/9] fix resource loading --- .../prototype/datasets/_builtin/ucf101.py | 2 +- .../prototype/datasets/utils/_resource.py | 33 ++++++++++++------- 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/ucf101.py b/torchvision/prototype/datasets/_builtin/ucf101.py index 1653475ce59..30c600ddcd2 100644 --- a/torchvision/prototype/datasets/_builtin/ucf101.py +++ b/torchvision/prototype/datasets/_builtin/ucf101.py @@ -51,10 +51,10 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: sha256="5c0d1a53b8ed364a2ac830a73f405e51bece7d98ce1254fd19ed4a36b224bd27", ) + # The SSL certificate of the server is currently invalid, but downloading "unsafe" data is not supported yet videos = HttpResource( f"{url_root}/UCF101.rar", sha256="ca8dfadb4c891cb11316f94d52b6b0ac2a11994e67a0cae227180cd160bd8e55", - extract=True, ) videos._preprocess = self._extract_videos_archive diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 294c0c9099b..bc57db22802 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -22,6 +22,7 @@ download_file_from_google_drive, _get_redirect_url, _get_google_drive_file_id, + tqdm, ) @@ -88,20 +89,30 @@ def load( root = pathlib.Path(root) path = root / self.file_name # Instead of the raw file, there might also be files with fewer suffixes after decompression or directories - # with no suffixes at all. Thus, we look for all paths that share the same name without suffixes as the raw - # file. - path_candidates = {file for file in path.parent.glob(path.name.replace("".join(path.suffixes), "") + "*")} - # If we don't find anything, we try to download the raw file. - if not path_candidates: - path_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)} + # with no suffixes at all. + stem = path.name.replace("".join(path.suffixes), "") + + # In a first step, we check for a folder with the same stem as the raw file. If it exists, we use it since + # extracted files give the best I/O performance. Note that OnlineResource._extract() makes sure that an archive + # is always extracted in a folder with the corresponding file name. + folder_candidate = path.parent / stem + if folder_candidate.exists() and folder_candidate.is_dir(): + return self._loader(path) + + # If there is no folder, we look for all files that share the same stem as the raw file, but might have a + # different suffix. + file_candidates = {file for file in path.parent.glob(stem + ".*")} + # If we don't find anything, we download the raw file. + if not file_candidates: + file_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)} # If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps. - if path_candidates == {path}: + if file_candidates == {path}: if self._preprocess is not None: path = self._preprocess(path) - # Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority - # that we want. + # Otherwise, we use the path with the fewest suffixes. This gives us the decompressed > raw priority that we + # want for the best I/O performance. else: - path = min(path_candidates, key=lambda path: len(path.suffixes)) + path = min(file_candidates, key=lambda path: len(path.suffixes)) return self._loader(path) @abc.abstractmethod @@ -119,7 +130,7 @@ def download(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool def _check_sha256(self, path: pathlib.Path, *, chunk_size: int = 1024 * 1024) -> None: hash = hashlib.sha256() with open(path, "rb") as file: - for chunk in iter(lambda: file.read(chunk_size), b""): + for chunk in tqdm(iter(lambda: file.read(chunk_size), b"")): hash.update(chunk) sha256 = hash.hexdigest() if sha256 != self.sha256: From 9696a9db096932b619db7b59795e1f892460eead Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Feb 2022 11:48:46 +0100 Subject: [PATCH 5/9] add av to optional prototype requirements in CI --- .circleci/config.yml | 2 +- .circleci/config.yml.in | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index f1ddaf861ac..c3aaefc196e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -351,7 +351,7 @@ jobs: - install_torchvision - install_prototype_dependencies - pip_install: - args: scipy pycocotools h5py + args: scipy pycocotools h5py av descr: Install optional dependencies - run: name: Enable prototype tests diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 4bd2e14147a..bdf4f438f69 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -351,7 +351,7 @@ jobs: - install_torchvision - install_prototype_dependencies - pip_install: - args: scipy pycocotools h5py + args: scipy pycocotools h5py av descr: Install optional dependencies - run: name: Enable prototype tests From 0a4b477648f5125f8741f99580ce3216dd6b356e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Feb 2022 23:18:38 +0100 Subject: [PATCH 6/9] add tests --- .circleci/config.yml | 2 +- .circleci/config.yml.in | 2 +- test/builtin_dataset_mocks.py | 97 +++++++++++++++++-- .../prototype/datasets/_builtin/hmdb51.py | 13 +-- .../prototype/datasets/_builtin/ucf101.py | 4 +- .../prototype/datasets/utils/_resource.py | 2 +- 6 files changed, 99 insertions(+), 21 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index c3aaefc196e..a5e8217ccce 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -351,7 +351,7 @@ jobs: - install_torchvision - install_prototype_dependencies - pip_install: - args: scipy pycocotools h5py av + args: scipy pycocotools h5py av rarfile descr: Install optional dependencies - run: name: Enable prototype tests diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index bdf4f438f69..fa955aeca0e 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -351,7 +351,7 @@ jobs: - install_torchvision - install_prototype_dependencies - pip_install: - args: scipy pycocotools h5py av + args: scipy pycocotools h5py av rarfile descr: Install optional dependencies - run: name: Enable prototype tests diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 123d8f29d3f..71ed8584b5a 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -9,6 +9,7 @@ import pathlib import pickle import random +import unittest.mock import xml.etree.ElementTree as ET from collections import defaultdict, Counter @@ -16,11 +17,10 @@ import PIL.Image import pytest import torch -from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file +from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, create_video_folder from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor from torchvision.prototype.datasets._api import find -from torchvision.prototype.utils._internal import sequence_to_str make_tensor = functools.partial(_make_tensor, device="cpu") make_scalar = functools.partial(make_tensor, ()) @@ -67,14 +67,15 @@ def prepare(self, home, config): mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config)) - available_file_names = {path.name for path in root.glob("*")} - required_file_names = {resource.file_name for resource in self.dataset.resources(config)} - missing_file_names = required_file_names - available_file_names - if missing_file_names: - raise pytest.UsageError( - f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " - f"for {config}, but they were not created by the mock data function." - ) + for resource in self.dataset.resources(config): + with unittest.mock.patch( + "torchvision.prototype.datasets.utils._resource.OnlineResource.download", + side_effect=TypeError( + f"Dataset '{self.name}' requires the file {resource.file_name} for {config}, " + f"but it was not created by the mock data function." + ), + ): + resource.load(root) return mock_info @@ -1344,3 +1345,79 @@ def pcam(info, root, config): compressed_file.write(compressed_data) return num_images + + +@register_mock +def ucf101(info, root, config): + video_folder = root / "UCF101" / "UCF-101" + + categories_and_labels = [ + ("ApplyEyeMakeup", 0), + ("LongJump", 50), + ("YoYo", 100), + ] + + def file_name_fn(cls, idx, clips_per_group=2): + return f"v_{cls}_g{(idx // clips_per_group) + 1:02d}_c{(idx % clips_per_group) + 1:02d}.avi" + + video_files = [ + create_video_folder( + video_folder, category, lambda idx: file_name_fn(category, idx), num_examples=int(torch.randint(1, 6, ())) + ) + for category, _ in categories_and_labels + ] + + splits_folder = root / "ucfTrainTestList" + splits_folder.mkdir() + + with open(splits_folder / "classInd.txt", "w") as file: + file.write("\n".join(f"{label} {category}" for category, label in categories_and_labels) + "\n") + + video_ids = [path.relative_to(video_folder).as_posix() for path in itertools.chain.from_iterable(video_files)] + splits = ("train", "test") + num_samples_map = {} + for fold in range(1, 4): + random.shuffle(video_ids) + for offset, split in enumerate(splits): + video_ids_in_config = video_ids[offset :: len(splits)] + with open(splits_folder / f"{split}list{fold:02d}.txt", "w") as file: + file.write("\n".join(video_ids_in_config) + "\n") + + num_samples_map[info.make_config(split=split, fold=str(fold))] = len(video_ids_in_config) + + make_zip(root, "UCF101TrainTestSplits-RecognitionTask.zip", splits_folder) + + return num_samples_map[config] + + +@register_mock +def hmdb51(info, root, config): + video_folder = root / "hmdb51_org" + + categories = [ + "brush_hair", + "pour", + "wave", + ] + + video_files = { + category: create_video_folder( + video_folder, category, lambda idx: f"{category}_{idx}.avi", num_examples=int(torch.randint(3, 10, ())) + ) + for category in categories + } + + splits_folder = root / "test_train_splits" / "testTrainMulti_7030_splits" + splits_folder.mkdir(parents=True) + + num_samples_map = defaultdict(lambda: 0) + for category, fold in itertools.product(categories, range(1, 4)): + videos = video_files[category] + + with open(splits_folder / f"{category}_test_split{fold}.txt", "w") as file: + file.write("\n".join(f"{path.name} {idx % 3}" for idx, path in enumerate(videos)) + "\n") + + for split, split_id in (("train", 1), ("test", 2)): + num_samples_map[info.make_config(split=split, fold=str(fold))] += len(videos[split_id::3]) + + return num_samples_map[config] diff --git a/torchvision/prototype/datasets/_builtin/hmdb51.py b/torchvision/prototype/datasets/_builtin/hmdb51.py index 1f2a404c247..ce37a0045b5 100644 --- a/torchvision/prototype/datasets/_builtin/hmdb51.py +++ b/torchvision/prototype/datasets/_builtin/hmdb51.py @@ -29,7 +29,7 @@ def _make_info(self) -> DatasetInfo: dependencies=("rarfile",), valid_options=dict( split=("train", "test"), - split_number=("1", "2", "3"), + fold=("1", "2", "3"), ), ) @@ -54,11 +54,11 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: videos._preprocess = self._extract_videos_archive return [splits, videos] - _SPLIT_FILE_PATTERN = re.compile(r"(?P\w+?)_test_split(?P[1-3])[.]txt") + _SPLIT_FILE_PATTERN = re.compile(r"(?P\w+?)_test_split(?P[1-3])[.]txt") - def _is_split_number(self, data: Tuple[str, Any], *, split_number: str) -> bool: + def _is_fold(self, data: Tuple[str, Any], *, fold: str) -> bool: path = pathlib.Path(data[0]) - return self._SPLIT_FILE_PATTERN.match(path.name)["split_number"] == split_number # type: ignore[index] + return self._SPLIT_FILE_PATTERN.match(path.name)["fold"] == fold # type: ignore[index] _SPLIT_ID_TO_NAME = { "1": "train", @@ -68,7 +68,8 @@ def _is_split_number(self, data: Tuple[str, Any], *, split_number: str) -> bool: def _is_split(self, data: Dict[str, Any], *, split: str) -> bool: split_id = data["split_id"] - # TODO: explain + # In addition to split id 1 and 2 corresponding to the train and test splits, some videos are annotated with + # split id 0, which indicates that the video is not included in either split if split_id not in self._SPLIT_ID_TO_NAME: return False @@ -90,7 +91,7 @@ def _make_datapipe( ) -> IterDataPipe[Dict[str, Any]]: splits_dp, videos_dp = resource_dps - splits_dp = Filter(splits_dp, functools.partial(self._is_split_number, split_number=config.split_number)) + splits_dp = Filter(splits_dp, functools.partial(self._is_fold, fold=config.fold)) splits_dp = CSVDictParser(splits_dp, fieldnames=("filename", "split_id"), delimiter=" ") splits_dp = Filter(splits_dp, functools.partial(self._is_split, split=config.split)) splits_dp = hint_sharding(splits_dp) diff --git a/torchvision/prototype/datasets/_builtin/ucf101.py b/torchvision/prototype/datasets/_builtin/ucf101.py index 30c600ddcd2..dfc7652468b 100644 --- a/torchvision/prototype/datasets/_builtin/ucf101.py +++ b/torchvision/prototype/datasets/_builtin/ucf101.py @@ -71,7 +71,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, str], Tuple[str, BinaryIO]]) -> def _make_datapipe( self, resource_dps: List[IterDataPipe], *, config: DatasetConfig ) -> IterDataPipe[Dict[str, Any]]: - splits_dp, images_dp = resource_dps + splits_dp, videos_dp = resource_dps splits_dp: IterDataPipe[Tuple[str, BinaryIO]] = Filter( splits_dp, path_comparator("name", f"{config.split}list0{config.fold}.txt") @@ -80,7 +80,7 @@ def _make_datapipe( splits_dp = hint_sharding(splits_dp) splits_dp = hint_shuffling(splits_dp) - dp = IterKeyZipper(splits_dp, images_dp, path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE) + dp = IterKeyZipper(splits_dp, videos_dp, path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE) return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index bc57db22802..7edd130c414 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -97,7 +97,7 @@ def load( # is always extracted in a folder with the corresponding file name. folder_candidate = path.parent / stem if folder_candidate.exists() and folder_candidate.is_dir(): - return self._loader(path) + return self._loader(folder_candidate) # If there is no folder, we look for all files that share the same stem as the raw file, but might have a # different suffix. From 673fc5a185a024120f02062b663d0f5e35a8b925 Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Thu, 17 Feb 2022 14:34:43 +0000 Subject: [PATCH 7/9] basic changes to support various video backends --- .../prototype/datasets/utils/_video.py | 173 ++++++++++-------- 1 file changed, 100 insertions(+), 73 deletions(-) diff --git a/torchvision/prototype/datasets/utils/_video.py b/torchvision/prototype/datasets/utils/_video.py index 9bde430f817..37703ec7262 100644 --- a/torchvision/prototype/datasets/utils/_video.py +++ b/torchvision/prototype/datasets/utils/_video.py @@ -1,17 +1,20 @@ import random +import warnings from typing import Any, Dict, Iterator, Optional, Tuple import av import numpy as np import torch from torchdata.datapipes.iter import IterDataPipe -from torchvision.io import video, _video_opt +from torchvision import get_video_backend +from torchvision.io import video, _video_opt, VideoReader from torchvision.prototype.features import Image, EncodedVideo from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer, query_recursively class _VideoDecoder(IterDataPipe): def __init__(self, datapipe: IterDataPipe, *, inline: bool = True) -> None: + # TODO: add gpu support self.datapipe = datapipe self._inline = inline @@ -71,6 +74,9 @@ def __iter__(self) -> Iterator[Any]: class KeyframeDecoder(_VideoDecoder): def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: + if get_video_backend() == "video_reader": + warnings.warn("Video reader API not implemented for keyframes yet, reverting to PyAV") + with av.open(buffer, metadata_errors="ignore") as container: stream = container.streams.video[0] stream.codec_context.skip_frame = "NONKEY" @@ -92,24 +98,42 @@ def __init__(self, datapipe: IterDataPipe, *, num_samples: int = 1, inline: bool self.num_sampler = num_samples def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: - with av.open(buffer, metadata_errors="ignore") as container: - stream = container.streams.video[0] - # duration is given in time_base units as int - duration = stream.duration - # seek to a random frame - seek_idxs = random.sample(list(range(duration)), self.num_samples) + if get_video_backend() == "video_reader": + vid = VideoReader(buffer, device=self.device) + # seek and return frames + metadata = vid.get_metadata()["video"] + duration = metadata["duration"][0] if self.device == "cpu" else metadata["duration"] + fps = metadata["fps"][0] if self.device == "cpu" else metadata["fps"] + max_seek = duration - (self.clip_len / fps + 0.1) # FIXME: random param + seek_idxs = random.sample(list(range(max_seek)), self.num_samples) for i in seek_idxs: - container.seek(i, any_frame=True, stream=stream) - frame = next(container.decode(stream)) + vid.seek(i) + frame = vid.next() yield dict( - frame=Image.from_pil(frame.to_image()), - pts=frame.pts, + frame=frame['data'], + pts = frame['pts'], video_meta=dict( - time_base=float(frame.time_base), - guessed_fps=float(stream.guessed_rate), + guessed_fps=fps, ), ) - + else: + with av.open(buffer, metadata_errors="ignore") as container: + stream = container.streams.video[0] + # duration is given in time_base units as int + duration = stream.duration + # seek to a random frame + seek_idxs = random.sample(list(range(duration)), self.num_samples) + for i in seek_idxs: + container.seek(i, any_frame=True, stream=stream) + frame = next(container.decode(stream)) + yield dict( + frame=Image.from_pil(frame.to_image()), + pts=frame.pts, + video_meta=dict( + time_base=float(frame.time_base), + guessed_fps=float(stream.guessed_rate), + ), + ) class ClipDecoder(_VideoDecoder): def __init__( @@ -147,62 +171,65 @@ def _unfold(self, tensor: torch.Tensor, dilation: int = 1) -> torch.Tensor: return torch.as_strided(tensor, new_size, new_stride) def _decode(self, buffer: ReadOnlyTensorBuffer, meta: Dict[str, Any]) -> Iterator[Dict[str, Any]]: - with av.open(buffer, metadata_errors="ignore") as container: - stream = container.streams.video[0] - time_base = stream.time_base - - # duration is given in time_base units as int - duration = stream.duration - - # get video_stream timestramps - # with a tolerance for pyav imprecission - _ptss = torch.arange(duration - 7) - _ptss = self._unfold(_ptss) - # shuffle the clips - perm = torch.randperm(_ptss.size(0)) - idx = perm[: self.num_clips_per_video] - samples = _ptss[idx] - - for clip_pts in samples: - start_pts = clip_pts[0].item() - end_pts = clip_pts[-1].item() - # video_timebase is the default time_base - pts_unit = "pts" - start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, "pts", time_base) - video_frames = video._read_from_stream( - container, - float(start_pts), - float(end_pts), - pts_unit, - stream, - {"video": 0}, - ) - - vframes_list = [frame.to_ndarray(format="rgb24") for frame in video_frames] - - if vframes_list: - vframes = torch.as_tensor(np.stack(vframes_list)) - # account for rounding errors in conversion - # FIXME: fix this in the code - vframes = vframes[: self.num_frames_per_clip, ...] - - else: - vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) - print("FAIL") - - # [N,H,W,C] to [N,C,H,W] - vframes = vframes.permute(0, 3, 1, 2) - assert vframes.size(0) == self.num_frames_per_clip - - # TODO: support sampling rates (FPS change) - # TODO: optimization (read all and select) - - yield { - "clip": vframes, - "pts": clip_pts, - "range": (start_pts, end_pts), - "video_meta": { - "time_base": float(stream.time_base), - "guessed_fps": float(stream.guessed_rate), - }, - } + if get_video_backend() == "video_reader": + pass + else: + with av.open(buffer, metadata_errors="ignore") as container: + stream = container.streams.video[0] + time_base = stream.time_base + + # duration is given in time_base units as int + duration = stream.duration + + # get video_stream timestramps + # with a tolerance for pyav imprecission + _ptss = torch.arange(duration - 7) + _ptss = self._unfold(_ptss) + # shuffle the clips + perm = torch.randperm(_ptss.size(0)) + idx = perm[: self.num_clips_per_video] + samples = _ptss[idx] + + for clip_pts in samples: + start_pts = clip_pts[0].item() + end_pts = clip_pts[-1].item() + # video_timebase is the default time_base + pts_unit = "pts" + start_pts, end_pts, pts_unit = _video_opt._convert_to_sec(start_pts, end_pts, "pts", time_base) + video_frames = video._read_from_stream( + container, + float(start_pts), + float(end_pts), + pts_unit, + stream, + {"video": 0}, + ) + + vframes_list = [frame.to_ndarray(format="rgb24") for frame in video_frames] + + if vframes_list: + vframes = torch.as_tensor(np.stack(vframes_list)) + # account for rounding errors in conversion + # FIXME: fix this in the code + vframes = vframes[: self.num_frames_per_clip, ...] + + else: + vframes = torch.empty((0, 1, 1, 3), dtype=torch.uint8) + print("FAIL") + + # [N,H,W,C] to [N,C,H,W] + vframes = vframes.permute(0, 3, 1, 2) + assert vframes.size(0) == self.num_frames_per_clip + + # TODO: support sampling rates (FPS change) + # TODO: optimization (read all and select) + + yield { + "clip": vframes, + "pts": clip_pts, + "range": (start_pts, end_pts), + "video_meta": { + "time_base": float(stream.time_base), + "guessed_fps": float(stream.guessed_rate), + }, + } From 0e7dcc9ed5ac480fdd61be161410b50254109f6f Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Thu, 17 Feb 2022 14:55:19 +0000 Subject: [PATCH 8/9] skeleton for implementing videousils tests --- test/test_prototype_videoutils.py | 83 +++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 test/test_prototype_videoutils.py diff --git a/test/test_prototype_videoutils.py b/test/test_prototype_videoutils.py new file mode 100644 index 00000000000..73a6611ab28 --- /dev/null +++ b/test/test_prototype_videoutils.py @@ -0,0 +1,83 @@ +import math +import os + +import pytest +import torch +from torchvision.io import _HAS_VIDEO_DECODER, _HAS_VIDEO_OPT, VideoReader + +try: + import av +except ImportError: + av = None + +VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") + + +@pytest.mark.skipif(av is None, reason="PyAV unavailable") +class TestVideoDatasetUtils: + # TODO: atm we separate backends in order to allow for testing on different systems; + # once we have things packaged we should add this as test parametrisation + # (this also applies for GPU decoding as well) + + @pytest.mark.parametrize( + "video_file", + [ + "RATRACE_wave_f_nm_np1_fr_goo_37.avi", + "TrumanShow_wave_f_nm_np1_fr_med_26.avi", + "v_SoccerJuggling_g23_c01.avi", + "v_SoccerJuggling_g24_c01.avi", + "R6llTwEh07w.mp4", + "SOX5yA1l24A.mp4", + "WUzgd7C1pWA.mp4", + ], + ) + def test_random_decoder_av(self, video_file): + """Read a sequence of random frames from a video + Checks that files are valid video frames and no error is thrown during decoding. + """ + pass + + def test_random_decoder_cpu(self, video_file): + """Read a sequence of random frames from a video using CPU backend + Checks that files are valid video frames and no error is thrown during decoding, + and compares them to `pyav` output. + """ + pass + + def test_random_decoder_GPU(self, video_file): + """Read a sequence of random frames from a video using GPU backend + Checks that files are valid video frames and no error is thrown during decoding, + and compares them to `pyav` output. + """ + pass + + def test_keyframe_decoder_av(self, video_file): + """Read all keyframes from a video; + Compare the output to naive keyframe reading with `pyav` + """ + pass + + def test_keyframe_decoder_cpu(self, video_file): + """Read all keyframes from a video using CPU backend; + ATM should raise a warning and default to `pyav` + TODO: should we fail or default to a working backend + """ + pass + + def test_keyframe_decoder_GPU(self, video_file): + """Read all keyframes from a video using CPU backend; + ATM should raise a warning and default to `pyav` + TODO: should we fail or default to a working backend + """ + pass + + def test_clip_decoder(self, video_file): + """ATM very crude test: + check only if fails, or if the clip sampling is correct, + don't bother with the content just yet. + """ + pass + + +if __name__ == "__main__": + pytest.main([__file__]) From cb79fbca64548a1309247dc5a24890f9b3128ae8 Mon Sep 17 00:00:00 2001 From: Bruno Korbar Date: Fri, 25 Feb 2022 13:34:39 +0000 Subject: [PATCH 9/9] add testing barebones --- test/test_prototype_videoutils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_videoutils.py b/test/test_prototype_videoutils.py index 73a6611ab28..a59c453786a 100644 --- a/test/test_prototype_videoutils.py +++ b/test/test_prototype_videoutils.py @@ -4,7 +4,9 @@ import pytest import torch from torchvision.io import _HAS_VIDEO_DECODER, _HAS_VIDEO_OPT, VideoReader - +from torchvision.prototype.features import EncodedData +from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer +from torchvision.prototype.datasets.utils._video import KeyframeDecoder, RandomFrameDecoder try: import av except ImportError: @@ -35,6 +37,9 @@ def test_random_decoder_av(self, video_file): """Read a sequence of random frames from a video Checks that files are valid video frames and no error is thrown during decoding. """ + video_file = os.path.join(VIDEO_DIR, video_file) + video = ReadOnlyTensorBuffer(EncodedData.from_path(video_file)) + print(next(video)) pass def test_random_decoder_cpu(self, video_file):