Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions test/builtin_dataset_mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torch.testing import make_tensor as _make_tensor
from torchdata.datapipes.iter import IterDataPipe
from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER
from torchvision.prototype.datasets._api import find
from torchvision.prototype.utils._internal import add_suggestion

Expand Down Expand Up @@ -109,21 +108,15 @@ def _get(self, dataset, config, root):
self._cache[(name, config)] = mock_resources, mock_info
return mock_resources, mock_info

def load(
self, name: str, decoder=DEFAULT_DECODER, split="train", **options: Any
) -> Tuple[IterDataPipe, Dict[str, Any]]:
def load(self, name: str, **options: Any) -> Tuple[IterDataPipe, Dict[str, Any]]:
dataset = find(name)
config = dataset.info.make_config(split=split, **options)
config = dataset.info.make_config(**options)

root = self._tmp_home / name
root.mkdir(exist_ok=True)
resources, mock_info = self._get(dataset, config, root)

datapipe = dataset._make_datapipe(
[resource.load(root) for resource in resources],
config=config,
decoder=DEFAULT_DECODER_MAP.get(dataset.info.type) if decoder is DEFAULT_DECODER else decoder,
)
datapipe = dataset._make_datapipe([resource.load(root) for resource in resources], config=config)
return datapipe, mock_info


Expand Down
40 changes: 34 additions & 6 deletions test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
import functools
import io

import builtin_dataset_mocks
import pytest
import torch
from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair, UnsupportedInputs, ErrorMeta
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision.prototype import datasets, transforms
from torchvision.prototype.datasets._api import DEFAULT_DECODER
from torchvision.prototype.utils._internal import sequence_to_str


def to_bytes(file):
return file.read()
# TODO: remove this patch after https://github.com/pytorch/pytorch/pull/70304 is merged
def patch(fn):
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except ErrorMeta as error:
if error.type is not ValueError:
raise error

raise UnsupportedInputs()

return wrapper


TensorLikePair._to_tensor = patch(TensorLikePair._to_tensor)


assert_samples_equal = functools.partial(
assert_equal, pair_types=(TensorLikePair, ObjectPair), rtol=0, atol=0, equal_nan=True
)


def config_id(name, config):
Expand All @@ -26,7 +45,7 @@ def config_id(name, config):
return "-".join(parts)


def dataset_parametrization(*names, decoder=to_bytes):
def dataset_parametrization(*names):
if not names:
# TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported
names = (
Expand All @@ -46,7 +65,7 @@ def dataset_parametrization(*names, decoder=to_bytes):
return pytest.mark.parametrize(
("dataset", "mock_info"),
[
pytest.param(*builtin_dataset_mocks.load(name, decoder=decoder, **config), id=config_id(name, config))
pytest.param(*builtin_dataset_mocks.load(name, **config), id=config_id(name, config))
for name in names
for config in datasets.info(name)._configs
],
Expand Down Expand Up @@ -89,7 +108,7 @@ def test_decoding(self, dataset, mock_info):
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
)

@dataset_parametrization(decoder=DEFAULT_DECODER)
@dataset_parametrization()
def test_no_vanilla_tensors(self, dataset, mock_info):
vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor}
if vanilla_tensors:
Expand Down Expand Up @@ -120,6 +139,15 @@ def scan(graph):
else:
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")

@dataset_parametrization()
def test_save_load(self, dataset, mock_info):
sample = next(iter(dataset))

with io.BytesIO() as buffer:
torch.save(sample, buffer)
buffer.seek(0)
assert_samples_equal(torch.load(buffer), sample)


class TestQMNIST:
@pytest.mark.parametrize(
Expand Down
15 changes: 3 additions & 12 deletions test/test_prototype_datasets_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch


def make_minimal_dataset_info(name="name", type=datasets.utils.DatasetType.RAW, categories=None, **kwargs):
return datasets.utils.DatasetInfo(name, type=type, categories=categories or [], **kwargs)
def make_minimal_dataset_info(name="name", categories=None, **kwargs):
return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs)


class TestFrozenMapping:
Expand Down Expand Up @@ -188,7 +188,7 @@ def resources(self, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation
pass

def _make_datapipe(self, resource_dps, *, config, decoder):
def _make_datapipe(self, resource_dps, *, config):
# This method is just defined to appease the ABC, but will be overwritten at instantiation
pass

Expand Down Expand Up @@ -241,12 +241,3 @@ def test_resources(self, mocker):

(call_args, _) = dataset._make_datapipe.call_args
assert call_args[0][0] is sentinel

def test_decoder(self):
dataset = self.DatasetMock()

sentinel = object()
dataset.load("", decoder=sentinel)

(_, call_kwargs) = dataset._make_datapipe.call_args
assert call_kwargs["decoder"] is sentinel
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"Note that you cannot install it with `pip install torchdata`, since this is another package."
) from error

