Skip to content

remove decoding from prototype datasets #5287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 47 additions & 43 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,50 +431,52 @@ def caltech256(info, root, config):

@register_mock
def imagenet(info, root, config):
wnids = tuple(info.extra.wnid_to_category.keys())
if config.split == "train":
images_root = root / "ILSVRC2012_img_train"
from scipy.io import savemat
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While working on the fix for the validation split, I realized that the data setup was slightly wrong.


categories = info.categories
wnids = [info.extra.category_to_wnid[category] for category in categories]
if config.split == "train":
num_samples = len(wnids)
archive_name = "ILSVRC2012_img_train.tar"

files = []
for wnid in wnids:
files = create_image_folder(
root=images_root,
create_image_folder(
root=root,
name=wnid,
file_name_fn=lambda image_idx: f"{wnid}_{image_idx:04d}.JPEG",
num_examples=1,
)
make_tar(images_root, f"{wnid}.tar", files[0].parent)
files.append(make_tar(root, f"{wnid}.tar"))
elif config.split == "val":
num_samples = 3
files = create_image_folder(
root=root,
name="ILSVRC2012_img_val",
file_name_fn=lambda image_idx: f"ILSVRC2012_val_{image_idx + 1:08d}.JPEG",
num_examples=num_samples,
)
images_root = files[0].parent
else: # config.split == "test"
images_root = root / "ILSVRC2012_img_test_v10102019"
archive_name = "ILSVRC2012_img_val.tar"
files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)]

num_samples = 3
devkit_root = root / "ILSVRC2012_devkit_t12"
data_root = devkit_root / "data"
data_root.mkdir(parents=True)

create_image_folder(
root=images_root,
name="test",
file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG",
num_examples=num_samples,
)
make_tar(root, f"{images_root.name}.tar", images_root)
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")

num_children = 0
synsets = [
(idx, wnid, category, "", num_children, [], 0, 0)
for idx, (category, wnid) in enumerate(zip(categories, wnids), 1)
]
num_children = 1
synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5))
savemat(data_root / "meta.mat", dict(synsets=synsets))

make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz")
else: # config.split == "test"
num_samples = 5
archive_name = "ILSVRC2012_img_test_v10102019.tar"
files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)]

devkit_root = root / "ILSVRC2012_devkit_t12"
devkit_root.mkdir()
data_root = devkit_root / "data"
data_root.mkdir()
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
make_tar(root, archive_name, *files)

return num_samples

Expand Down Expand Up @@ -666,14 +668,15 @@ def sbd(info, root, config):
@register_mock
def semeion(info, root, config):
num_samples = 3
num_categories = len(info.categories)

images = torch.rand(num_samples, 256)
labels = one_hot(torch.randint(len(info.categories), size=(num_samples,)))
labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories)
with open(root / "semeion.data", "w") as fh:
for image, one_hot_label in zip(images, labels):
image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image])
labels_columns = " ".join([str(label.item()) for label in one_hot_label])
fh.write(f"{image_columns} {labels_columns}\n")
fh.write(f"{image_columns} {labels_columns} \n")

return num_samples

Expand Down Expand Up @@ -728,32 +731,33 @@ def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples):
def _make_detection_ann_file(cls, root, name):
def add_child(parent, name, text=None):
child = ET.SubElement(parent, name)
child.text = text
child.text = str(text)
return child

def add_name(obj, name="dog"):
add_child(obj, "name", name)
return name

def add_bndbox(obj, bndbox=None):
if bndbox is None:
bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"}
def add_size(obj):
obj = add_child(obj, "size")
size = {"width": 0, "height": 0, "depth": 3}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VOC provides the image size together with the annotations. Since the reworked BoundingBox requires the image size, we need to add it to the mock data.

for name, text in size.items():
add_child(obj, name, text)

def add_bndbox(obj):
obj = add_child(obj, "bndbox")
bndbox = {"xmin": 1, "xmax": 2, "ymin": 3, "ymax": 4}
for name, text in bndbox.items():
add_child(obj, name, text)

