Skip to content

Compatibility layer between stable datasets and prototype transforms #6663

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d6786ac
PoC
pmeier Sep 21, 2022
3a916c8
Merge branch 'main'
pmeier Sep 27, 2022
d77ef0b
Merge branch 'main' into dataset-wrappers
pmeier Sep 28, 2022
63e1148
cleanup
pmeier Sep 28, 2022
13a820c
Merge branch 'main' into dataset-wrappers
pmeier Sep 28, 2022
fb600a7
Merge branch 'main' into dataset-wrappers
pmeier Oct 4, 2022
cae3e71
Merge branch 'main' into dataset-wrappers
pmeier Jan 30, 2023
dbfac05
refactor
pmeier Jan 31, 2023
2dba1c7
handle None label for test set use case
pmeier Jan 31, 2023
bcd7620
minor cleanup
pmeier Jan 31, 2023
f72ed86
Merge branch 'main' into dataset-wrappers
pmeier Feb 1, 2023
fe6be60
minor refactorings
pmeier Feb 1, 2023
cff9092
minor cache refactoring for COCO
pmeier Feb 1, 2023
9965492
remove GenericDatapoint for now
pmeier Feb 1, 2023
a588686
Merge branch 'main' into dataset-wrappers
pmeier Feb 2, 2023
d64e1a9
add all detection and segmentation datasets
pmeier Feb 2, 2023
49cc8e7
add Image/DatasetFolder
pmeier Feb 2, 2023
8e12bad
add video datasets
pmeier Feb 2, 2023
7a9f083
nuke annotations
pmeier Feb 2, 2023
7f7efd5
reinstate transform and target_transform disabling
pmeier Feb 2, 2023
e6f2b68
address minor comments
pmeier Feb 3, 2023
4c3860e
Merge branch 'main' into dataset-wrappers
pmeier Feb 6, 2023
58f21f4
Merge branch 'main' into dataset-wrappers
pmeier Feb 9, 2023
22288ce
remove categories and refactor wrapping architecture
pmeier Feb 9, 2023
a88aec3
add tests
pmeier Feb 9, 2023
ce740c1
cleanup
pmeier Feb 9, 2023
edad790
Merge branch 'main' into dataset-wrappers
pmeier Feb 9, 2023
3398822
remove GenericDatapoint
pmeier Feb 9, 2023
b565426
Merge branch 'dataset-wrappers' of https://github.com/pmeier/vision i…
pmeier Feb 9, 2023
a236f9c
Merge branch 'main' into dataset-wrappers
pmeier Feb 9, 2023
331a66d
move wrapper instantiation into the class
pmeier Feb 9, 2023
48405b8
use decorator registering everywhere
pmeier Feb 9, 2023
0286238
hard depend on wrapper in stable tests
pmeier Feb 9, 2023
be42cc9
remove target type wrapping default
pmeier Feb 9, 2023
e3c4d50
make test more strict
pmeier Feb 9, 2023
351becb
fix cityscapes instance return
pmeier Feb 9, 2023
8ed41ba
add comment for two stage design
pmeier Feb 9, 2023
f0e1af7
Merge branch 'main' into dataset-wrappers
pmeier Feb 9, 2023
dbebe40
Merge branch 'main' into dataset-wrappers
pmeier Feb 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion torchvision/prototype/datapoints/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from ._bounding_box import BoundingBox, BoundingBoxFormat
from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT
from ._datapoint import FillType, FillTypeJIT, GenericDatapoint, InputType, InputTypeJIT
from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT
from ._label import Label, OneHotLabel
from ._mask import Mask
from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT

from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip
4 changes: 4 additions & 0 deletions torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,7 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N

InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint]
InputTypeJIT = torch.Tensor


class GenericDatapoint(Datapoint):
pass
250 changes: 250 additions & 0 deletions torchvision/prototype/datapoints/_dataset_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
# type: ignore

from __future__ import annotations

import contextlib

import functools
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple

import PIL.Image
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F

__all__ = ["wrap_dataset_for_transforms_v2"]

_WRAPPERS = {}


# TODO: naming!
def wrap_dataset_for_transforms_v2(dataset: datasets.VisionDataset) -> _VisionDatasetDatapointWrapper:
wrapper = _WRAPPERS.get(type(dataset))
if wrapper is None:
raise TypeError
return _VisionDatasetDatapointWrapper(dataset, wrapper)


class _VisionDatasetDatapointWrapper(Dataset):
def __init__(self, dataset: datasets.VisionDataset, wrapper) -> None:
self.vision_dataset = dataset
self.wrapper = wrapper

# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply the
# transforms
self.transform, dataset.transform = dataset.transform, None
self.target_transform, dataset.target_transform = dataset.target_transform, None
self.transforms, dataset.transforms = dataset.transforms, None

def __getattr__(self, item: str) -> Any:
with contextlib.suppress(AttributeError):
return object.__getattribute__(self, item)

return getattr(self.vision_dataset, item)

def __getitem__(self, idx: int) -> Any:
# This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
# of this class
sample = self.vision_dataset[idx]

sample = self.wrapper(self.vision_dataset, sample)

# We don't need to care about `transform` and `target_transform` here since `VisionDataset` joins them into a
# `transforms` internally:
# https://github.com/pytorch/vision/blob/2d92728341bbd3dc1e0f1e86c6a436049bbb3403/torchvision/datasets/vision.py#L52-L54
if self.transforms is not None:
sample = self.transforms(*sample)

return sample

def __len__(self) -> int:
return len(self.vision_dataset)


