Skip to content

Commit 2341bad

Browse files
authored
Merge branch 'main' into prabhat00155/wide_resnet_update
2 parents 157b9e7 + 4715e2e commit 2341bad

File tree

7 files changed

+427
-33
lines changed

7 files changed

+427
-33
lines changed

docs/source/datasets.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ You can also create your own datasets using the provided :ref:`base classes <bas
4343
FashionMNIST
4444
Flickr8k
4545
Flickr30k
46+
FlyingChairs
47+
FlyingThings3D
4648
HMDB51
4749
ImageNet
4850
INaturalist

test/datasets_utils.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import random
99
import shutil
1010
import string
11+
import struct
1112
import tarfile
1213
import unittest
1314
import unittest.mock
@@ -203,7 +204,6 @@ class DatasetTestCase(unittest.TestCase):
203204
``transforms``, or ``download``.
204205
- REQUIRED_PACKAGES (Iterable[str]): Additional dependencies to use the dataset. If these packages are not
205206
available, the tests are skipped.
206-
- EXTRA_PATCHES(set): Additional patches to add for each test, to e.g. mock a specific function
207207
208208
Additionally, you need to overwrite the ``inject_fake_data()`` method that provides the data that the tests rely on.
209209
The fake data should resemble the original data as close as necessary, while containing only few examples. During
@@ -255,8 +255,6 @@ def test_baz(self):
255255
ADDITIONAL_CONFIGS = None
256256
REQUIRED_PACKAGES = None
257257

258-
EXTRA_PATCHES = None
259-
260258
# These keyword arguments are checked by test_transforms in case they are available in DATASET_CLASS.
261259
_TRANSFORM_KWARGS = {
262260
"transform",
@@ -382,17 +380,14 @@ def create_dataset(
382380
if patch_checks:
383381
patchers.update(self._patch_checks())
384382

385-
if self.EXTRA_PATCHES is not None:
386-
patchers.update(self.EXTRA_PATCHES)
387-
388383
with get_tmp_dir() as tmpdir:
389384
args = self.dataset_args(tmpdir, complete_config)
390385
info = self._inject_fake_data(tmpdir, complete_config) if inject_fake_data else None
391386

392387
with self._maybe_apply_patches(patchers), disable_console_output():
393388
dataset = self.DATASET_CLASS(*args, **complete_config, **special_kwargs)
394389

395-
yield dataset, info
390+
yield dataset, info
396391

397392
@classmethod
398393
def setUpClass(cls):
@@ -922,3 +917,26 @@ def create_random_string(length: int, *digits: str) -> str:
922917
digits = "".join(itertools.chain(*digits))
923918

924919
return "".join(random.choice(digits) for _ in range(length))
920+
921+
922+
def make_fake_pfm_file(h, w, file_name):
923+
values = list(range(3 * h * w))
924+
# Note: we pack everything in little endian: -1.0, and "<"
925+
content = f"PF \n{w} {h} \n-1.0\n".encode() + struct.pack("<" + "f" * len(values), *values)
926+
with open(file_name, "wb") as f:
927+
f.write(content)
928+
929+
930+
def make_fake_flo_file(h, w, file_name):
931+
"""Creates a fake flow file in .flo format."""
932+
# Everything needs to be in little Endian according to
933+
# https://vision.middlebury.edu/flow/code/flow-code/README.txt
934+
values = list(range(2 * h * w))
935+
content = (
936+
struct.pack("<4c", *(c.encode() for c in "PIEH"))
937+
+ struct.pack("<i", w)
938+
+ struct.pack("<i", h)
939+
+ struct.pack("<" + "f" * len(values), *values)
940+
)
941+
with open(file_name, "wb") as f:
942+
f.write(content)

test/test_datasets.py

Lines changed: 129 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1874,11 +1874,9 @@ def _inject_pairs(self, root, num_pairs, same):
18741874
class SintelTestCase(datasets_utils.ImageDatasetTestCase):
18751875
DATASET_CLASS = datasets.Sintel
18761876
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "test"), pass_name=("clean", "final"))
1877-
# We patch the flow reader, because this would otherwise force us to generate fake (but readable) .flo files,
1878-
# which is something we want to # avoid.
1879-
_FAKE_FLOW = "Fake Flow"
1880-
EXTRA_PATCHES = {unittest.mock.patch("torchvision.datasets.Sintel._read_flow", return_value=_FAKE_FLOW)}
1881-
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (type(_FAKE_FLOW), type(None)))
1877+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
1878+
1879+
FLOW_H, FLOW_W = 3, 4
18821880

18831881
def inject_fake_data(self, tmpdir, config):
18841882
root = pathlib.Path(tmpdir) / "Sintel"
@@ -1899,14 +1897,13 @@ def inject_fake_data(self, tmpdir, config):
18991897
num_examples=num_images_per_scene,
19001898
)
19011899

1902-
# For the ground truth flow value we just create empty files so that they're properly discovered,
1903-
# see comment above about EXTRA_PATCHES
19041900
flow_root = root / "training" / "flow"
19051901
for scene_id in range(num_scenes):
19061902
scene_dir = flow_root / f"scene_{scene_id}"
19071903
os.makedirs(scene_dir)
19081904
for i in range(num_images_per_scene - 1):
1909-
open(str(scene_dir / f"frame_000{i}.flo"), "a").close()
1905+
file_name = str(scene_dir / f"frame_000{i}.flo")
1906+
datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name)
19101907

19111908
# with e.g. num_images_per_scene = 3, for a single scene with have 3 images
19121909
# which are frame_0000, frame_0001 and frame_0002
@@ -1920,7 +1917,8 @@ def test_flow(self):
19201917
with self.create_dataset(split="train") as (dataset, _):
19211918
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
19221919
for _, _, flow in dataset:
1923-
assert flow == self._FAKE_FLOW
1920+
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
1921+
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
19241922

19251923
# Make sure flow is always None for test split
19261924
with self.create_dataset(split="test") as (dataset, _):
@@ -1929,11 +1927,11 @@ def test_flow(self):
19291927
assert flow is None
19301928

19311929
def test_bad_input(self):
1932-
with pytest.raises(ValueError, match="split must be either"):
1930+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
19331931
with self.create_dataset(split="bad"):
19341932
pass
19351933

1936-
with pytest.raises(ValueError, match="pass_name must be either"):
1934+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
19371935
with self.create_dataset(pass_name="bad"):
19381936
pass
19391937

@@ -1993,10 +1991,129 @@ def test_flow_and_valid(self):
19931991
assert valid is None
19941992

19951993
def test_bad_input(self):
1996-
with pytest.raises(ValueError, match="split must be either"):
1994+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
19971995
with self.create_dataset(split="bad"):
19981996
pass
19991997

20001998

1999+
class FlyingChairsTestCase(datasets_utils.ImageDatasetTestCase):
2000+
DATASET_CLASS = datasets.FlyingChairs
2001+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=("train", "val"))
2002+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
2003+
2004+
FLOW_H, FLOW_W = 3, 4
2005+
2006+
def _make_split_file(self, root, num_examples):
2007+
# We create a fake split file here, but users are asked to download the real one from the authors website
2008+
split_ids = [1] * num_examples["train"] + [2] * num_examples["val"]
2009+
random.shuffle(split_ids)
2010+
with open(str(root / "FlyingChairs_train_val.txt"), "w+") as split_file:
2011+
for split_id in split_ids:
2012+
split_file.write(f"{split_id}\n")
2013+
2014+
def inject_fake_data(self, tmpdir, config):
2015+
root = pathlib.Path(tmpdir) / "FlyingChairs"
2016+
2017+
num_examples = {"train": 5, "val": 3}
2018+
num_examples_total = sum(num_examples.values())
2019+
2020+
datasets_utils.create_image_folder( # img1
2021+
root,
2022+
name="data",
2023+
file_name_fn=lambda image_idx: f"00{image_idx}_img1.ppm",
2024+
num_examples=num_examples_total,
2025+
)
2026+
datasets_utils.create_image_folder( # img2
2027+
root,
2028+
name="data",
2029+
file_name_fn=lambda image_idx: f"00{image_idx}_img2.ppm",
2030+
num_examples=num_examples_total,
2031+
)
2032+
for i in range(num_examples_total):
2033+
file_name = str(root / "data" / f"00{i}_flow.flo")
2034+
datasets_utils.make_fake_flo_file(h=self.FLOW_H, w=self.FLOW_W, file_name=file_name)
2035+
2036+
self._make_split_file(root, num_examples)
2037+
2038+
return num_examples[config["split"]]
2039+
2040+
@datasets_utils.test_all_configs
2041+
def test_flow(self, config):
2042+
# Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
2043+
# Also make sure the flow is properly decoded
2044+
with self.create_dataset(config=config) as (dataset, _):
2045+
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
2046+
for _, _, flow in dataset:
2047+
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
2048+
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
2049+
2050+
2051+
class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
2052+
DATASET_CLASS = datasets.FlyingThings3D
2053+
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(
2054+
split=("train", "test"), pass_name=("clean", "final", "both"), camera=("left", "right", "both")
2055+
)
2056+
FEATURE_TYPES = (PIL.Image.Image, PIL.Image.Image, (np.ndarray, type(None)))
2057+
2058+
FLOW_H, FLOW_W = 3, 4
2059+
2060+
def inject_fake_data(self, tmpdir, config):
2061+
root = pathlib.Path(tmpdir) / "FlyingThings3D"
2062+
2063+
num_images_per_camera = 3 if config["split"] == "train" else 4
2064+
passes = ("frames_cleanpass", "frames_finalpass")
2065+
splits = ("TRAIN", "TEST")
2066+
letters = ("A", "B", "C")
2067+
subfolders = ("0000", "0001")
2068+
cameras = ("left", "right")
2069+
for pass_name, split, letter, subfolder, camera in itertools.product(
2070+
passes, splits, letters, subfolders, cameras
2071+
):
2072+
current_folder = root / pass_name / split / letter / subfolder
2073+
datasets_utils.create_image_folder(
2074+
current_folder,
2075+
name=camera,
2076+
file_name_fn=lambda image_idx: f"00{image_idx}.png",
2077+
num_examples=num_images_per_camera,
2078+
)
2079+
2080+
directions = ("into_future", "into_past")
2081+
for split, letter, subfolder, direction, camera in itertools.product(
2082+
splits, letters, subfolders, directions, cameras
2083+
):
2084+
current_folder = root / "optical_flow" / split / letter / subfolder / direction / camera
2085+
os.makedirs(str(current_folder), exist_ok=True)
2086+
for i in range(num_images_per_camera):
2087+
datasets_utils.make_fake_pfm_file(self.FLOW_H, self.FLOW_W, file_name=str(current_folder / f"{i}.pfm"))
2088+
2089+
num_cameras = 2 if config["camera"] == "both" else 1
2090+
num_passes = 2 if config["pass_name"] == "both" else 1
2091+
num_examples = (
2092+
(num_images_per_camera - 1) * num_cameras * len(subfolders) * len(letters) * len(splits) * num_passes
2093+
)
2094+
return num_examples
2095+
2096+
@datasets_utils.test_all_configs
2097+
def test_flow(self, config):
2098+
with self.create_dataset(config=config) as (dataset, _):
2099+
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
2100+
for _, _, flow in dataset:
2101+
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
2102+
# We don't check the values because the reshaping and flipping makes it hard to figure out
2103+
2104+
def test_bad_input(self):
2105+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):
2106+
with self.create_dataset(split="bad"):
2107+
pass
2108+
2109+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument pass_name"):
2110+
with self.create_dataset(pass_name="bad"):
2111+
pass
2112+
2113+
with pytest.raises(ValueError, match="Unknown value 'bad' for argument camera"):
2114+
with self.create_dataset(camera="bad"):
2115+
pass
2116+
2117+
20012118
if __name__ == "__main__":
20022119
unittest.main()

torchvision/datasets/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._optical_flow import KittiFlow, Sintel
1+
from ._optical_flow import KittiFlow, Sintel, FlyingChairs, FlyingThings3D
22
from .caltech import Caltech101, Caltech256
33
from .celeba import CelebA
44
from .cifar import CIFAR10, CIFAR100
@@ -74,4 +74,6 @@
7474
"LFWPairs",
7575
"KittiFlow",
7676
"Sintel",
77+
"FlyingChairs",
78+
"FlyingThings3D",
7779
)

0 commit comments

Comments
 (0)