return bndbox

annotation = ET.Element("annotation")
add_size(annotation)
obj = add_child(annotation, "object")
data = dict(name=add_name(obj), bndbox=add_bndbox(obj))
add_name(obj)
add_bndbox(obj)

with open(root / name, "wb") as fh:
fh.write(ET.tostring(annotation))

return data

@classmethod
def generate(cls, root, *, year, trainval):
archive_folder = root
Expand Down
21 changes: 20 additions & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import functools
import io
from pathlib import Path

import pytest
import torch
from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision.prototype import transforms, datasets
from torchvision.prototype.utils._internal import sequence_to_str


assert_samples_equal = functools.partial(
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
)


@pytest.fixture
def test_home(mocker, tmp_path):
mocker.patch("torchvision.prototype.datasets._api.home", return_value=str(tmp_path))
Expand Down Expand Up @@ -92,6 +99,7 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config):
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
)

@pytest.mark.xfail
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
Expand Down Expand Up @@ -137,6 +145,17 @@ def scan(graph):
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")

@parametrize_dataset_mocks(DATASET_MOCKS)
def test_save_load(self, test_home, dataset_mock, config):
dataset_mock.prepare(test_home, config)
dataset = datasets.load(dataset_mock.name, **config)
sample = next(iter(dataset))

with io.BytesIO() as buffer:
torch.save(sample, buffer)
buffer.seek(0)
assert_samples_equal(torch.load(buffer), sample)


@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST:
Expand Down Expand Up @@ -171,5 +190,5 @@ def test_label_matches_path(self, test_home, dataset_mock, config):
dataset = datasets.load(dataset_mock.name, **config)

for sample in dataset:
label_from_path = int(Path(sample["image_path"]).parent.name)
label_from_path = int(Path(sample["path"]).parent.name)
assert sample["label"] == label_from_path
15 changes: 3 additions & 12 deletions test/test_prototype_datasets_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch


def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs):
return datasets.utils.DatasetInfo(name, type=type, categories=categories or [], **kwargs)
def make_minimal_dataset_info(name="name", categories=None, **kwargs):
return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs)


class TestFrozenMapping:
Expand Down Expand Up @@ -176,7 +176,7 @@ def resources(self, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation
pass

def _make_datapipe(self, resource_dps, *, config, decoder):
def _make_datapipe(self, resource_dps, *, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation
pass

Expand Down Expand Up @@ -229,12 +229,3 @@ def test_resources(self, mocker):

(call_args, _) = dataset._make_datapipe.call_args
assert call_args[0][0] is sentinel

def test_decoder(self):
dataset = self.DatasetMock()

sentinel = object()
dataset.load("", decoder=sentinel)

(_, call_kwargs) = dataset._make_datapipe.call_args
assert call_kwargs["decoder"] is sentinel
61 changes: 0 additions & 61 deletions test/test_prototype_transforms.py

This file was deleted.

2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"Note that you cannot install it with `pip install torchdata`, since this is another package."
) from error

from . import decoder, utils
from . import utils
from ._home import home

# Load this last, since some parts depend on the above being loaded first
Expand Down
21 changes: 3 additions & 18 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import io
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List

import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import raw, pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.utils._internal import add_suggestion

from . import _builtin
Expand Down Expand Up @@ -49,27 +46,15 @@ def info(name: str) -> DatasetInfo:
return find(name).info


DEFAULT_DECODER = object()

DEFAULT_DECODER_MAP: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
DatasetType.RAW: raw,
DatasetType.IMAGE: pil,
}


def load(
name: str,
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment]
skip_integrity_check: bool = False,
**options: Any,
) -> IterDataPipe[Dict[str, Any]]:
dataset = find(name)

if decoder is DEFAULT_DECODER:
decoder = DEFAULT_DECODER_MAP.get(dataset.info.type)

config = dataset.info.make_config(**options)
root = os.path.join(home(), dataset.name)

return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check)
return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check)
Loading