@functools.lru_cache(maxsize=None)
def get_categories(dataset: datasets.VisionDataset) -> Optional[List[str]]:
categories_fn = {
datasets.Caltech256: lambda dataset: [name.rsplit(".", 1)[1] for name in dataset.categories],
datasets.CIFAR10: lambda dataset: dataset.classes,
datasets.CIFAR100: lambda dataset: dataset.classes,
datasets.FashionMNIST: lambda dataset: dataset.classes,
datasets.ImageNet: lambda dataset: [", ".join(names) for names in dataset.classes],
}.get(type(dataset))
return categories_fn(dataset) if categories_fn is not None else None


def classification_wrapper(
dataset: datasets.VisionDataset, sample: Tuple[PIL.Image.Image, Optional[int]]
) -> Tuple[PIL.Image.Image, Optional[datapoints.Label]]:
image, label = sample
if label is not None:
label = datapoints.Label(label, categories=get_categories(dataset))
return image, label


for dataset_type in [
datasets.Caltech256,
datasets.CIFAR10,
datasets.CIFAR100,
datasets.ImageNet,
datasets.MNIST,
datasets.FashionMNIST,
]:
_WRAPPERS[dataset_type] = classification_wrapper


def segmentation_wrapper(
dataset: datasets.VisionDataset, sample: Tuple[PIL.Image.Image, PIL.Image.Image]
) -> Tuple[PIL.Image.Image, datapoints.Mask]:
image, mask = sample
return image, datapoints.Mask(F.to_image_tensor(mask))


for dataset_type in [
datasets.VOCSegmentation,
]:
_WRAPPERS[dataset_type] = segmentation_wrapper


def caltech101_wrapper(
dataset: datasets.Caltech101, sample: Tuple[PIL.Image.Image, Any]
) -> Tuple[PIL.Image.Image, Any]:
image, target = sample

target_type_wrapper_map: Dict[str, Callable] = {
"category": lambda label: datapoints.Label(label, categories=dataset.categories),
"annotation": datapoints.GenericDatapoint,
}
if len(dataset.target_type) == 1:
target = target_type_wrapper_map[dataset.target_type[0]](target)
else:
target = tuple(target_type_wrapper_map[typ](item) for typ, item in zip(dataset.target_type, target))

return image, target


_WRAPPERS[datasets.Caltech101] = caltech101_wrapper


def coco_dectection_wrapper(
dataset: datasets.CocoDetection, sample: Tuple[PIL.Image.Image, List[Dict[str, Any]]]
) -> Tuple[PIL.Image.Image, Dict[str, List[Any]]]:
idx_to_category = {idx: cat["name"] for idx, cat in dataset.coco.cats.items()}
idx_to_category[0] = "__background__"
for idx in set(range(91)) - idx_to_category.keys():
idx_to_category[idx] = "N/A"

categories = [category for _, category in sorted(idx_to_category.items())]

def segmentation_to_mask(segmentation: Any, *, spatial_size: Tuple[int, int]) -> torch.Tensor:
from pycocotools import mask

segmentation = (
mask.frPyObjects(segmentation, *spatial_size)
if isinstance(segmentation, dict)
else mask.merge(mask.frPyObjects(segmentation, *spatial_size))
)
return torch.from_numpy(mask.decode(segmentation))

image, target = sample

# Originally, COCODetection returns a list of dicts in which each dict represents an object instance on the image.
# However, our transforms and models expect all instance annotations grouped together, if applicable as tensor with
# batch dimension. Thus, we are changing the target to a dict of lists here.
batched_target = defaultdict(list)
for object in target:
for key, value in object.items():
batched_target[key].append(value)

spatial_size = tuple(F.get_spatial_size(image))
batched_target = dict(
batched_target,
boxes=datapoints.BoundingBox(
batched_target["bbox"],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
),
masks=datapoints.Mask(
torch.stack(
[
segmentation_to_mask(segmentation, spatial_size=spatial_size)
for segmentation in batched_target["segmentation"]
]
),
),
labels=datapoints.Label(batched_target["category_id"], categories=categories),
)

return image, batched_target


_WRAPPERS[datasets.CocoDetection] = coco_dectection_wrapper
_WRAPPERS[datasets.CocoCaptions] = lambda sample: sample


def voc_detection_wrapper(
dataset: datasets.VOCDetection, sample: Tuple[PIL.Image.Image, Any]
) -> Tuple[PIL.Image.Image, Any]:
categories = [
"__background__",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
categories_to_idx = dict(zip(categories, range(len(categories))))

image, target = sample

batched_instances = defaultdict(list)
for object in target["annotation"]["object"]:
for key, value in object.items():
batched_instances[key].append(value)

target["boxes"] = datapoints.BoundingBox(
[[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bndbox in batched_instances["bndbox"]],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=tuple(int(target["annotation"]["size"][dim]) for dim in ("height", "width")),
)
target["labels"] = datapoints.Label(
[categories_to_idx[category] for category in batched_instances["name"]],
categories=categories,
)

return image, target


_WRAPPERS[datasets.VOCDetection] = voc_detection_wrapper


def sbd_wrapper(dataset: datasets.SBDataset, sample: Tuple[PIL.Image.Image, Any]) -> Tuple[PIL.Image.Image, Any]:
image, target = sample

if dataset.mode == "boundaries":
target = datapoints.GenericDatapoint(target)
else:
target = datapoints.Mask(F.to_image_tensor(target))

return image, target


_WRAPPERS[datasets.SBDataset] = sbd_wrapper