Skip to content
Merged
100 changes: 54 additions & 46 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
import pathlib
import pickle
import random
import unittest.mock
import xml.etree.ElementTree as ET
from collections import defaultdict, Counter

import numpy as np
import PIL.Image
import pytest
import torch
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor
from torchvision.prototype.datasets._api import find
from torchvision.prototype import datasets
from torchvision.prototype.utils._internal import sequence_to_str

make_tensor = functools.partial(_make_tensor, device="cpu")
Expand All @@ -30,13 +31,11 @@


class DatasetMock:
def __init__(self, name, mock_data_fn):
self.dataset = find(name)
self.info = self.dataset.info
self.name = self.info.name

def __init__(self, name, *, mock_data_fn, configs):
# FIXME: error handling for unknown names
self.name = name
self.mock_data_fn = mock_data_fn
self.configs = self.info._configs
self.configs = configs

def _parse_mock_info(self, mock_info):
if mock_info is None:
Expand Down Expand Up @@ -65,10 +64,13 @@ def prepare(self, home, config):
root = home / self.name
root.mkdir(exist_ok=True)

mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config))
mock_info = self._parse_mock_info(self.mock_data_fn(root, config))

with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"):
required_file_names = {
resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources()
}
available_file_names = {path.name for path in root.glob("*")}
required_file_names = {resource.file_name for resource in self.dataset.resources(config)}
missing_file_names = required_file_names - available_file_names
if missing_file_names:
raise pytest.UsageError(
Expand Down Expand Up @@ -123,10 +125,16 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None):
DATASET_MOCKS = {}


def register_mock(fn):
name = fn.__name__.replace("_", "-")
DATASET_MOCKS[name] = DatasetMock(name, fn)
return fn
def register_mock(name=None, *, configs):
def wrapper(mock_data_fn):
nonlocal name
if name is None:
name = mock_data_fn.__name__
DATASET_MOCKS[name] = DatasetMock(name, mock_data_fn=mock_data_fn, configs=configs)

return mock_data_fn

return wrapper


class MNISTMockData:
Expand Down Expand Up @@ -204,7 +212,7 @@ def generate(
return num_samples


@register_mock
# @register_mock
def mnist(info, root, config):
train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
Expand All @@ -217,10 +225,10 @@ def mnist(info, root, config):
)


DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
# DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})


@register_mock
# @register_mock
def emnist(info, root, config):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there.
Expand All @@ -247,7 +255,7 @@ def emnist(info, root, config):
return num_samples_map[config]


@register_mock
# @register_mock
def qmnist(info, root, config):
num_categories = len(info.categories)
if config.split == "train":
Expand Down Expand Up @@ -324,7 +332,7 @@ def generate(
make_tar(root, name, folder, compression="gz")


@register_mock
# @register_mock
def cifar10(info, root, config):
train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
test_files = ["test_batch"]
Expand All @@ -342,7 +350,7 @@ def cifar10(info, root, config):
return len(train_files if config.split == "train" else test_files)


@register_mock
# @register_mock
def cifar100(info, root, config):
train_files = ["train"]
test_files = ["test"]
Expand All @@ -360,7 +368,7 @@ def cifar100(info, root, config):
return len(train_files if config.split == "train" else test_files)


@register_mock
# @register_mock
def caltech101(info, root, config):
def create_ann_file(root, name):
import scipy.io
Expand Down Expand Up @@ -410,7 +418,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples):
return num_images_per_category * len(info.categories)


@register_mock
# @register_mock
def caltech256(info, root, config):
dir = root / "256_ObjectCategories"
num_images_per_category = 2
Expand All @@ -430,26 +438,26 @@ def caltech256(info, root, config):
return num_images_per_category * len(info.categories)


@register_mock
def imagenet(info, root, config):
@register_mock(configs=combinations_grid(split=("train", "val", "test")))
def imagenet(root, config):
from scipy.io import savemat

categories = info.categories
wnids = [info.extra.category_to_wnid[category] for category in categories]
if config.split == "train":
num_samples = len(wnids)
info = datasets.info("imagenet")

