Skip to content

add HMDB51 and UCF101 datasets as well as prototype for new style video decoding #5335

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -125,3 +125,7 @@ ignore_missing_imports = True
[mypy-h5py.*]

ignore_missing_imports = True

[mypy-rarfile.*]

ignore_missing_imports = True
10 changes: 10 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import bz2
import contextlib
import gzip
import hashlib
import itertools
Expand Down Expand Up @@ -301,6 +302,15 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
".tgz": (".tar", ".gz"),
}

with contextlib.suppress(ImportError):
import rarfile

def _extract_rar(from_path: str, to_path: str, compression: Optional[str]) -> None:
with rarfile.RarFile(from_path, "r") as rar:
rar.extractall(to_path)

_ARCHIVE_EXTRACTORS[".rar"] = _extract_rar


def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datasets/_builtin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from .dtd import DTD
from .fer2013 import FER2013
from .gtsrb import GTSRB
from .hmdb51 import HMDB51
from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
from .oxford_iiit_pet import OxfordIITPet
from .pcam import PCAM
from .sbd import SBD
from .semeion import SEMEION
from .svhn import SVHN
from .ucf101 import UCF101
from .voc import VOC
51 changes: 51 additions & 0 deletions torchvision/prototype/datasets/_builtin/hmdb51.categories
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
brush_hair
cartwheel
catch
chew
clap
climb
climb_stairs
dive
draw_sword
dribble
drink
eat
fall_floor
fencing
flic_flac
golf
handstand
hit
hug
jump
kick
kick_ball
kiss
laugh
pick
pour
pullup
punch
push
pushup
ride_bike
ride_horse
run
shake_hands
shoot_ball
shoot_bow
shoot_gun
sit
situp
smile
smoke
somersault
stand
swing_baseball
sword
sword_exercise
talk
throw
turn
walk
wave
116 changes: 116 additions & 0 deletions torchvision/prototype/datasets/_builtin/hmdb51.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import functools
import pathlib
import re
from typing import Any, Dict, List, Tuple, BinaryIO

from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, IterKeyZipper
from torchvision.prototype.datasets.utils import (
Dataset,
DatasetConfig,
DatasetInfo,
HttpResource,
OnlineResource,
)
from torchvision.prototype.datasets.utils._internal import (
INFINITE_BUFFER_SIZE,
getitem,
path_accessor,
hint_sharding,
hint_shuffling,
)
from torchvision.prototype.features import EncodedVideo, Label


class HMDB51(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"hmdb51",
homepage="https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/",
dependencies=("rarfile",),
valid_options=dict(
split=("train", "test"),
split_number=("1", "2", "3"),
),
)

def _extract_videos_archive(self, path: pathlib.Path) -> pathlib.Path:
folder = OnlineResource._extract(path)
for rar_file in folder.glob("*.rar"):
OnlineResource._extract(rar_file)
rar_file.unlink()
return folder

def resources(self, config: DatasetConfig) -> List[OnlineResource]:
url_root = "https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10"

splits = HttpResource(
f"{url_root}/test_train_splits.rar",
sha256="229c94f845720d01eb3946d39f39292ea962d50a18136484aa47c1eba251d2b7",
)
videos = HttpResource(
f"{url_root}/hmdb51_org.rar",
sha256="9e714a0d8b76104d76e932764a7ca636f929fff66279cda3f2e326fa912a328e",
)
videos._preprocess = self._extract_videos_archive
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug The archive is a rar of rars so using a single extract=True won't cut it. We need the full extraction since reading from rar archives is rather slow and with this we get a significant performance increase.

Another option would be to use this "recursive extraction" by default when setting extract=True.

return [splits, videos]

_SPLIT_FILE_PATTERN = re.compile(r"(?P<category>\w+?)_test_split(?P<split_number>[1-3])[.]txt")

def _is_split_number(self, data: Tuple[str, Any], *, split_number: str) -> bool:
path = pathlib.Path(data[0])
return self._SPLIT_FILE_PATTERN.match(path.name)["split_number"] == split_number # type: ignore[index]

_SPLIT_ID_TO_NAME = {
"1": "train",
"2": "test",
}

def _is_split(self, data: Dict[str, Any], *, split: str) -> bool:
split_id = data["split_id"]

# TODO: explain
if split_id not in self._SPLIT_ID_TO_NAME:
return False

return self._SPLIT_ID_TO_NAME[split_id] == split

def _prepare_sample(self, data: Tuple[List[str], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
_, (path, buffer) = data
path = pathlib.Path(path)
return dict(
label=Label.from_category(path.parent.name, categories=self.categories),
video=EncodedVideo.from_file(buffer, path=path),
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
splits_dp, videos_dp = resource_dps

splits_dp = Filter(splits_dp, functools.partial(self._is_split_number, split_number=config.split_number))
splits_dp = CSVDictParser(splits_dp, fieldnames=("filename", "split_id"), delimiter=" ")
splits_dp = Filter(splits_dp, functools.partial(self._is_split, split=config.split))
splits_dp = hint_sharding(splits_dp)
splits_dp = hint_shuffling(splits_dp)

dp = IterKeyZipper(
splits_dp,
videos_dp,
key_fn=getitem("filename"),
ref_key_fn=path_accessor("name"),
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)

def _generate_categories(self, root: pathlib.Path) -> List[str]:
config = self.default_config
resources = self.resources(config)

dp = resources[0].load(root)
categories = {
self._SPLIT_FILE_PATTERN.match(pathlib.Path(path).name)["category"] for path, _ in dp # type: ignore[index]
}
return sorted(categories)
101 changes: 101 additions & 0 deletions torchvision/prototype/datasets/_builtin/ucf101.categories
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
ApplyEyeMakeup
ApplyLipstick
Archery
BabyCrawling
BalanceBeam
BandMarching
BaseballPitch
Basketball
BasketballDunk
BenchPress
Biking
Billiards
BlowDryHair
BlowingCandles
BodyWeightSquats
Bowling
BoxingPunchingBag
BoxingSpeedBag
BreastStroke
BrushingTeeth
CleanAndJerk
CliffDiving
CricketBowling
CricketShot
CuttingInKitchen
Diving
Drumming
Fencing
FieldHockeyPenalty
FloorGymnastics
FrisbeeCatch
FrontCrawl
GolfSwing
Haircut
Hammering
HammerThrow
HandstandPushups
HandstandWalking
HeadMassage
HighJump
HorseRace
HorseRiding
HulaHoop
IceDancing
JavelinThrow
JugglingBalls
JumpingJack
JumpRope
Kayaking
Knitting
LongJump
Lunges
MilitaryParade
Mixing
MoppingFloor
Nunchucks
ParallelBars
PizzaTossing
PlayingCello
PlayingDaf
PlayingDhol
PlayingFlute
PlayingGuitar
PlayingPiano
PlayingSitar
PlayingTabla
PlayingViolin
PoleVault
PommelHorse
PullUps
Punch
PushUps
Rafting
RockClimbingIndoor
RopeClimbing
Rowing
SalsaSpin
ShavingBeard
Shotput
SkateBoarding
Skiing
Skijet
SkyDiving
SoccerJuggling
SoccerPenalty
StillRings
SumoWrestling
Surfing
Swing
TableTennisShot
TaiChi
TennisSwing
ThrowDiscus
TrampolineJumping
Typing
UnevenBars
VolleyballSpiking
WalkingWithDog
WallPushups
WritingOnBoard
YoYo
Loading