from . import decoder, utils
from . import utils
from ._home import home

# Load this last, since some parts depend on the above being loaded first
Expand Down
23 changes: 3 additions & 20 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import io
import os
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List

import torch
from torch.utils.data import IterDataPipe
from torchvision.prototype.datasets import home
from torchvision.prototype.datasets.decoder import raw, pil
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetType
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo
from torchvision.prototype.utils._internal import add_suggestion

from . import _builtin
Expand Down Expand Up @@ -49,28 +46,14 @@ def info(name: str) -> DatasetInfo:
return find(name).info


DEFAULT_DECODER = object()

DEFAULT_DECODER_MAP: Dict[DatasetType, Callable[[io.IOBase], torch.Tensor]] = {
DatasetType.RAW: raw,
DatasetType.IMAGE: pil,
}


def load(
name: str,
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment]
skip_integrity_check: bool = False,
split: str = "train",
**options: Any,
) -> IterDataPipe[Dict[str, Any]]:
dataset = find(name)

if decoder is DEFAULT_DECODER:
decoder = DEFAULT_DECODER_MAP.get(dataset.info.type)

config = dataset.info.make_config(split=split, **options)
root = os.path.join(home(), dataset.name)

return dataset.load(root, config=config, decoder=decoder, skip_integrity_check=skip_integrity_check)
return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check)
60 changes: 22 additions & 38 deletions torchvision/prototype/datasets/_builtin/caltech.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import functools
import io
import pathlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Tuple, BinaryIO

import numpy as np
import torch
from torchdata.datapipes.iter import (
IterDataPipe,
Mapper,
Expand All @@ -18,7 +15,7 @@
DatasetInfo,
HttpResource,
OnlineResource,
DatasetType,
RawImage,
)
from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling
from torchvision.prototype.features import Label, BoundingBox, Feature
Expand All @@ -28,7 +25,6 @@ class Caltech101(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"caltech101",
type=DatasetType.IMAGE,
dependencies=("scipy",),
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101",
)
Expand Down Expand Up @@ -81,32 +77,26 @@ def _anns_key_fn(self, data: Tuple[str, Any]) -> Tuple[str, str]:

return category, id

def _collate_and_decode_sample(
self,
data: Tuple[Tuple[str, str], Tuple[Tuple[str, io.IOBase], Tuple[str, io.IOBase]]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
def _decode_ann(self, data: BinaryIO) -> Dict[str, Any]:
ann = read_mat(data)
return dict(
bounding_box=BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy"),
contour=Feature(ann["obj_contour"].T),
)

def _prepare_sample(
self, data: Tuple[Tuple[str, str], Tuple[Tuple[str, BinaryIO], Tuple[str, BinaryIO]]]
) -> Dict[str, Any]:
key, (image_data, ann_data) = data
category, _ = key
image_path, image_buffer = image_data
ann_path, ann_buffer = ann_data

label = self.info.categories.index(category)

image = decoder(image_buffer) if decoder else image_buffer

ann = read_mat(ann_buffer)
bbox = BoundingBox(ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy")
contour = Feature(ann["obj_contour"].T)

return dict(
category=category,
label=label,
image=image,
self._decode_ann(ann_buffer),
label=Label(self.info.categories.index(category), category=category),
image_path=image_path,
bbox=bbox,
contour=contour,
image=RawImage.fromfile(image_buffer),
ann_path=ann_path,
)

Expand All @@ -115,7 +105,6 @@ def _make_datapipe(
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
images_dp, anns_dp = resource_dps

Expand All @@ -133,7 +122,7 @@ def _make_datapipe(
buffer_size=INFINITE_BUFFER_SIZE,
keep_key=True,
)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return Mapper(dp, self._prepare_sample)

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand All @@ -145,7 +134,6 @@ class Caltech256(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"caltech256",
type=DatasetType.IMAGE,
homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256",
)

Expand All @@ -161,32 +149,28 @@ def _is_not_rogue_file(self, data: Tuple[str, Any]) -> bool:
path = pathlib.Path(data[0])
return path.name != "RENAME2"

def _collate_and_decode_sample(
self,
data: Tuple[str, io.IOBase],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]:
path, buffer = data

dir_name = pathlib.Path(path).parent.name
label_str, category = dir_name.split(".")
label = Label(int(label_str), category=category)

return dict(label=label, image=decoder(buffer) if decoder else buffer)
return dict(
path=path,
image=RawImage.fromfile(buffer),
label=Label(int(label_str), category=category),
)

def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, self._is_not_rogue_file)
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
return Mapper(dp, self._prepare_sample)

def _generate_categories(self, root: pathlib.Path) -> List[str]:
dp = self.resources(self.default_config)[0].load(pathlib.Path(root) / self.name)
Expand Down
Loading