if config["split"] == "train":
num_samples = len(info["wnids"])
archive_name = "ILSVRC2012_img_train.tar"

files = []
for wnid in wnids:
for wnid in info["wnids"]:
create_image_folder(
root=root,
name=wnid,
file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG",
num_examples=1,
)
files.append(make_tar(root, f"{wnid}.tar"))
elif config.split == "val":
elif config["split"] == "val":
num_samples = 3
archive_name = "ILSVRC2012_img_val.tar"
files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
Expand All @@ -459,20 +467,20 @@ def imagenet(info, root, config):
data_root.mkdir(parents=True)

with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
for label in torch.randint(0, len(info["wnids"]), (num_samples,)).tolist():
file.write(f"{label}\n")

num_children = 0
synsets = [
(idx, wnid, category, "", num_children, [], 0, 0)
for idx, (category, wnid) in enumerate(zip(categories, wnids), 1)
for idx, (category, wnid) in enumerate(zip(info["categories"], info["wnids"]), 1)
]
num_children = 1
synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5))
savemat(data_root / "meta.mat", dict(synsets=synsets))

make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz")
else: # config.split == "test"
else: # config["split"] == "test"
num_samples = 5
archive_name = "ILSVRC2012_img_test_v10102019.tar"
files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)]
Expand Down Expand Up @@ -587,7 +595,7 @@ def generate(
return num_samples


@register_mock
# @register_mock
def coco(info, root, config):
return CocoMockData.generate(root, year=config.year, num_samples=5)

Expand Down Expand Up @@ -661,12 +669,12 @@ def generate(cls, root):
return num_samples_map


@register_mock
# @register_mock
def sbd(info, root, config):
return SBDMockData.generate(root)[config.split]


@register_mock
# @register_mock
def semeion(info, root, config):
num_samples = 3
num_categories = len(info.categories)
Expand Down Expand Up @@ -779,7 +787,7 @@ def generate(cls, root, *, year, trainval):
return num_samples_map


@register_mock
# @register_mock
def voc(info, root, config):
trainval = config.split != "test"
return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split]
Expand Down Expand Up @@ -873,12 +881,12 @@ def generate(cls, root):
return num_samples_map


@register_mock
# @register_mock
def celeba(info, root, config):
return CelebAMockData.generate(root)[config.split]


@register_mock
# @register_mock
def dtd(info, root, config):
data_folder = root / "dtd"

Expand Down Expand Up @@ -926,7 +934,7 @@ def dtd(info, root, config):
return num_samples_map[config]


@register_mock
# @register_mock
def fer2013(info, root, config):
num_samples = 5 if config.split == "train" else 3

Expand All @@ -951,7 +959,7 @@ def fer2013(info, root, config):
return num_samples


@register_mock
# @register_mock
def gtsrb(info, root, config):
num_examples_per_class = 5 if config.split == "train" else 3
classes = ("00000", "00042", "00012")
Expand Down Expand Up @@ -1021,7 +1029,7 @@ def _make_ann_file(path, num_examples, class_idx):
return num_examples


@register_mock
# @register_mock
def clevr(info, root, config):
data_folder = root / "CLEVR_v1.0"

Expand Down Expand Up @@ -1127,7 +1135,7 @@ def generate(self, root):
return num_samples_map


@register_mock
# @register_mock
def oxford_iiit_pet(info, root, config):
return OxfordIIITPetMockData.generate(root)[config.split]

Expand Down Expand Up @@ -1293,13 +1301,13 @@ def generate(cls, root):
return num_samples_map


@register_mock
# @register_mock
def cub200(info, root, config):
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
return num_samples_map[config.split]


@register_mock
# @register_mock
def svhn(info, root, config):
import scipy.io as sio

Expand All @@ -1319,7 +1327,7 @@ def svhn(info, root, config):
return num_samples


@register_mock
# @register_mock
def pcam(info, root, config):
import h5py

Expand Down
Loading