From d6786ac09f98c1ea5c6c380bc7a5880e5e1997d8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 21 Sep 2022 21:56:12 +0200 Subject: [PATCH 01/25] PoC --- torchvision/prototype/features/__init__.py | 4 +- .../prototype/features/_dataset_wrapper.py | 295 ++++++++++++++++++ torchvision/prototype/features/_feature.py | 4 + 3 files changed, 302 insertions(+), 1 deletion(-) create mode 100644 torchvision/prototype/features/_dataset_wrapper.py diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py index 42ddd9aec27..f004ee57a76 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/features/__init__.py @@ -1,6 +1,8 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat from ._encoded import EncodedData, EncodedImage, EncodedVideo -from ._feature import _Feature, DType, is_simple_tensor +from ._feature import _Feature, DType, GenericFeature, is_simple_tensor from ._image import ColorSpace, Image, ImageType from ._label import Label, OneHotLabel from ._mask import Mask + +from ._dataset_wrapper import DatasetFeatureWrapper # usort: skip diff --git a/torchvision/prototype/features/_dataset_wrapper.py b/torchvision/prototype/features/_dataset_wrapper.py new file mode 100644 index 00000000000..db832c2b4da --- /dev/null +++ b/torchvision/prototype/features/_dataset_wrapper.py @@ -0,0 +1,295 @@ +from __future__ import annotations + +import functools +from collections import defaultdict +from typing import Any, Callable, cast, Dict, Optional, Tuple, Type + +import PIL.Image +import torch +from torch.utils._pytree import _get_node_type, LeafSpec, SUPPORTED_NODES, tree_flatten, tree_unflatten + +from torchvision import datasets +from torchvision.prototype import features +from torchvision.transforms import functional as F + + +def tree_flatten_to_spec(pytree, spec): + if isinstance(spec, LeafSpec): + return [pytree] + + node_type = _get_node_type(pytree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(pytree) + + flat = [] + for children_pytree, children_spec in zip(child_pytrees, spec.children_specs): + flat.extend(tree_flatten_to_spec(children_pytree, children_spec)) + + return flat + + +# FIXME: make this a proper dataset +class DatasetFeatureWrapper: + __wrappers_fns__: Dict[ + Type[datasets.VisionDataset], + Callable[[datasets.VisionDataset, bool, Dict[Type[features._Feature], Optional[torch.dtype]]], Any], + ] = {} + + def __init__(self, dataset, wrappers): + self.__dataset__ = dataset + self.__wrappers__ = wrappers + + # FIXME: re-route everything to __dataset__ besides __getitem__ + + def __getitem__(self, idx: int) -> Any: + # Do we wrap before or after the transforms? -> most likely after + sample = self.__dataset__[idx] + + wrappers_flat, spec = tree_flatten(self.__wrappers__) + sample_flat = tree_flatten_to_spec(sample, spec) + + wrapped_sample_flat = [wrapper(item) for item, wrapper in zip(sample_flat, wrappers_flat)] + + return tree_unflatten(wrapped_sample_flat, spec) + + @classmethod + def __register_wrappers_fn__(cls, dataset_type: Type[datasets.VisionDataset]): + def foo(wrappers_fn): + cls.__wrappers_fns__[dataset_type] = wrappers_fn + return wrappers_fn + + return foo + + @classmethod + def from_torchvision_dataset( + cls, dataset: datasets.VisionDataset, *, keep_pil_image: bool = False, dtypes=None + ) -> DatasetFeatureWrapper: + dtypes = defaultdict(lambda: None, dtypes or dict()) + wrappers_fn = cls.__wrappers_fns__[type(dataset)] + wrappers = wrappers_fn(dataset, keep_pil_image, dtypes) + return cls(dataset, wrappers) + + +def identity_wrapper(obj): + return obj + + +def generic_feature_wrapper(data): + return features.GenericFeature(data) + + +def wrap_image(image, *, keep_pil_image, dtype): + assert isinstance(image, PIL.Image.Image) + if keep_pil_image: + return image + + image = F.pil_to_tensor(image) + image = F.convert_image_dtype(image, dtype=dtype) + return features.Image(image) + + +def make_image_wrapper(*, keep_pil_image, dtypes): + return functools.partial(wrap_image, keep_pil_image=keep_pil_image, dtype=dtypes[features.Image]) + + +def wrap_label(label, *, categories, dtype): + return features.Label(label, categories=categories, dtype=dtype) + + +def make_label_wrapper(*, categories, dtypes): + return functools.partial(wrap_label, categories=categories, dtype=dtypes[features.Label]) + + +def wrap_segmentation_mask(segmentation_mask, *, dtype): + assert isinstance(segmentation_mask, PIL.Image.Image) + + segmentation_mask = F.pil_to_tensor(segmentation_mask) + segmentation_mask = F.convert_image_dtype(segmentation_mask, dtype=dtype) + return features.Mask(segmentation_mask.squeeze(0)) + + +def make_segmentation_mask_wrapper(*, dtypes): + return functools.partial(wrap_segmentation_mask, dtype=dtypes[features.Mask]) + + +CATEGORIES_GETTER = defaultdict( + lambda: lambda dataset: None, + { + datasets.Caltech256: lambda dataset: [name.rsplit(".", 1)[1] for name in dataset.categories], + datasets.CIFAR10: lambda dataset: dataset.classes, + datasets.CIFAR100: lambda dataset: dataset.classes, + datasets.FashionMNIST: lambda dataset: dataset.classes, + datasets.ImageNet: lambda dataset: [", ".join(names) for names in dataset.classes], + }, +) + + +def classification_wrappers(dataset, keep_pil_image, dtypes): + return ( + make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), + make_label_wrapper(categories=CATEGORIES_GETTER[type(dataset)](dataset), dtypes=dtypes), + ) + + +for dataset_type in [ + datasets.Caltech256, + datasets.CIFAR10, + datasets.CIFAR100, + datasets.ImageNet, + datasets.MNIST, + datasets.FashionMNIST, +]: + DatasetFeatureWrapper.__register_wrappers_fn__(dataset_type)(classification_wrappers) + + +def segmentation_wrappers(dataset, keep_pil_image, dtypes): + return ( + make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), + make_segmentation_mask_wrapper(dtypes=dtypes), + ) + + +for dataset_type in [ + datasets.VOCSegmentation, +]: + DatasetFeatureWrapper.__register_wrappers_fn__(dataset_type)(classification_wrappers) + + +@DatasetFeatureWrapper.__register_wrappers_fn__(datasets.Caltech101) +def caltech101_dectection_wrappers(dataset, keep_pil_image, dtypes): + target_type_wrapper_map = { + "category": make_label_wrapper(categories=dataset.categories, dtypes=dtypes), + "annotation": features.GenericFeature, + } + return ( + make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), + [target_type_wrapper_map[target_type] for target_type in dataset.target_type], + ) + + +@DatasetFeatureWrapper.__register_wrappers_fn__(datasets.CocoDetection) +def coco_dectection_wrappers(dataset, keep_pil_image, dtypes): + idx_to_category = {idx: cat["name"] for idx, cat in dataset.coco.cats.items()} + idx_to_category[0] = "__background__" + for idx in set(range(91)) - idx_to_category.keys(): + idx_to_category[idx] = "N/A" + + categories = [category for _, category in sorted(idx_to_category.items())] + + def segmentation_to_mask(segmentation: Any, *, iscrowd: bool, image_size: Tuple[int, int]) -> torch.Tensor: + from pycocotools import mask + + segmentation = ( + mask.frPyObjects(segmentation, *image_size) + if iscrowd + else mask.merge(mask.frPyObjects(segmentation, *image_size)) + ) + return torch.from_numpy(mask.decode(segmentation)).to(torch.bool) + + def sample_wrapper(sample): + image, target = sample + + _, height, width = F.get_dimensions(image) + image_size = height, width + + wrapped_image = wrap_image(image, keep_pil_image=keep_pil_image, dtype=dtypes[features.Image]) + + batched_target = defaultdict(list) + for object in target: + for key, value in object.items(): + batched_target[key].append(value) + + wrapped_target = dict( + batched_target, + segmentation=features.Mask( + torch.stack( + [ + segmentation_to_mask(segmentation, iscrowd=iscrowd, image_size=image_size) + for segmentation, iscrowd in zip(batched_target["segmentation"], batched_target["iscrowd"]) + ] + ), + dtype=dtypes.get(features.Mask), + ), + bbox=features.BoundingBox( + batched_target["bbox"], + format=features.BoundingBoxFormat.XYXY, + image_size=image_size, + dtype=dtypes.get(features.BoundingBox), + ), + labels=features.Label(batched_target.pop("category_id"), categories=categories), + ) + + return wrapped_image, wrapped_target + + return sample_wrapper + + +@DatasetFeatureWrapper.__register_wrappers_fn__(datasets.CocoCaptions) +def coco_captions_wrappers(dataset, keep_pil_image, dtypes): + return make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), identity_wrapper + + +@DatasetFeatureWrapper.__register_wrappers_fn__(datasets.VOCDetection) +def voc_detection_wrappers(dataset, keep_pil_image, dtypes): + categories = [ + "__background__", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", + ] + categories_to_idx = dict(zip(categories, range(len(categories)))) + + def target_wrapper(target): + batched_object = defaultdict(list) + for object in target["annotation"]["object"]: + for key, value in object.items(): + batched_object[key].append(value) + + wrapped_object = dict( + batched_object, + bndbox=features.BoundingBox( + [ + [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] + for bndbox in batched_object["bndbox"] + ], + format="xyxy", + image_size=cast( + Tuple[int, int], tuple(int(target["annotation"]["size"][dim]) for dim in ("height", "width")) + ), + ), + ) + wrapped_object["labels"] = features.Label( + [categories_to_idx[category] for category in batched_object["name"]], + categories=categories, + dtype=dtypes[features.Label], + ) + + target["annotation"]["object"] = wrapped_object + return target + + return make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), target_wrapper + + +@DatasetFeatureWrapper.__register_wrappers_fn__(datasets.SBDataset) +def sbd_wrappers(dataset, keep_pil_image, dtypes): + return { + "boundaries": (make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), generic_feature_wrapper), + "segmentation": segmentation_wrappers(dataset, keep_pil_image, dtypes), + }[dataset.mode] diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 9cfccf33e54..73c8370066b 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -232,3 +232,7 @@ def invert(self) -> _Feature: def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = None) -> _Feature: return self + + +class GenericFeature(_Feature): + pass From 63e114889aa6491e82b391b75c00e0827f0db460 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 28 Sep 2022 17:12:28 +0200 Subject: [PATCH 02/25] cleanup --- torchvision/prototype/features/__init__.py | 2 +- .../prototype/features/_dataset_wrapper.py | 217 +++++++++++++----- 2 files changed, 156 insertions(+), 63 deletions(-) diff --git a/torchvision/prototype/features/__init__.py b/torchvision/prototype/features/__init__.py index 2c26aaf26ed..956c66c5a6f 100644 --- a/torchvision/prototype/features/__init__.py +++ b/torchvision/prototype/features/__init__.py @@ -15,4 +15,4 @@ from ._label import Label, OneHotLabel from ._mask import Mask -from ._dataset_wrapper import DatasetFeatureWrapper # usort: skip +from ._dataset_wrapper import VisionDatasetFeatureWrapper # usort: skip diff --git a/torchvision/prototype/features/_dataset_wrapper.py b/torchvision/prototype/features/_dataset_wrapper.py index db832c2b4da..3ad5af29670 100644 --- a/torchvision/prototype/features/_dataset_wrapper.py +++ b/torchvision/prototype/features/_dataset_wrapper.py @@ -1,19 +1,32 @@ from __future__ import annotations +import contextlib + import functools from collections import defaultdict -from typing import Any, Callable, cast, Dict, Optional, Tuple, Type +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union import PIL.Image import torch -from torch.utils._pytree import _get_node_type, LeafSpec, SUPPORTED_NODES, tree_flatten, tree_unflatten +from torch.utils._pytree import ( + _get_node_type, + LeafSpec, + PyTree, + SUPPORTED_NODES, + tree_flatten, + tree_unflatten, + TreeSpec, +) +from torch.utils.data import Dataset from torchvision import datasets from torchvision.prototype import features from torchvision.transforms import functional as F +T = TypeVar("T") -def tree_flatten_to_spec(pytree, spec): + +def tree_flatten_to_spec(pytree: PyTree, spec: TreeSpec) -> List[Any]: if isinstance(spec, LeafSpec): return [pytree] @@ -28,87 +41,125 @@ def tree_flatten_to_spec(pytree, spec): return flat -# FIXME: make this a proper dataset -class DatasetFeatureWrapper: - __wrappers_fns__: Dict[ - Type[datasets.VisionDataset], - Callable[[datasets.VisionDataset, bool, Dict[Type[features._Feature], Optional[torch.dtype]]], Any], - ] = {} +class VisionDatasetFeatureWrapper(Dataset): + _wrappers_fns: Dict[Type[datasets.VisionDataset], PyTree] = {} - def __init__(self, dataset, wrappers): - self.__dataset__ = dataset - self.__wrappers__ = wrappers + def __init__(self, dataset: datasets.VisionDataset, wrappers: PyTree) -> None: + self.vision_dataset = dataset + self.sample_wrappers = wrappers - # FIXME: re-route everything to __dataset__ besides __getitem__ + # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply the + # transforms + self.transform, dataset.transform = dataset.transform, None + self.target_transform, dataset.target_transform = dataset.target_transform, None + self.transforms, dataset.transforms = dataset.transforms, None - def __getitem__(self, idx: int) -> Any: - # Do we wrap before or after the transforms? -> most likely after - sample = self.__dataset__[idx] + def __getattr__(self, item: str) -> Any: + with contextlib.suppress(AttributeError): + return object.__getattribute__(self, item) - wrappers_flat, spec = tree_flatten(self.__wrappers__) - sample_flat = tree_flatten_to_spec(sample, spec) + return getattr(self.vision_dataset, item) + def __getitem__(self, idx: int) -> Any: + # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor + # of this class + sample = self.vision_dataset[idx] + + wrappers_flat, spec = tree_flatten(self.sample_wrappers) + # We cannot use `tree_flatten` directly, because the spec of `self.sample_wrappers` and `sample` might differ. + # For example, for `COCODetection` the target is a list of dicts. To be able to wrap this into a dict of lists, + # we need to have access to the whole list, but `tree_flatten` would also flatten it. + sample_flat = tree_flatten_to_spec(sample, spec) wrapped_sample_flat = [wrapper(item) for item, wrapper in zip(sample_flat, wrappers_flat)] + sample = tree_unflatten(wrapped_sample_flat, spec) - return tree_unflatten(wrapped_sample_flat, spec) + # We don't need to care about `transform` and `target_transform` here since `VisionDataset` joins them into a + # `transforms` internally: + # https://github.com/pytorch/vision/blob/2d92728341bbd3dc1e0f1e86c6a436049bbb3403/torchvision/datasets/vision.py#L52-L54 + if self.transforms is not None: + sample = self.transforms(*sample) + + return sample + + def __len__(self) -> int: + return len(self.vision_dataset) @classmethod - def __register_wrappers_fn__(cls, dataset_type: Type[datasets.VisionDataset]): - def foo(wrappers_fn): - cls.__wrappers_fns__[dataset_type] = wrappers_fn + def _register_wrappers_fn(cls, dataset_type: Type[datasets.VisionDataset]) -> Callable[[PyTree], PyTree]: + def register(wrappers_fn: PyTree) -> PyTree: + cls._wrappers_fns[dataset_type] = wrappers_fn return wrappers_fn - return foo + return register @classmethod def from_torchvision_dataset( - cls, dataset: datasets.VisionDataset, *, keep_pil_image: bool = False, dtypes=None - ) -> DatasetFeatureWrapper: - dtypes = defaultdict(lambda: None, dtypes or dict()) - wrappers_fn = cls.__wrappers_fns__[type(dataset)] - wrappers = wrappers_fn(dataset, keep_pil_image, dtypes) + cls, + dataset: datasets.VisionDataset, + *, + keep_pil_image: bool = False, + bounding_box_format: Optional[features.BoundingBoxFormat] = None, + dtypes: Optional[Dict[Type[features._Feature], Optional[torch.dtype]]] = None, + ) -> VisionDatasetFeatureWrapper: + dtypes: Dict[Type[features._Feature], Optional[torch.dtype]] = { + features.Image: torch.uint8, + features.Label: torch.int64, + features.Mask: torch.uint8, + features.BoundingBox: torch.float32, + features.GenericFeature: None, + **(dtypes or dict()), + } + wrappers_fn = cls._wrappers_fns[type(dataset)] + wrappers = wrappers_fn(dataset, keep_pil_image, bounding_box_format, dtypes) return cls(dataset, wrappers) -def identity_wrapper(obj): +def identity_wrapper(obj: T) -> T: return obj -def generic_feature_wrapper(data): +def generic_feature_wrapper(data: Any) -> features.GenericFeature: return features.GenericFeature(data) -def wrap_image(image, *, keep_pil_image, dtype): - assert isinstance(image, PIL.Image.Image) - if keep_pil_image: - return image - +def wrap_image(image: PIL.Image.Image, *, dtype: Optional[torch.dtype]) -> features.Image: image = F.pil_to_tensor(image) - image = F.convert_image_dtype(image, dtype=dtype) + if dtype is not None: + image = F.convert_image_dtype(image, dtype=dtype) return features.Image(image) -def make_image_wrapper(*, keep_pil_image, dtypes): - return functools.partial(wrap_image, keep_pil_image=keep_pil_image, dtype=dtypes[features.Image]) +def make_image_wrapper( + *, keep_pil_image: bool, dtypes: Dict[Type[features._Feature], Optional[torch.dtype]] +) -> Callable[[PIL.Image.Image], Union[PIL.Image.Image, features.Image]]: + if keep_pil_image: + return identity_wrapper + + return functools.partial(wrap_image, dtype=dtypes[features.Image]) -def wrap_label(label, *, categories, dtype): +def wrap_label(label: Any, *, categories: Optional[Sequence[str]], dtype: Optional[torch.dtype]) -> features.Label: return features.Label(label, categories=categories, dtype=dtype) -def make_label_wrapper(*, categories, dtypes): +def make_label_wrapper( + *, categories: Optional[Sequence[str]], dtypes: Dict[Type[features._Feature], Optional[torch.dtype]] +) -> Callable[[Any], features.Label]: return functools.partial(wrap_label, categories=categories, dtype=dtypes[features.Label]) -def wrap_segmentation_mask(segmentation_mask, *, dtype): +def wrap_segmentation_mask(segmentation_mask: PIL.Image.Image, *, dtype: Optional[torch.dtype]) -> features.Mask: assert isinstance(segmentation_mask, PIL.Image.Image) segmentation_mask = F.pil_to_tensor(segmentation_mask) - segmentation_mask = F.convert_image_dtype(segmentation_mask, dtype=dtype) + if dtype is not None: + segmentation_mask = F.convert_image_dtype(segmentation_mask, dtype=dtype) return features.Mask(segmentation_mask.squeeze(0)) -def make_segmentation_mask_wrapper(*, dtypes): +def make_segmentation_mask_wrapper( + *, dtypes: Dict[Type[features._Feature], Optional[torch.dtype]] +) -> Callable[[Any], features.Mask]: return functools.partial(wrap_segmentation_mask, dtype=dtypes[features.Mask]) @@ -124,7 +175,12 @@ def make_segmentation_mask_wrapper(*, dtypes): ) -def classification_wrappers(dataset, keep_pil_image, dtypes): +def classification_wrappers( + dataset: datasets.VisionDataset, + keep_pil_image: bool, + bounding_box_format: Optional[features.BoundingBoxFormat], + dtypes: Dict[Type[features._Feature], Optional[torch.dtype]], +) -> Any: return ( make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), make_label_wrapper(categories=CATEGORIES_GETTER[type(dataset)](dataset), dtypes=dtypes), @@ -139,10 +195,15 @@ def classification_wrappers(dataset, keep_pil_image, dtypes): datasets.MNIST, datasets.FashionMNIST, ]: - DatasetFeatureWrapper.__register_wrappers_fn__(dataset_type)(classification_wrappers) + VisionDatasetFeatureWrapper._register_wrappers_fn(dataset_type)(classification_wrappers) -def segmentation_wrappers(dataset, keep_pil_image, dtypes): +def segmentation_wrappers( + dataset: datasets.VisionDataset, + keep_pil_image: bool, + bounding_box_format: Optional[features.BoundingBoxFormat], + dtypes: Dict[Type[features._Feature], Optional[torch.dtype]], +) -> Any: return ( make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), make_segmentation_mask_wrapper(dtypes=dtypes), @@ -152,11 +213,16 @@ def segmentation_wrappers(dataset, keep_pil_image, dtypes): for dataset_type in [ datasets.VOCSegmentation, ]: - DatasetFeatureWrapper.__register_wrappers_fn__(dataset_type)(classification_wrappers) + VisionDatasetFeatureWrapper._register_wrappers_fn(dataset_type)(segmentation_wrappers) -@DatasetFeatureWrapper.__register_wrappers_fn__(datasets.Caltech101) -def caltech101_dectection_wrappers(dataset, keep_pil_image, dtypes): +@VisionDatasetFeatureWrapper._register_wrappers_fn(datasets.Caltech101) +def caltech101_dectection_wrappers( + dataset: datasets.Caltech101, + keep_pil_image: bool, + bounding_box_format: Optional[features.BoundingBoxFormat], + dtypes: Dict[Type[features._Feature], Optional[torch.dtype]], +) -> Any: target_type_wrapper_map = { "category": make_label_wrapper(categories=dataset.categories, dtypes=dtypes), "annotation": features.GenericFeature, @@ -167,8 +233,13 @@ def caltech101_dectection_wrappers(dataset, keep_pil_image, dtypes): ) -@DatasetFeatureWrapper.__register_wrappers_fn__(datasets.CocoDetection) -def coco_dectection_wrappers(dataset, keep_pil_image, dtypes): +@VisionDatasetFeatureWrapper._register_wrappers_fn(datasets.CocoDetection) +def coco_dectection_wrappers( + dataset: datasets.CocoDetection, + keep_pil_image: bool, + bounding_box_format: Optional[features.BoundingBoxFormat], + dtypes: Dict[Type[features._Feature], Optional[torch.dtype]], +) -> Any: idx_to_category = {idx: cat["name"] for idx, cat in dataset.coco.cats.items()} idx_to_category[0] = "__background__" for idx in set(range(91)) - idx_to_category.keys(): @@ -186,13 +257,14 @@ def segmentation_to_mask(segmentation: Any, *, iscrowd: bool, image_size: Tuple[ ) return torch.from_numpy(mask.decode(segmentation)).to(torch.bool) - def sample_wrapper(sample): + def sample_wrapper(sample: Tuple[PIL.Image, List[Dict[str, Any]]]) -> Tuple[features.Image, Dict[str, Any]]: image, target = sample _, height, width = F.get_dimensions(image) image_size = height, width - wrapped_image = wrap_image(image, keep_pil_image=keep_pil_image, dtype=dtypes[features.Image]) + image_wrapper = make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes) + wrapped_image = image_wrapper(image) batched_target = defaultdict(list) for object in target: @@ -218,19 +290,31 @@ def sample_wrapper(sample): ), labels=features.Label(batched_target.pop("category_id"), categories=categories), ) + if bounding_box_format is not None: + wrapped_target["bbox"] = cast(features.BoundingBox, wrapped_target["bbox"]).to_format(bounding_box_format) return wrapped_image, wrapped_target return sample_wrapper -@DatasetFeatureWrapper.__register_wrappers_fn__(datasets.CocoCaptions) -def coco_captions_wrappers(dataset, keep_pil_image, dtypes): +@VisionDatasetFeatureWrapper._register_wrappers_fn(datasets.CocoCaptions) +def coco_captions_wrappers( + dataset: datasets.CocoCaptions, + keep_pil_image: bool, + bounding_box_format: Optional[features.BoundingBoxFormat], + dtypes: Dict[Type[features._Feature], Optional[torch.dtype]], +) -> Any: return make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), identity_wrapper -@DatasetFeatureWrapper.__register_wrappers_fn__(datasets.VOCDetection) -def voc_detection_wrappers(dataset, keep_pil_image, dtypes): +@VisionDatasetFeatureWrapper._register_wrappers_fn(datasets.VOCDetection) +def voc_detection_wrappers( + dataset: datasets.VOCDetection, + keep_pil_image: bool, + bounding_box_format: Optional[features.BoundingBoxFormat], + dtypes: Dict[Type[features._Feature], Optional[torch.dtype]], +) -> Any: categories = [ "__background__", "aeroplane", @@ -256,7 +340,7 @@ def voc_detection_wrappers(dataset, keep_pil_image, dtypes): ] categories_to_idx = dict(zip(categories, range(len(categories)))) - def target_wrapper(target): + def target_wrapper(target: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: batched_object = defaultdict(list) for object in target["annotation"]["object"]: for key, value in object.items(): @@ -269,12 +353,16 @@ def target_wrapper(target): [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bndbox in batched_object["bndbox"] ], - format="xyxy", + format=features.BoundingBoxFormat.XYXY, image_size=cast( Tuple[int, int], tuple(int(target["annotation"]["size"][dim]) for dim in ("height", "width")) ), ), ) + if bounding_box_format is not None: + wrapped_object["bndbox"] = cast(features.BoundingBox, wrapped_object["bndbox"]).to_format( + bounding_box_format + ) wrapped_object["labels"] = features.Label( [categories_to_idx[category] for category in batched_object["name"]], categories=categories, @@ -287,9 +375,14 @@ def target_wrapper(target): return make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), target_wrapper -@DatasetFeatureWrapper.__register_wrappers_fn__(datasets.SBDataset) -def sbd_wrappers(dataset, keep_pil_image, dtypes): +@VisionDatasetFeatureWrapper._register_wrappers_fn(datasets.SBDataset) +def sbd_wrappers( + dataset: datasets.SBDataset, + keep_pil_image: bool, + bounding_box_format: Optional[features.BoundingBoxFormat], + dtypes: Dict[Type[features._Feature], Optional[torch.dtype]], +) -> Any: return { "boundaries": (make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), generic_feature_wrapper), - "segmentation": segmentation_wrappers(dataset, keep_pil_image, dtypes), + "segmentation": segmentation_wrappers(dataset, keep_pil_image, bounding_box_format, dtypes), }[dataset.mode] From dbfac05fa75496937ad82ba7341c12a0e4dac565 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 31 Jan 2023 09:04:36 +0100 Subject: [PATCH 03/25] refactor --- torchvision/prototype/datapoints/__init__.py | 2 +- .../prototype/datapoints/_dataset_wrapper.py | 389 ++++++------------ 2 files changed, 129 insertions(+), 262 deletions(-) diff --git a/torchvision/prototype/datapoints/__init__.py b/torchvision/prototype/datapoints/__init__.py index d0c1035a76b..ff6c44ab108 100644 --- a/torchvision/prototype/datapoints/__init__.py +++ b/torchvision/prototype/datapoints/__init__.py @@ -5,4 +5,4 @@ from ._mask import Mask from ._video import TensorVideoType, TensorVideoTypeJIT, Video, VideoType, VideoTypeJIT -from ._dataset_wrapper import VisionDatasetFeatureWrapper # usort: skip +from ._dataset_wrapper import wrap_dataset_for_transforms_v2 # type: ignore[attr-defined] # usort: skip diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 64f0d024cdd..c9970e41fbd 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -1,52 +1,40 @@ +# type: ignore + from __future__ import annotations import contextlib import functools from collections import defaultdict -from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar import PIL.Image import torch -from torch.utils._pytree import ( - _get_node_type, - LeafSpec, - PyTree, - SUPPORTED_NODES, - tree_flatten, - tree_unflatten, - TreeSpec, -) from torch.utils.data import Dataset - from torchvision import datasets from torchvision.prototype import datapoints -from torchvision.transforms import functional as F +from torchvision.prototype.transforms import functional as F T = TypeVar("T") +D = TypeVar("D", bound=datasets.VisionDataset) +__all__ = ["wrap_dataset_for_transforms_v2"] -def tree_flatten_to_spec(pytree: PyTree, spec: TreeSpec) -> List[Any]: - if isinstance(spec, LeafSpec): - return [pytree] - - node_type = _get_node_type(pytree) - flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - child_pytrees, context = flatten_fn(pytree) - - flat = [] - for children_pytree, children_spec in zip(child_pytrees, spec.children_specs): - flat.extend(tree_flatten_to_spec(children_pytree, children_spec)) +_WRAPPERS = {} - return flat +# TODO: naming! +def wrap_dataset_for_transforms_v2(dataset: datasets.VisionDataset) -> _VisionDatasetDatapointWrapper: + wrapper = _WRAPPERS.get(type(dataset)) + if wrapper is None: + raise TypeError + return _VisionDatasetDatapointWrapper(dataset, wrapper) -class VisionDatasetFeatureWrapper(Dataset): - _wrappers_fns: Dict[Type[datasets.VisionDataset], PyTree] = {} - def __init__(self, dataset: datasets.VisionDataset, wrappers: PyTree) -> None: +class _VisionDatasetDatapointWrapper(Dataset): + def __init__(self, dataset: datasets.VisionDataset, wrapper) -> None: self.vision_dataset = dataset - self.sample_wrappers = wrappers + self.wrapper = wrapper # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply the # transforms @@ -65,13 +53,7 @@ def __getitem__(self, idx: int) -> Any: # of this class sample = self.vision_dataset[idx] - wrappers_flat, spec = tree_flatten(self.sample_wrappers) - # We cannot use `tree_flatten` directly, because the spec of `self.sample_wrappers` and `sample` might differ. - # For example, for `COCODetection` the target is a list of dicts. To be able to wrap this into a dict of lists, - # we need to have access to the whole list, but `tree_flatten` would also flatten it. - sample_flat = tree_flatten_to_spec(sample, spec) - wrapped_sample_flat = [wrapper(item) for item, wrapper in zip(sample_flat, wrappers_flat)] - sample = tree_unflatten(wrapped_sample_flat, spec) + sample = self.wrapper(self.vision_dataset, sample) # We don't need to care about `transform` and `target_transform` here since `VisionDataset` joins them into a # `transforms` internally: @@ -84,107 +66,28 @@ def __getitem__(self, idx: int) -> Any: def __len__(self) -> int: return len(self.vision_dataset) - @classmethod - def _register_wrappers_fn(cls, dataset_type: Type[datasets.VisionDataset]) -> Callable[[PyTree], PyTree]: - def register(wrappers_fn: PyTree) -> PyTree: - cls._wrappers_fns[dataset_type] = wrappers_fn - return wrappers_fn - - return register - - @classmethod - def from_torchvision_dataset( - cls, - dataset: datasets.VisionDataset, - *, - keep_pil_image: bool = False, - bounding_box_format: Optional[datapoints.BoundingBoxFormat] = None, - dtypes: Optional[Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]]] = None, - ) -> VisionDatasetFeatureWrapper: - dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]] = { - datapoints.Image: torch.uint8, - datapoints.Label: torch.int64, - datapoints.Mask: torch.uint8, - datapoints.BoundingBox: torch.float32, - datapoints.GenericDatapoint: None, - **(dtypes or dict()), - } - wrappers_fn = cls._wrappers_fns[type(dataset)] - wrappers = wrappers_fn(dataset, keep_pil_image, bounding_box_format, dtypes) - return cls(dataset, wrappers) - - -def identity_wrapper(obj: T) -> T: - return obj - - -def generic_feature_wrapper(data: Any) -> datapoints.GenericDatapoint: - return datapoints.GenericDatapoint(data) - - -def wrap_image(image: PIL.Image.Image, *, dtype: Optional[torch.dtype]) -> datapoints.Image: - image = F.pil_to_tensor(image) - if dtype is not None: - image = F.convert_image_dtype(image, dtype=dtype) - return datapoints.Image(image) - - -def make_image_wrapper( - *, keep_pil_image: bool, dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]] -) -> Callable[[PIL.Image.Image], Union[PIL.Image.Image, datapoints.Image]]: - if keep_pil_image: - return identity_wrapper - return functools.partial(wrap_image, dtype=dtypes[datapoints.Image]) +def identity_wrapper(sample: T) -> T: + return sample -def wrap_label(label: Any, *, categories: Optional[Sequence[str]], dtype: Optional[torch.dtype]) -> datapoints.Label: - return datapoints.Label(label, categories=categories, dtype=dtype) - - -def make_label_wrapper( - *, categories: Optional[Sequence[str]], dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]] -) -> Callable[[Any], datapoints.Label]: - return functools.partial(wrap_label, categories=categories, dtype=dtypes[datapoints.Label]) - - -def wrap_segmentation_mask(segmentation_mask: PIL.Image.Image, *, dtype: Optional[torch.dtype]) -> datapoints.Mask: - assert isinstance(segmentation_mask, PIL.Image.Image) - - segmentation_mask = F.pil_to_tensor(segmentation_mask) - if dtype is not None: - segmentation_mask = F.convert_image_dtype(segmentation_mask, dtype=dtype) - return datapoints.Mask(segmentation_mask.squeeze(0)) - - -def make_segmentation_mask_wrapper( - *, dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]] -) -> Callable[[Any], datapoints.Mask]: - return functools.partial(wrap_segmentation_mask, dtype=dtypes[datapoints.Mask]) - - -CATEGORIES_GETTER = defaultdict( - lambda: lambda dataset: None, - { +@functools.lru_cache(maxsize=None) +def get_categories(dataset: datasets.VisionDataset) -> Optional[List[str]]: + categories_fn = { datasets.Caltech256: lambda dataset: [name.rsplit(".", 1)[1] for name in dataset.categories], datasets.CIFAR10: lambda dataset: dataset.classes, datasets.CIFAR100: lambda dataset: dataset.classes, datasets.FashionMNIST: lambda dataset: dataset.classes, datasets.ImageNet: lambda dataset: [", ".join(names) for names in dataset.classes], - }, -) - - -def classification_wrappers( - dataset: datasets.VisionDataset, - keep_pil_image: bool, - bounding_box_format: Optional[datapoints.BoundingBoxFormat], - dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]], -) -> Any: - return ( - make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), - make_label_wrapper(categories=CATEGORIES_GETTER[type(dataset)](dataset), dtypes=dtypes), - ) + }.get(type(dataset)) + return categories_fn(dataset) if categories_fn is not None else None + + +def classification_wrapper( + dataset: datasets.VisionDataset, sample: Tuple[PIL.Image.Image, int] +) -> Tuple[PIL.Image.Image, datapoints.Label]: + image, label = sample + return image, datapoints.Label(label, categories=get_categories(dataset)) for dataset_type in [ @@ -195,51 +98,45 @@ def classification_wrappers( datasets.MNIST, datasets.FashionMNIST, ]: - VisionDatasetFeatureWrapper._register_wrappers_fn(dataset_type)(classification_wrappers) - - -def segmentation_wrappers( - dataset: datasets.VisionDataset, - keep_pil_image: bool, - bounding_box_format: Optional[datapoints.BoundingBoxFormat], - dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]], -) -> Any: - return ( - make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), - make_segmentation_mask_wrapper(dtypes=dtypes), - ) + _WRAPPERS[dataset_type] = classification_wrapper + + +def segmentation_wrapper( + dataset: datasets.VisionDataset, sample: Tuple[PIL.Image.Image, PIL.Image.Image] +) -> Tuple[PIL.Image.Image, datapoints.Mask]: + image, mask = sample + return image, datapoints.Mask(F.to_image_tensor(mask)) for dataset_type in [ datasets.VOCSegmentation, ]: - VisionDatasetFeatureWrapper._register_wrappers_fn(dataset_type)(segmentation_wrappers) - - -@VisionDatasetFeatureWrapper._register_wrappers_fn(datasets.Caltech101) -def caltech101_dectection_wrappers( - dataset: datasets.Caltech101, - keep_pil_image: bool, - bounding_box_format: Optional[datapoints.BoundingBoxFormat], - dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]], -) -> Any: - target_type_wrapper_map = { - "category": make_label_wrapper(categories=dataset.categories, dtypes=dtypes), + _WRAPPERS[dataset_type] = segmentation_wrapper + + +def caltech101_wrapper( + dataset: datasets.Caltech101, sample: Tuple[PIL.Image.Image, Any] +) -> Tuple[PIL.Image.Image, Any]: + image, target = sample + + target_type_wrapper_map: Dict[str, Callable] = { + "category": lambda label: datapoints.Label(label, categories=dataset.categories), "annotation": datapoints.GenericDatapoint, } - return ( - make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), - [target_type_wrapper_map[target_type] for target_type in dataset.target_type], - ) + if len(dataset.target_type) == 1: + target = target_type_wrapper_map[dataset.target_type[0]](target) + else: + target = tuple(target_type_wrapper_map[typ](item) for typ, item in zip(dataset.target_type, target)) + + return image, target -@VisionDatasetFeatureWrapper._register_wrappers_fn(datasets.CocoDetection) -def coco_dectection_wrappers( - dataset: datasets.CocoDetection, - keep_pil_image: bool, - bounding_box_format: Optional[datapoints.BoundingBoxFormat], - dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]], -) -> Any: +_WRAPPERS[datasets.Caltech101] = caltech101_wrapper + + +def coco_dectection_wrapper( + dataset: datasets.CocoDetection, sample: Tuple[PIL.Image.Image, List[Dict[str, Any]]] +) -> Tuple[PIL.Image.Image, Dict[str, List[Any]]]: idx_to_category = {idx: cat["name"] for idx, cat in dataset.coco.cats.items()} idx_to_category[0] = "__background__" for idx in set(range(91)) - idx_to_category.keys(): @@ -247,74 +144,55 @@ def coco_dectection_wrappers( categories = [category for _, category in sorted(idx_to_category.items())] - def segmentation_to_mask(segmentation: Any, *, iscrowd: bool, spatial_size: Tuple[int, int]) -> torch.Tensor: + def segmentation_to_mask(segmentation: Any, *, spatial_size: Tuple[int, int]) -> torch.Tensor: from pycocotools import mask segmentation = ( mask.frPyObjects(segmentation, *spatial_size) - if iscrowd + if isinstance(segmentation, dict) else mask.merge(mask.frPyObjects(segmentation, *spatial_size)) ) - return torch.from_numpy(mask.decode(segmentation)).to(torch.bool) - - def sample_wrapper(sample: Tuple[PIL.Image, List[Dict[str, Any]]]) -> Tuple[datapoints.Image, Dict[str, Any]]: - image, target = sample - - _, height, width = F.get_dimensions(image) - spatial_size = height, width - - image_wrapper = make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes) - wrapped_image = image_wrapper(image) - - batched_target = defaultdict(list) - for object in target: - for key, value in object.items(): - batched_target[key].append(value) - - wrapped_target = dict( - batched_target, - segmentation=datapoints.Mask( - torch.stack( - [ - segmentation_to_mask(segmentation, iscrowd=iscrowd, spatial_size=spatial_size) - for segmentation, iscrowd in zip(batched_target["segmentation"], batched_target["iscrowd"]) - ] - ), - dtype=dtypes.get(datapoints.Mask), - ), - bbox=datapoints.BoundingBox( - batched_target["bbox"], - format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=spatial_size, - dtype=dtypes.get(datapoints.BoundingBox), + return torch.from_numpy(mask.decode(segmentation)) + + image, target = sample + + # Originally, COCODetection returns a list of dicts in which each dict represents an object instance on the image. + # However, our transforms and models expect all instance annotations grouped together, if applicable as tensor with + # batch dimension. Thus, we are changing the target to a dict of lists here. + batched_target = defaultdict(list) + for object in target: + for key, value in object.items(): + batched_target[key].append(value) + + spatial_size = tuple(F.get_spatial_size(image)) + batched_target = dict( + batched_target, + boxes=datapoints.BoundingBox( + batched_target["bbox"], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + ), + masks=datapoints.Mask( + torch.stack( + [ + segmentation_to_mask(segmentation, spatial_size=spatial_size) + for segmentation in batched_target["segmentation"] + ] ), - labels=datapoints.Label(batched_target.pop("category_id"), categories=categories), - ) - if bounding_box_format is not None: - wrapped_target["bbox"] = cast(datapoints.BoundingBox, wrapped_target["bbox"]).to_format(bounding_box_format) - - return wrapped_image, wrapped_target + ), + labels=datapoints.Label(batched_target["category_id"], categories=categories), + ) - return sample_wrapper + return image, batched_target -@VisionDatasetFeatureWrapper._register_wrappers_fn(datasets.CocoCaptions) -def coco_captions_wrappers( - dataset: datasets.CocoCaptions, - keep_pil_image: bool, - bounding_box_format: Optional[datapoints.BoundingBoxFormat], - dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]], -) -> Any: - return make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), identity_wrapper +_WRAPPERS[datasets.CocoDetection] = coco_dectection_wrapper +_WRAPPERS[datasets.CocoCaptions] = identity_wrapper -@VisionDatasetFeatureWrapper._register_wrappers_fn(datasets.VOCDetection) -def voc_detection_wrappers( - dataset: datasets.VOCDetection, - keep_pil_image: bool, - bounding_box_format: Optional[datapoints.BoundingBoxFormat], - dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]], -) -> Any: +def voc_detection_wrapper( + dataset: datasets.VOCDetection, sample: Tuple[PIL.Image.Image, Any] +) -> Tuple[PIL.Image.Image, Any]: categories = [ "__background__", "aeroplane", @@ -340,49 +218,38 @@ def voc_detection_wrappers( ] categories_to_idx = dict(zip(categories, range(len(categories)))) - def target_wrapper(target: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, Any]]: - batched_object = defaultdict(list) - for object in target["annotation"]["object"]: - for key, value in object.items(): - batched_object[key].append(value) + image, target = sample - wrapped_object = dict( - batched_object, - bndbox=datapoints.BoundingBox( - [ - [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] - for bndbox in batched_object["bndbox"] - ], - format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=cast( - Tuple[int, int], tuple(int(target["annotation"]["size"][dim]) for dim in ("height", "width")) - ), - ), - ) - if bounding_box_format is not None: - wrapped_object["bndbox"] = cast(datapoints.BoundingBox, wrapped_object["bndbox"]).to_format( - bounding_box_format - ) - wrapped_object["labels"] = datapoints.Label( - [categories_to_idx[category] for category in batched_object["name"]], - categories=categories, - dtype=dtypes[datapoints.Label], - ) + batched_instances = defaultdict(list) + for object in target["annotation"]["object"]: + for key, value in object.items(): + batched_instances[key].append(value) + + target["boxes"] = datapoints.BoundingBox( + [[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bndbox in batched_instances["bndbox"]], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=tuple(int(target["annotation"]["size"][dim]) for dim in ("height", "width")), + ) + target["labels"] = datapoints.Label( + [categories_to_idx[category] for category in batched_instances["name"]], + categories=categories, + ) + + return image, target + + +_WRAPPERS[datasets.VOCDetection] = voc_detection_wrapper + + +def sbd_wrapper(dataset: datasets.SBDataset, sample: Tuple[PIL.Image.Image, Any]) -> Tuple[PIL.Image.Image, Any]: + image, target = sample - target["annotation"]["object"] = wrapped_object - return target + if dataset.mode == "boundaries": + target = datapoints.GenericDatapoint(target) + else: + target = datapoints.Mask(F.to_image_tensor(target)) - return make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), target_wrapper + return image, target -@VisionDatasetFeatureWrapper._register_wrappers_fn(datasets.SBDataset) -def sbd_wrappers( - dataset: datasets.SBDataset, - keep_pil_image: bool, - bounding_box_format: Optional[datapoints.BoundingBoxFormat], - dtypes: Dict[Type[datapoints._datapoint.Datapoint], Optional[torch.dtype]], -) -> Any: - return { - "boundaries": (make_image_wrapper(keep_pil_image=keep_pil_image, dtypes=dtypes), generic_feature_wrapper), - "segmentation": segmentation_wrappers(dataset, keep_pil_image, bounding_box_format, dtypes), - }[dataset.mode] +_WRAPPERS[datasets.SBDataset] = sbd_wrapper From 2dba1c75684bcf0b27bf5d6be6f0ca80fdda3710 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 31 Jan 2023 09:18:07 +0100 Subject: [PATCH 04/25] handle None label for test set use case --- torchvision/prototype/datapoints/_dataset_wrapper.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index c9970e41fbd..20cea062b1e 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -84,10 +84,12 @@ def get_categories(dataset: datasets.VisionDataset) -> Optional[List[str]]: def classification_wrapper( - dataset: datasets.VisionDataset, sample: Tuple[PIL.Image.Image, int] -) -> Tuple[PIL.Image.Image, datapoints.Label]: + dataset: datasets.VisionDataset, sample: Tuple[PIL.Image.Image, Optional[int]] +) -> Tuple[PIL.Image.Image, Optional[datapoints.Label]]: image, label = sample - return image, datapoints.Label(label, categories=get_categories(dataset)) + if label is not None: + label = datapoints.Label(label, categories=get_categories(dataset)) + return image, label for dataset_type in [ From bcd76206942a81837360e6cb78dc67525a8e8dae Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 31 Jan 2023 09:19:16 +0100 Subject: [PATCH 05/25] minor cleanup --- torchvision/prototype/datapoints/_dataset_wrapper.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 20cea062b1e..56fb9df0f99 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -6,7 +6,7 @@ import functools from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar +from typing import Any, Callable, Dict, List, Optional, Tuple import PIL.Image import torch @@ -15,9 +15,6 @@ from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F -T = TypeVar("T") -D = TypeVar("D", bound=datasets.VisionDataset) - __all__ = ["wrap_dataset_for_transforms_v2"] _WRAPPERS = {} @@ -67,10 +64,6 @@ def __len__(self) -> int: return len(self.vision_dataset) -def identity_wrapper(sample: T) -> T: - return sample - - @functools.lru_cache(maxsize=None) def get_categories(dataset: datasets.VisionDataset) -> Optional[List[str]]: categories_fn = { @@ -189,7 +182,7 @@ def segmentation_to_mask(segmentation: Any, *, spatial_size: Tuple[int, int]) -> _WRAPPERS[datasets.CocoDetection] = coco_dectection_wrapper -_WRAPPERS[datasets.CocoCaptions] = identity_wrapper +_WRAPPERS[datasets.CocoCaptions] = lambda sample: sample def voc_detection_wrapper( From fe6be600cee004375a528077003aed3218a0ddec Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 1 Feb 2023 16:57:32 +0100 Subject: [PATCH 06/25] minor refactorings --- .../prototype/datapoints/_dataset_wrapper.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 56fb9df0f99..3cec39d2a8b 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -22,35 +22,41 @@ # TODO: naming! def wrap_dataset_for_transforms_v2(dataset: datasets.VisionDataset) -> _VisionDatasetDatapointWrapper: - wrapper = _WRAPPERS.get(type(dataset)) + dataset_cls = type(dataset) + wrapper = _WRAPPERS.get(dataset_cls) if wrapper is None: - raise TypeError + # TODO: If we have documentation on how to do that, put a link in the error message. + msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." + if dataset_cls in datasets.__dict__.values(): + msg = ( + f"{msg} If an automated wrapper for this dataset would be useful for you, " + f"please open an issue at https://github.com/pytorch/vision/issues." + ) + raise ValueError(msg) return _VisionDatasetDatapointWrapper(dataset, wrapper) class _VisionDatasetDatapointWrapper(Dataset): def __init__(self, dataset: datasets.VisionDataset, wrapper) -> None: - self.vision_dataset = dataset - self.wrapper = wrapper + self._vision_dataset = dataset + self._wrapper = wrapper # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply the # transforms - self.transform, dataset.transform = dataset.transform, None - self.target_transform, dataset.target_transform = dataset.target_transform, None self.transforms, dataset.transforms = dataset.transforms, None def __getattr__(self, item: str) -> Any: with contextlib.suppress(AttributeError): return object.__getattribute__(self, item) - return getattr(self.vision_dataset, item) + return getattr(self._vision_dataset, item) def __getitem__(self, idx: int) -> Any: # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor # of this class - sample = self.vision_dataset[idx] + sample = self._vision_dataset[idx] - sample = self.wrapper(self.vision_dataset, sample) + sample = self._wrapper(self._vision_dataset, sample) # We don't need to care about `transform` and `target_transform` here since `VisionDataset` joins them into a # `transforms` internally: @@ -61,7 +67,7 @@ def __getitem__(self, idx: int) -> Any: return sample def __len__(self) -> int: - return len(self.vision_dataset) + return len(self._vision_dataset) @functools.lru_cache(maxsize=None) From cff90923f5994437c1a8f37cd13b793e7ff74d7b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 1 Feb 2023 17:06:46 +0100 Subject: [PATCH 07/25] minor cache refactoring for COCO --- .../prototype/datapoints/_dataset_wrapper.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 3cec39d2a8b..c60ce7f667e 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -17,6 +17,8 @@ __all__ = ["wrap_dataset_for_transforms_v2"] +cache = functools.partial(functools.lru_cache, max_size=None) + _WRAPPERS = {} @@ -70,7 +72,7 @@ def __len__(self) -> int: return len(self._vision_dataset) -@functools.lru_cache(maxsize=None) +@cache def get_categories(dataset: datasets.VisionDataset) -> Optional[List[str]]: categories_fn = { datasets.Caltech256: lambda dataset: [name.rsplit(".", 1)[1] for name in dataset.categories], @@ -135,16 +137,19 @@ def caltech101_wrapper( _WRAPPERS[datasets.Caltech101] = caltech101_wrapper -def coco_dectection_wrapper( - dataset: datasets.CocoDetection, sample: Tuple[PIL.Image.Image, List[Dict[str, Any]]] -) -> Tuple[PIL.Image.Image, Dict[str, List[Any]]]: +@cache +def get_coco_detection_categories(dataset: datasets.CocoDetection): idx_to_category = {idx: cat["name"] for idx, cat in dataset.coco.cats.items()} idx_to_category[0] = "__background__" for idx in set(range(91)) - idx_to_category.keys(): idx_to_category[idx] = "N/A" - categories = [category for _, category in sorted(idx_to_category.items())] + return [category for _, category in sorted(idx_to_category.items())] + +def coco_dectection_wrapper( + dataset: datasets.CocoDetection, sample: Tuple[PIL.Image.Image, List[Dict[str, Any]]] +) -> Tuple[PIL.Image.Image, Dict[str, List[Any]]]: def segmentation_to_mask(segmentation: Any, *, spatial_size: Tuple[int, int]) -> torch.Tensor: from pycocotools import mask @@ -181,7 +186,7 @@ def segmentation_to_mask(segmentation: Any, *, spatial_size: Tuple[int, int]) -> ] ), ), - labels=datapoints.Label(batched_target["category_id"], categories=categories), + labels=datapoints.Label(batched_target["category_id"], categories=get_coco_detection_categories(dataset)), ) return image, batched_target From 9965492b26f13d6663026758a10691e4cb9e7b69 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 1 Feb 2023 17:16:55 +0100 Subject: [PATCH 08/25] remove GenericDatapoint for now --- .../prototype/datapoints/_dataset_wrapper.py | 36 ++++++++++--------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index c60ce7f667e..dc27ed0ca1a 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -6,7 +6,7 @@ import functools from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import PIL.Image import torch @@ -38,6 +38,17 @@ def wrap_dataset_for_transforms_v2(dataset: datasets.VisionDataset) -> _VisionDa return _VisionDatasetDatapointWrapper(dataset, wrapper) +def raise_missing_functionality(dataset, *params): + msg = f"{type(dataset).__name__}" + if params: + param_msg = ", ".join(f"{param}={getattr(dataset, param)}" for param in params) + msg = f"{msg} with {param_msg}" + raise RuntimeError( + f"{msg} is currently not supported by this wrapper. " + f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues." + ) + + class _VisionDatasetDatapointWrapper(Dataset): def __init__(self, dataset: datasets.VisionDataset, wrapper) -> None: self._vision_dataset = dataset @@ -120,18 +131,12 @@ def segmentation_wrapper( def caltech101_wrapper( dataset: datasets.Caltech101, sample: Tuple[PIL.Image.Image, Any] ) -> Tuple[PIL.Image.Image, Any]: - image, target = sample + if "annotation" in dataset.target_type: + raise_missing_functionality(dataset, "target_type") - target_type_wrapper_map: Dict[str, Callable] = { - "category": lambda label: datapoints.Label(label, categories=dataset.categories), - "annotation": datapoints.GenericDatapoint, - } - if len(dataset.target_type) == 1: - target = target_type_wrapper_map[dataset.target_type[0]](target) - else: - target = tuple(target_type_wrapper_map[typ](item) for typ, item in zip(dataset.target_type, target)) + image, target = sample - return image, target + return image, datapoints.Label(target, categories=dataset.categories) _WRAPPERS[datasets.Caltech101] = caltech101_wrapper @@ -248,14 +253,11 @@ def voc_detection_wrapper( def sbd_wrapper(dataset: datasets.SBDataset, sample: Tuple[PIL.Image.Image, Any]) -> Tuple[PIL.Image.Image, Any]: - image, target = sample - if dataset.mode == "boundaries": - target = datapoints.GenericDatapoint(target) - else: - target = datapoints.Mask(F.to_image_tensor(target)) + raise_missing_functionality(dataset, "mode") - return image, target + image, target = sample + return image, datapoints.Mask(F.to_image_tensor(target)) _WRAPPERS[datasets.SBDataset] = sbd_wrapper From d64e1a933051712fedb9694d4cb58c90231d2740 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 2 Feb 2023 11:33:15 +0100 Subject: [PATCH 09/25] add all detection and segmentation datasets --- .../prototype/datapoints/_dataset_wrapper.py | 287 +++++++++++++----- 1 file changed, 210 insertions(+), 77 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index dc27ed0ca1a..ae7172eb101 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -3,21 +3,22 @@ from __future__ import annotations import contextlib - import functools -from collections import defaultdict -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import PIL.Image + import torch from torch.utils.data import Dataset + from torchvision import datasets +from torchvision._utils import sequence_to_str from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F __all__ = ["wrap_dataset_for_transforms_v2"] -cache = functools.partial(functools.lru_cache, max_size=None) +cache = functools.partial(functools.lru_cache) _WRAPPERS = {} @@ -38,17 +39,6 @@ def wrap_dataset_for_transforms_v2(dataset: datasets.VisionDataset) -> _VisionDa return _VisionDatasetDatapointWrapper(dataset, wrapper) -def raise_missing_functionality(dataset, *params): - msg = f"{type(dataset).__name__}" - if params: - param_msg = ", ".join(f"{param}={getattr(dataset, param)}" for param in params) - msg = f"{msg} with {param_msg}" - raise RuntimeError( - f"{msg} is currently not supported by this wrapper. " - f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues." - ) - - class _VisionDatasetDatapointWrapper(Dataset): def __init__(self, dataset: datasets.VisionDataset, wrapper) -> None: self._vision_dataset = dataset @@ -83,6 +73,45 @@ def __len__(self) -> int: return len(self._vision_dataset) +def list_of_dicts_to_dict_of_lists(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, List]: + if not list_of_dicts: + return {} + + dict_of_lists = {key: [value] for key, value in list_of_dicts[0].items()} + for dct in list_of_dicts[1:]: + for key, value in dct.items(): + dict_of_lists[key].append(value) + return dict_of_lists + + +def wrap_target_by_type( + dataset, target, type_wrappers: Dict[str, Callable], *, fail_on=(), attr_name: str = "target_type" +): + if target is None: + return None + + target_types = getattr(dataset, attr_name) + + if any(target_type in fail_on for target_type in target_types): + raise RuntimeError( + f"{type(dataset).__name__} with target type(s) {sequence_to_str(fail_on, separate_last='or ')} " + f"is currently not supported by this wrapper. " + f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues." + ) + + if not isinstance(target, (tuple, list)): + target = [target] + + wrapped_target = tuple( + type_wrappers.get(target_type, lambda item: item)(item) for target_type, item in zip(target_types, target) + ) + + if len(wrapped_target) == 1: + wrapped_target = wrapped_target[0] + + return wrapped_target + + @cache def get_categories(dataset: datasets.VisionDataset) -> Optional[List[str]]: categories_fn = { @@ -111,6 +140,7 @@ def classification_wrapper( datasets.ImageNet, datasets.MNIST, datasets.FashionMNIST, + datasets.GTSRB, ]: _WRAPPERS[dataset_type] = classification_wrapper @@ -119,7 +149,7 @@ def segmentation_wrapper( dataset: datasets.VisionDataset, sample: Tuple[PIL.Image.Image, PIL.Image.Image] ) -> Tuple[PIL.Image.Image, datapoints.Mask]: image, mask = sample - return image, datapoints.Mask(F.to_image_tensor(mask)) + return image, datapoints.Mask(F.to_image_tensor(mask).squeeze(0)) for dataset_type in [ @@ -131,12 +161,13 @@ def segmentation_wrapper( def caltech101_wrapper( dataset: datasets.Caltech101, sample: Tuple[PIL.Image.Image, Any] ) -> Tuple[PIL.Image.Image, Any]: - if "annotation" in dataset.target_type: - raise_missing_functionality(dataset, "target_type") - image, target = sample - - return image, datapoints.Label(target, categories=dataset.categories) + return image, wrap_target_by_type( + dataset, + target, + {"category": lambda item: datapoints.Label(target, categories=dataset.categories)}, + fail_on=["annotation"], + ) _WRAPPERS[datasets.Caltech101] = caltech101_wrapper @@ -167,31 +198,24 @@ def segmentation_to_mask(segmentation: Any, *, spatial_size: Tuple[int, int]) -> image, target = sample - # Originally, COCODetection returns a list of dicts in which each dict represents an object instance on the image. - # However, our transforms and models expect all instance annotations grouped together, if applicable as tensor with - # batch dimension. Thus, we are changing the target to a dict of lists here. - batched_target = defaultdict(list) - for object in target: - for key, value in object.items(): - batched_target[key].append(value) + batched_target = list_of_dicts_to_dict_of_lists(target) spatial_size = tuple(F.get_spatial_size(image)) - batched_target = dict( - batched_target, - boxes=datapoints.BoundingBox( - batched_target["bbox"], - format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=spatial_size, - ), - masks=datapoints.Mask( - torch.stack( - [ - segmentation_to_mask(segmentation, spatial_size=spatial_size) - for segmentation in batched_target["segmentation"] - ] - ), + batched_target["boxes"] = datapoints.BoundingBox( + batched_target["bbox"], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + ) + batched_target["masks"] = datapoints.Mask( + torch.stack( + [ + segmentation_to_mask(segmentation, spatial_size=spatial_size) + for segmentation in batched_target["segmentation"] + ] ), - labels=datapoints.Label(batched_target["category_id"], categories=get_coco_detection_categories(dataset)), + ) + batched_target["labels"] = datapoints.Label( + batched_target["category_id"], categories=get_coco_detection_categories(dataset) ) return image, batched_target @@ -201,49 +225,47 @@ def segmentation_to_mask(segmentation: Any, *, spatial_size: Tuple[int, int]) -> _WRAPPERS[datasets.CocoCaptions] = lambda sample: sample +VOC_DETECTION_CATEGORIES = [ + "__background__", + "aeroplane", + "bicycle", + "bird", + "boat", + "bottle", + "bus", + "car", + "cat", + "chair", + "cow", + "diningtable", + "dog", + "horse", + "motorbike", + "person", + "pottedplant", + "sheep", + "sofa", + "train", + "tvmonitor", +] +VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES)))) + + def voc_detection_wrapper( dataset: datasets.VOCDetection, sample: Tuple[PIL.Image.Image, Any] ) -> Tuple[PIL.Image.Image, Any]: - categories = [ - "__background__", - "aeroplane", - "bicycle", - "bird", - "boat", - "bottle", - "bus", - "car", - "cat", - "chair", - "cow", - "diningtable", - "dog", - "horse", - "motorbike", - "person", - "pottedplant", - "sheep", - "sofa", - "train", - "tvmonitor", - ] - categories_to_idx = dict(zip(categories, range(len(categories)))) - image, target = sample - batched_instances = defaultdict(list) - for object in target["annotation"]["object"]: - for key, value in object.items(): - batched_instances[key].append(value) + batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"]) target["boxes"] = datapoints.BoundingBox( [[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bndbox in batched_instances["bndbox"]], format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=tuple(int(target["annotation"]["size"][dim]) for dim in ("height", "width")), + spatial_size=(image.height, image.width), ) target["labels"] = datapoints.Label( - [categories_to_idx[category] for category in batched_instances["name"]], - categories=categories, + [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]], + categories=VOC_DETECTION_CATEGORIES, ) return image, target @@ -254,10 +276,121 @@ def voc_detection_wrapper( def sbd_wrapper(dataset: datasets.SBDataset, sample: Tuple[PIL.Image.Image, Any]) -> Tuple[PIL.Image.Image, Any]: if dataset.mode == "boundaries": - raise_missing_functionality(dataset, "mode") + raise RuntimeError( + "SBDataset with mode='boundaries' is currently not supported by this wrapper. " + "If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues." + ) image, target = sample - return image, datapoints.Mask(F.to_image_tensor(target)) + return image, datapoints.Mask(F.to_image_tensor(target).squeeze(0)) _WRAPPERS[datasets.SBDataset] = sbd_wrapper + + +def celeba_wrapper(dataset: datasets.CelebA, sample: Tuple[PIL.Image.Image, Any]) -> Tuple[PIL.Image.Image, Any]: + image, target = sample + return wrap_target_by_type( + dataset, + target, + { + "identity": datapoints.Label, + "bbox": lambda item: datapoints.BoundingBox( + item, format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) + ), + }, + # FIXME: Failing on "attr" here is problematic, since it is the default + fail_on=["attr", "landmarks"], + ) + + +_WRAPPERS[datasets.CelebA] = celeba_wrapper + +KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"] +KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES)))) + + +def kitti_wrapper(dataset: datasets.Kitti, sample): + image, target = sample + + target = list_of_dicts_to_dict_of_lists(target) + + target["boxes"] = datapoints.BoundingBox( + target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width) + ) + target["labels"] = datapoints.Label( + [KITTI_CATEGORY_TO_IDX[category] for category in target["type"]], categories=KITTI_CATEGORIES + ) + + return image, target + + +_WRAPPERS[datasets.Kitti] = kitti_wrapper + + +def oxford_iiit_pet_wrapper( + dataset: datasets.OxfordIIITPet, sample: Tuple[PIL.Image.Image, List] +) -> Tuple[PIL.Image.Image, List]: + image, target = sample + return image, wrap_target_by_type( + dataset, + target, + { + "category": lambda item: datapoints.Label(item, categories=dataset.classes), + "segmentation": lambda item: datapoints.Mask(F.pil_to_tensor(item).squeeze(0)), + }, + attr_name="_target_types", + ) + + +_WRAPPERS[datasets.OxfordIIITPet] = oxford_iiit_pet_wrapper + + +def cityscapes_wrapper( + dataset: datasets.Cityscapes, sample: Tuple[PIL.Image.Image, List] +) -> Tuple[PIL.Image.Image, List]: + def instance_segmentation_wrapper(mask: PIL.Image.Image) -> datapoints.Mask: + # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21 + data = F.pil_to_tensor(mask).squeeze(0) + masks = [] + labels = [] + for id in data.unique(): + masks.append(data == id) + label = id + if label >= 1_000: + label //= 1_000 + labels.append(label) + masks = datapoints.Mask(torch.stack(masks)) + # FIXME: without the labels, returning just the instance masks is pretty useless. However, we would need to + # return a two-tuple or the like where we originally only had a single PIL image. + labels = datapoints.Label(torch.stack(labels), categories=[cls.name for cls in dataset.classes]) + return masks + + image, target = sample + return image, wrap_target_by_type( + dataset, + target, + { + "instance": instance_segmentation_wrapper, + "semantic": lambda item: datapoints.Mask(F.pil_to_tensor(item).squeeze(0)), + }, + fail_on=["polygon", "color"], + ) + + +_WRAPPERS[datasets.Cityscapes] = cityscapes_wrapper + + +def widerface_wrapper( + dataset: datasets.WIDERFace, sample: Tuple[PIL.Image.Image, Optional[Dict[str, torch.Tensor]]] +) -> Tuple[PIL.Image.Image, Optional[Dict[str, torch.Tensor]]]: + image, target = sample + if target is not None: + # FIXME: all returned values inside this dictionary are tensors, but not images + target["bbox"] = datapoints.BoundingBox( + target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) + ) + return image, target + + +_WRAPPERS[datasets.WIDERFace] = widerface_wrapper From 49cc8e796164c26a91e25f99be30425cbdbc83a4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 2 Feb 2023 11:41:03 +0100 Subject: [PATCH 10/25] add Image/DatasetFolder --- torchvision/prototype/datapoints/_dataset_wrapper.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index ae7172eb101..837c8cd18d7 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -73,6 +73,10 @@ def __len__(self) -> int: return len(self._vision_dataset) +def identity(item): + return item + + def list_of_dicts_to_dict_of_lists(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, List]: if not list_of_dicts: return {} @@ -103,7 +107,7 @@ def wrap_target_by_type( target = [target] wrapped_target = tuple( - type_wrappers.get(target_type, lambda item: item)(item) for target_type, item in zip(target_types, target) + type_wrappers.get(target_type, identity)(item) for target_type, item in zip(target_types, target) ) if len(wrapped_target) == 1: @@ -120,6 +124,8 @@ def get_categories(dataset: datasets.VisionDataset) -> Optional[List[str]]: datasets.CIFAR100: lambda dataset: dataset.classes, datasets.FashionMNIST: lambda dataset: dataset.classes, datasets.ImageNet: lambda dataset: [", ".join(names) for names in dataset.classes], + datasets.DatasetFolder: lambda dataset: dataset.classes, + datasets.ImageFolder: lambda dataset: dataset.classes, }.get(type(dataset)) return categories_fn(dataset) if categories_fn is not None else None @@ -141,6 +147,8 @@ def classification_wrapper( datasets.MNIST, datasets.FashionMNIST, datasets.GTSRB, + datasets.DatasetFolder, + datasets.ImageFolder, ]: _WRAPPERS[dataset_type] = classification_wrapper @@ -222,7 +230,7 @@ def segmentation_to_mask(segmentation: Any, *, spatial_size: Tuple[int, int]) -> _WRAPPERS[datasets.CocoDetection] = coco_dectection_wrapper -_WRAPPERS[datasets.CocoCaptions] = lambda sample: sample +_WRAPPERS[datasets.CocoCaptions] = identity VOC_DETECTION_CATEGORIES = [ From 8e12bad8fee3136b0d322b503231ba7e88825627 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 2 Feb 2023 14:29:51 +0100 Subject: [PATCH 11/25] add video datasets --- .../prototype/datapoints/_dataset_wrapper.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 837c8cd18d7..5360bb21f79 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -166,6 +166,29 @@ def segmentation_wrapper( _WRAPPERS[dataset_type] = segmentation_wrapper +def video_classification_wrapper(dataset, sample): + if dataset.video_clips.output_format == "THWC": + raise RuntimeError( + f"{type(dataset).__name__} with output_format='THWC' is not supported by this wrapper, " + f"since it is not compatible with the transformations. Please use output_format='TCHW' instead." + ) + + video, audio, label = sample + + video = datapoints.Video(video) + label = datapoints.Label(label, categories=dataset.classes) + + return video, audio, label + + +for dataset_type in [ + datasets.HMDB51, + datasets.Kinetics, + datasets.UCF101, +]: + _WRAPPERS[dataset_type] = video_classification_wrapper + + def caltech101_wrapper( dataset: datasets.Caltech101, sample: Tuple[PIL.Image.Image, Any] ) -> Tuple[PIL.Image.Image, Any]: From 7a9f0837a6a26d91cb943b5274b2852095ef8892 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 2 Feb 2023 14:45:08 +0100 Subject: [PATCH 12/25] nuke annotations --- .../prototype/datapoints/_dataset_wrapper.py | 65 +++++++------------ 1 file changed, 22 insertions(+), 43 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 5360bb21f79..11ed30176a2 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -4,9 +4,6 @@ import contextlib import functools -from typing import Any, Callable, Dict, List, Optional, Tuple - -import PIL.Image import torch from torch.utils.data import Dataset @@ -24,7 +21,7 @@ # TODO: naming! -def wrap_dataset_for_transforms_v2(dataset: datasets.VisionDataset) -> _VisionDatasetDatapointWrapper: +def wrap_dataset_for_transforms_v2(dataset): dataset_cls = type(dataset) wrapper = _WRAPPERS.get(dataset_cls) if wrapper is None: @@ -40,7 +37,7 @@ def wrap_dataset_for_transforms_v2(dataset: datasets.VisionDataset) -> _VisionDa class _VisionDatasetDatapointWrapper(Dataset): - def __init__(self, dataset: datasets.VisionDataset, wrapper) -> None: + def __init__(self, dataset, wrapper): self._vision_dataset = dataset self._wrapper = wrapper @@ -48,13 +45,13 @@ def __init__(self, dataset: datasets.VisionDataset, wrapper) -> None: # transforms self.transforms, dataset.transforms = dataset.transforms, None - def __getattr__(self, item: str) -> Any: + def __getattr__(self, item): with contextlib.suppress(AttributeError): return object.__getattribute__(self, item) return getattr(self._vision_dataset, item) - def __getitem__(self, idx: int) -> Any: + def __getitem__(self, idx): # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor # of this class sample = self._vision_dataset[idx] @@ -69,7 +66,7 @@ def __getitem__(self, idx: int) -> Any: return sample - def __len__(self) -> int: + def __len__(self): return len(self._vision_dataset) @@ -77,7 +74,7 @@ def identity(item): return item -def list_of_dicts_to_dict_of_lists(list_of_dicts: List[Dict[str, Any]]) -> Dict[str, List]: +def list_of_dicts_to_dict_of_lists(list_of_dicts): if not list_of_dicts: return {} @@ -88,9 +85,7 @@ def list_of_dicts_to_dict_of_lists(list_of_dicts: List[Dict[str, Any]]) -> Dict[ return dict_of_lists -def wrap_target_by_type( - dataset, target, type_wrappers: Dict[str, Callable], *, fail_on=(), attr_name: str = "target_type" -): +def wrap_target_by_type(dataset, target, type_wrappers, *, fail_on=(), attr_name="target_type"): if target is None: return None @@ -117,7 +112,7 @@ def wrap_target_by_type( @cache -def get_categories(dataset: datasets.VisionDataset) -> Optional[List[str]]: +def get_categories(dataset): categories_fn = { datasets.Caltech256: lambda dataset: [name.rsplit(".", 1)[1] for name in dataset.categories], datasets.CIFAR10: lambda dataset: dataset.classes, @@ -130,9 +125,7 @@ def get_categories(dataset: datasets.VisionDataset) -> Optional[List[str]]: return categories_fn(dataset) if categories_fn is not None else None -def classification_wrapper( - dataset: datasets.VisionDataset, sample: Tuple[PIL.Image.Image, Optional[int]] -) -> Tuple[PIL.Image.Image, Optional[datapoints.Label]]: +def classification_wrapper(dataset, sample): image, label = sample if label is not None: label = datapoints.Label(label, categories=get_categories(dataset)) @@ -153,9 +146,7 @@ def classification_wrapper( _WRAPPERS[dataset_type] = classification_wrapper -def segmentation_wrapper( - dataset: datasets.VisionDataset, sample: Tuple[PIL.Image.Image, PIL.Image.Image] -) -> Tuple[PIL.Image.Image, datapoints.Mask]: +def segmentation_wrapper(dataset, sample): image, mask = sample return image, datapoints.Mask(F.to_image_tensor(mask).squeeze(0)) @@ -189,9 +180,7 @@ def video_classification_wrapper(dataset, sample): _WRAPPERS[dataset_type] = video_classification_wrapper -def caltech101_wrapper( - dataset: datasets.Caltech101, sample: Tuple[PIL.Image.Image, Any] -) -> Tuple[PIL.Image.Image, Any]: +def caltech101_wrapper(dataset, sample): image, target = sample return image, wrap_target_by_type( dataset, @@ -205,7 +194,7 @@ def caltech101_wrapper( @cache -def get_coco_detection_categories(dataset: datasets.CocoDetection): +def get_coco_detection_categories(dataset): idx_to_category = {idx: cat["name"] for idx, cat in dataset.coco.cats.items()} idx_to_category[0] = "__background__" for idx in set(range(91)) - idx_to_category.keys(): @@ -214,10 +203,8 @@ def get_coco_detection_categories(dataset: datasets.CocoDetection): return [category for _, category in sorted(idx_to_category.items())] -def coco_dectection_wrapper( - dataset: datasets.CocoDetection, sample: Tuple[PIL.Image.Image, List[Dict[str, Any]]] -) -> Tuple[PIL.Image.Image, Dict[str, List[Any]]]: - def segmentation_to_mask(segmentation: Any, *, spatial_size: Tuple[int, int]) -> torch.Tensor: +def coco_dectection_wrapper(dataset, sample): + def segmentation_to_mask(segmentation, *, spatial_size): from pycocotools import mask segmentation = ( @@ -282,9 +269,7 @@ def segmentation_to_mask(segmentation: Any, *, spatial_size: Tuple[int, int]) -> VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES)))) -def voc_detection_wrapper( - dataset: datasets.VOCDetection, sample: Tuple[PIL.Image.Image, Any] -) -> Tuple[PIL.Image.Image, Any]: +def voc_detection_wrapper(dataset, sample): image, target = sample batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"]) @@ -305,7 +290,7 @@ def voc_detection_wrapper( _WRAPPERS[datasets.VOCDetection] = voc_detection_wrapper -def sbd_wrapper(dataset: datasets.SBDataset, sample: Tuple[PIL.Image.Image, Any]) -> Tuple[PIL.Image.Image, Any]: +def sbd_wrapper(dataset, sample): if dataset.mode == "boundaries": raise RuntimeError( "SBDataset with mode='boundaries' is currently not supported by this wrapper. " @@ -319,7 +304,7 @@ def sbd_wrapper(dataset: datasets.SBDataset, sample: Tuple[PIL.Image.Image, Any] _WRAPPERS[datasets.SBDataset] = sbd_wrapper -def celeba_wrapper(dataset: datasets.CelebA, sample: Tuple[PIL.Image.Image, Any]) -> Tuple[PIL.Image.Image, Any]: +def celeba_wrapper(dataset, sample): image, target = sample return wrap_target_by_type( dataset, @@ -341,7 +326,7 @@ def celeba_wrapper(dataset: datasets.CelebA, sample: Tuple[PIL.Image.Image, Any] KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES)))) -def kitti_wrapper(dataset: datasets.Kitti, sample): +def kitti_wrapper(dataset, sample): image, target = sample target = list_of_dicts_to_dict_of_lists(target) @@ -359,9 +344,7 @@ def kitti_wrapper(dataset: datasets.Kitti, sample): _WRAPPERS[datasets.Kitti] = kitti_wrapper -def oxford_iiit_pet_wrapper( - dataset: datasets.OxfordIIITPet, sample: Tuple[PIL.Image.Image, List] -) -> Tuple[PIL.Image.Image, List]: +def oxford_iiit_pet_wrapper(dataset, sample): image, target = sample return image, wrap_target_by_type( dataset, @@ -377,10 +360,8 @@ def oxford_iiit_pet_wrapper( _WRAPPERS[datasets.OxfordIIITPet] = oxford_iiit_pet_wrapper -def cityscapes_wrapper( - dataset: datasets.Cityscapes, sample: Tuple[PIL.Image.Image, List] -) -> Tuple[PIL.Image.Image, List]: - def instance_segmentation_wrapper(mask: PIL.Image.Image) -> datapoints.Mask: +def cityscapes_wrapper(dataset, sample): + def instance_segmentation_wrapper(mask): # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21 data = F.pil_to_tensor(mask).squeeze(0) masks = [] @@ -412,9 +393,7 @@ def instance_segmentation_wrapper(mask: PIL.Image.Image) -> datapoints.Mask: _WRAPPERS[datasets.Cityscapes] = cityscapes_wrapper -def widerface_wrapper( - dataset: datasets.WIDERFace, sample: Tuple[PIL.Image.Image, Optional[Dict[str, torch.Tensor]]] -) -> Tuple[PIL.Image.Image, Optional[Dict[str, torch.Tensor]]]: +def widerface_wrapper(dataset, sample): image, target = sample if target is not None: # FIXME: all returned values inside this dictionary are tensors, but not images From 7f7efd51c934f17c63245b5b2dfeff0f4b1a5696 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 2 Feb 2023 14:54:52 +0100 Subject: [PATCH 13/25] reinstate transform and target_transform disabling --- .../prototype/datapoints/_dataset_wrapper.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 11ed30176a2..a460f6d72a8 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -41,8 +41,14 @@ def __init__(self, dataset, wrapper): self._vision_dataset = dataset self._wrapper = wrapper - # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply the - # transforms + # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them. + # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint + # `transforms` + # https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54 + # some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to + # disable all three here to be able to extract the untransformed sample to wrap. + self.transform, dataset.transform = dataset.transform, None + self.target_transform, dataset.target_transform = dataset.target_transform, None self.transforms, dataset.transforms = dataset.transforms, None def __getattr__(self, item): @@ -58,9 +64,8 @@ def __getitem__(self, idx): sample = self._wrapper(self._vision_dataset, sample) - # We don't need to care about `transform` and `target_transform` here since `VisionDataset` joins them into a - # `transforms` internally: - # https://github.com/pytorch/vision/blob/2d92728341bbd3dc1e0f1e86c6a436049bbb3403/torchvision/datasets/vision.py#L52-L54 + # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`) + # or joint (`transforms`), we can access the full functionality through `transforms` if self.transforms is not None: sample = self.transforms(*sample) From e6f2b681d1b5a81d1a4175e4c948a1e46b3d82ba Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 3 Feb 2023 15:01:55 +0100 Subject: [PATCH 14/25] address minor comments --- .../prototype/datapoints/_dataset_wrapper.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index a460f6d72a8..fdafda10d63 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -4,6 +4,7 @@ import contextlib import functools +from collections import defaultdict import torch from torch.utils.data import Dataset @@ -80,21 +81,18 @@ def identity(item): def list_of_dicts_to_dict_of_lists(list_of_dicts): - if not list_of_dicts: - return {} - - dict_of_lists = {key: [value] for key, value in list_of_dicts[0].items()} - for dct in list_of_dicts[1:]: + dict_of_lists = defaultdict(list) + for dct in list_of_dicts: for key, value in dct.items(): dict_of_lists[key].append(value) - return dict_of_lists + return dict(dict_of_lists) -def wrap_target_by_type(dataset, target, type_wrappers, *, fail_on=(), attr_name="target_type"): +def wrap_target_by_type(dataset, target, type_wrappers, *, fail_on=()): if target is None: return None - target_types = getattr(dataset, attr_name) + target_types = next(getattr(dataset, attr) for attr in ["target_type", "_target_types"]) if any(target_type in fail_on for target_type in target_types): raise RuntimeError( @@ -358,7 +356,6 @@ def oxford_iiit_pet_wrapper(dataset, sample): "category": lambda item: datapoints.Label(item, categories=dataset.classes), "segmentation": lambda item: datapoints.Mask(F.pil_to_tensor(item).squeeze(0)), }, - attr_name="_target_types", ) From 22288ce6b871b99c129e4d7268146be6c849abeb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 13:03:35 +0100 Subject: [PATCH 15/25] remove categories and refactor wrapping architecture --- .../prototype/datapoints/_dataset_wrapper.py | 346 +++++++++--------- 1 file changed, 170 insertions(+), 176 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index fdafda10d63..f9debb9f991 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -3,29 +3,23 @@ from __future__ import annotations import contextlib -import functools from collections import defaultdict import torch from torch.utils.data import Dataset from torchvision import datasets -from torchvision._utils import sequence_to_str from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F __all__ = ["wrap_dataset_for_transforms_v2"] -cache = functools.partial(functools.lru_cache) - -_WRAPPERS = {} - # TODO: naming! def wrap_dataset_for_transforms_v2(dataset): dataset_cls = type(dataset) - wrapper = _WRAPPERS.get(dataset_cls) - if wrapper is None: + wrapper_factory = WRAPPER_FACTORIES.get(dataset_cls) + if wrapper_factory is None: # TODO: If we have documentation on how to do that, put a link in the error message. msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." if dataset_cls in datasets.__dict__.values(): @@ -34,10 +28,22 @@ def wrap_dataset_for_transforms_v2(dataset): f"please open an issue at https://github.com/pytorch/vision/issues." ) raise ValueError(msg) - return _VisionDatasetDatapointWrapper(dataset, wrapper) + return VisionDatasetDatapointWrapper(dataset, wrapper_factory(dataset)) + + +class WrapperFactories(dict): + def register(self, dataset_cls): + def decorator(wrapper_factory): + self[dataset_cls] = wrapper_factory + return wrapper_factory + + return decorator + +WRAPPER_FACTORIES = WrapperFactories() -class _VisionDatasetDatapointWrapper(Dataset): + +class VisionDatasetDatapointWrapper(Dataset): def __init__(self, dataset, wrapper): self._vision_dataset = dataset self._wrapper = wrapper @@ -63,7 +69,7 @@ def __getitem__(self, idx): # of this class sample = self._vision_dataset[idx] - sample = self._wrapper(self._vision_dataset, sample) + sample = self._wrapper(sample) # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`) # or joint (`transforms`), we can access the full functionality through `transforms` @@ -88,18 +94,9 @@ def list_of_dicts_to_dict_of_lists(list_of_dicts): return dict(dict_of_lists) -def wrap_target_by_type(dataset, target, type_wrappers, *, fail_on=()): - if target is None: - return None - - target_types = next(getattr(dataset, attr) for attr in ["target_type", "_target_types"]) - - if any(target_type in fail_on for target_type in target_types): - raise RuntimeError( - f"{type(dataset).__name__} with target type(s) {sequence_to_str(fail_on, separate_last='or ')} " - f"is currently not supported by this wrapper. " - f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues." - ) +def wrap_target_by_type(target, *, target_types, type_wrappers=None): + if type_wrappers is None: + type_wrappers = dict() if not isinstance(target, (tuple, list)): target = [target] @@ -114,25 +111,8 @@ def wrap_target_by_type(dataset, target, type_wrappers, *, fail_on=()): return wrapped_target -@cache -def get_categories(dataset): - categories_fn = { - datasets.Caltech256: lambda dataset: [name.rsplit(".", 1)[1] for name in dataset.categories], - datasets.CIFAR10: lambda dataset: dataset.classes, - datasets.CIFAR100: lambda dataset: dataset.classes, - datasets.FashionMNIST: lambda dataset: dataset.classes, - datasets.ImageNet: lambda dataset: [", ".join(names) for names in dataset.classes], - datasets.DatasetFolder: lambda dataset: dataset.classes, - datasets.ImageFolder: lambda dataset: dataset.classes, - }.get(type(dataset)) - return categories_fn(dataset) if categories_fn is not None else None - - -def classification_wrapper(dataset, sample): - image, label = sample - if label is not None: - label = datapoints.Label(label, categories=get_categories(dataset)) - return image, label +def classification_wrapper_factory(dataset): + return identity for dataset_type in [ @@ -146,33 +126,38 @@ def classification_wrapper(dataset, sample): datasets.DatasetFolder, datasets.ImageFolder, ]: - _WRAPPERS[dataset_type] = classification_wrapper + WRAPPER_FACTORIES[dataset_type] = classification_wrapper_factory + +def segmentation_wrapper_factory(dataset): + def wrapper(sample): + image, mask = sample + return image, datapoints.Mask(F.to_image_tensor(mask).squeeze(0)) -def segmentation_wrapper(dataset, sample): - image, mask = sample - return image, datapoints.Mask(F.to_image_tensor(mask).squeeze(0)) + return wrapper for dataset_type in [ datasets.VOCSegmentation, ]: - _WRAPPERS[dataset_type] = segmentation_wrapper + WRAPPER_FACTORIES[dataset_type] = segmentation_wrapper_factory -def video_classification_wrapper(dataset, sample): +def video_classification_wrapper_factory(dataset): if dataset.video_clips.output_format == "THWC": raise RuntimeError( - f"{type(dataset).__name__} with output_format='THWC' is not supported by this wrapper, " - f"since it is not compatible with the transformations. Please use output_format='TCHW' instead." + f"{type(dataset).__name__} with `output_format='THWC'` is not supported by this wrapper, " + f"since it is not compatible with the transformations. Please use `output_format='TCHW'` instead." ) - video, audio, label = sample + def wrapper(sample): + video, audio, label = sample + + video = datapoints.Video(video) - video = datapoints.Video(video) - label = datapoints.Label(label, categories=dataset.classes) + return video, audio, label - return video, audio, label + return wrapper for dataset_type in [ @@ -180,33 +165,31 @@ def video_classification_wrapper(dataset, sample): datasets.Kinetics, datasets.UCF101, ]: - _WRAPPERS[dataset_type] = video_classification_wrapper + WRAPPER_FACTORIES[dataset_type] = video_classification_wrapper_factory -def caltech101_wrapper(dataset, sample): - image, target = sample - return image, wrap_target_by_type( - dataset, - target, - {"category": lambda item: datapoints.Label(target, categories=dataset.categories)}, - fail_on=["annotation"], +def raise_not_supported(description): + raise RuntimeError( + f"{description} is currently not supported by this wrapper. " + f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues." ) -_WRAPPERS[datasets.Caltech101] = caltech101_wrapper +@WRAPPER_FACTORIES.register(datasets.Caltech101) +def caltech101_wrapper_factory(dataset): + if "annotation" in dataset.target_type: + raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`") + def wrapper(sample): + image, target = sample -@cache -def get_coco_detection_categories(dataset): - idx_to_category = {idx: cat["name"] for idx, cat in dataset.coco.cats.items()} - idx_to_category[0] = "__background__" - for idx in set(range(91)) - idx_to_category.keys(): - idx_to_category[idx] = "N/A" + target = wrap_target_by_type(target, target_types=dataset.target_type) - return [category for _, category in sorted(idx_to_category.items())] + return classification_wrapper_factory(dataset) -def coco_dectection_wrapper(dataset, sample): +@WRAPPER_FACTORIES.register(datasets.CocoDetection) +def coco_dectection_wrapper_factory(dataset): def segmentation_to_mask(segmentation, *, spatial_size): from pycocotools import mask @@ -217,33 +200,30 @@ def segmentation_to_mask(segmentation, *, spatial_size): ) return torch.from_numpy(mask.decode(segmentation)) - image, target = sample + def wrapper(sample): + image, target = sample - batched_target = list_of_dicts_to_dict_of_lists(target) + batched_target = list_of_dicts_to_dict_of_lists(target) - spatial_size = tuple(F.get_spatial_size(image)) - batched_target["boxes"] = datapoints.BoundingBox( - batched_target["bbox"], - format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=spatial_size, - ) - batched_target["masks"] = datapoints.Mask( - torch.stack( - [ - segmentation_to_mask(segmentation, spatial_size=spatial_size) - for segmentation in batched_target["segmentation"] - ] - ), - ) - batched_target["labels"] = datapoints.Label( - batched_target["category_id"], categories=get_coco_detection_categories(dataset) - ) - - return image, batched_target + spatial_size = tuple(F.get_spatial_size(image)) + batched_target["boxes"] = datapoints.BoundingBox( + batched_target["bbox"], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=spatial_size, + ) + batched_target["masks"] = datapoints.Mask( + torch.stack( + [ + segmentation_to_mask(segmentation, spatial_size=spatial_size) + for segmentation in batched_target["segmentation"] + ] + ), + ) + batched_target["labels"] = torch.tensor(batched_target["category_id"]) + return image, batched_target -_WRAPPERS[datasets.CocoDetection] = coco_dectection_wrapper -_WRAPPERS[datasets.CocoCaptions] = identity + return wrapper VOC_DETECTION_CATEGORIES = [ @@ -272,97 +252,106 @@ def segmentation_to_mask(segmentation, *, spatial_size): VOC_DETECTION_CATEGORY_TO_IDX = dict(zip(VOC_DETECTION_CATEGORIES, range(len(VOC_DETECTION_CATEGORIES)))) -def voc_detection_wrapper(dataset, sample): - image, target = sample +@WRAPPER_FACTORIES.register(datasets.VOCDetection) +def voc_detection_wrapper_factory(dataset, sample): + def wrapper(sample): + image, target = sample - batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"]) + batched_instances = list_of_dicts_to_dict_of_lists(target["annotation"]["object"]) - target["boxes"] = datapoints.BoundingBox( - [[int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bndbox in batched_instances["bndbox"]], - format=datapoints.BoundingBoxFormat.XYXY, - spatial_size=(image.height, image.width), - ) - target["labels"] = datapoints.Label( - [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]], - categories=VOC_DETECTION_CATEGORIES, - ) - - return image, target + target["boxes"] = datapoints.BoundingBox( + [ + [int(bndbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] + for bndbox in batched_instances["bndbox"] + ], + format=datapoints.BoundingBoxFormat.XYXY, + spatial_size=(image.height, image.width), + ) + target["labels"] = torch.tensor( + [VOC_DETECTION_CATEGORY_TO_IDX[category] for category in batched_instances["name"]] + ) + return image, target -_WRAPPERS[datasets.VOCDetection] = voc_detection_wrapper + return wrapper -def sbd_wrapper(dataset, sample): +@WRAPPER_FACTORIES.register(datasets.SBDataset) +def sbd_wrapper(dataset): if dataset.mode == "boundaries": - raise RuntimeError( - "SBDataset with mode='boundaries' is currently not supported by this wrapper. " - "If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues." - ) + raise_not_supported("SBDataset with mode='boundaries'") - image, target = sample - return image, datapoints.Mask(F.to_image_tensor(target).squeeze(0)) + return segmentation_wrapper_factory(dataset) -_WRAPPERS[datasets.SBDataset] = sbd_wrapper +@WRAPPER_FACTORIES.register(datasets.CelebA) +def celeba_wrapper_factory(dataset): + if any(target_type in dataset.target_type for target_type in ["attr", "landmarks"]): + raise_not_supported("`CelebA` dataset with `target_type=['attr', 'landmarks', ...]`") + def wrapper(sample): + image, target = sample -def celeba_wrapper(dataset, sample): - image, target = sample - return wrap_target_by_type( - dataset, - target, - { - "identity": datapoints.Label, - "bbox": lambda item: datapoints.BoundingBox( - item, format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) - ), - }, - # FIXME: Failing on "attr" here is problematic, since it is the default - fail_on=["attr", "landmarks"], - ) + target = wrap_target_by_type( + target, + target_types=dataset.target_type, + type_wrappers={ + "bbox": lambda item: datapoints.BoundingBox( + item, format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) + ), + }, + ) + + return image, target + return wrapper -_WRAPPERS[datasets.CelebA] = celeba_wrapper KITTI_CATEGORIES = ["Car", "Van", "Truck", "Pedestrian", "Person_sitting", "Cyclist", "Tram", "Misc", "DontCare"] KITTI_CATEGORY_TO_IDX = dict(zip(KITTI_CATEGORIES, range(len(KITTI_CATEGORIES)))) -def kitti_wrapper(dataset, sample): - image, target = sample +@WRAPPER_FACTORIES.register(datasets.Kitti) +def kitti_wrapper_factory(dataset): + def wrapper(sample): + image, target = sample - target = list_of_dicts_to_dict_of_lists(target) + target = list_of_dicts_to_dict_of_lists(target) - target["boxes"] = datapoints.BoundingBox( - target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width) - ) - target["labels"] = datapoints.Label( - [KITTI_CATEGORY_TO_IDX[category] for category in target["type"]], categories=KITTI_CATEGORIES - ) + target["boxes"] = datapoints.BoundingBox( + target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width) + ) + target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in target["type"]]) - return image, target + return image, target + return wrapper -_WRAPPERS[datasets.Kitti] = kitti_wrapper +@WRAPPER_FACTORIES.register(datasets.OxfordIIITPet) +def oxford_iiit_pet_wrapper_factor(dataset): + def wrapper(sample): + image, target = sample -def oxford_iiit_pet_wrapper(dataset, sample): - image, target = sample - return image, wrap_target_by_type( - dataset, - target, - { - "category": lambda item: datapoints.Label(item, categories=dataset.classes), - "segmentation": lambda item: datapoints.Mask(F.pil_to_tensor(item).squeeze(0)), - }, - ) + if target is not None: + target = wrap_target_by_type( + target, + target_types=dataset._target_types, + type_wrappers={ + "segmentation": lambda item: datapoints.Mask(F.pil_to_tensor(item).squeeze(0)), + }, + ) + + return image, target + return wrapper -_WRAPPERS[datasets.OxfordIIITPet] = oxford_iiit_pet_wrapper +@WRAPPER_FACTORIES.register(datasets.Cityscapes) +def cityscapes_wrapper_factory(dataset): + if any(target_type in dataset.target_type for target_type in ["polygon", "color"]): + raise_not_supported("`Cityscapes` dataset with `target_type=['polygon', 'color', ...]`") -def cityscapes_wrapper(dataset, sample): def instance_segmentation_wrapper(mask): # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21 data = F.pil_to_tensor(mask).squeeze(0) @@ -377,32 +366,37 @@ def instance_segmentation_wrapper(mask): masks = datapoints.Mask(torch.stack(masks)) # FIXME: without the labels, returning just the instance masks is pretty useless. However, we would need to # return a two-tuple or the like where we originally only had a single PIL image. - labels = datapoints.Label(torch.stack(labels), categories=[cls.name for cls in dataset.classes]) + labels = datapoints.Label(torch.stack(labels)) return masks - image, target = sample - return image, wrap_target_by_type( - dataset, - target, - { - "instance": instance_segmentation_wrapper, - "semantic": lambda item: datapoints.Mask(F.pil_to_tensor(item).squeeze(0)), - }, - fail_on=["polygon", "color"], - ) + def wrapper(sample): + image, target = sample + + target = wrap_target_by_type( + target, + target_types=dataset.target_type, + type_wrappers={ + "instance": instance_segmentation_wrapper, + # FIXME: pil_image_to_mask + "semantic": lambda item: datapoints.Mask(F.pil_to_tensor(item).squeeze(0)), + }, + ) + return image, target -_WRAPPERS[datasets.Cityscapes] = cityscapes_wrapper + return wrapper -def widerface_wrapper(dataset, sample): - image, target = sample - if target is not None: - # FIXME: all returned values inside this dictionary are tensors, but not images - target["bbox"] = datapoints.BoundingBox( - target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) - ) - return image, target +@WRAPPER_FACTORIES.register(datasets.WIDERFace) +def widerface_wrapper(dataset): + def wrapper(sample): + image, target = sample + + if target is not None: + target["bbox"] = datapoints.BoundingBox( + target["bbox"], format=datapoints.BoundingBoxFormat.XYWH, spatial_size=(image.height, image.width) + ) + return image, target -_WRAPPERS[datasets.WIDERFace] = widerface_wrapper + return wrapper From a88aec312b1e17c0204a74d57b22486cab4e639e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 15:08:29 +0100 Subject: [PATCH 16/25] add tests --- test/datasets_utils.py | 30 +++++++++++++++++++ test/test_datasets.py | 18 +++++++---- .../prototype/datapoints/_dataset_wrapper.py | 27 +++++++++-------- 3 files changed, 57 insertions(+), 18 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 1f186650ad0..58d596ba0e0 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -581,6 +581,27 @@ def test_transforms(self, config): mock.assert_called() + @test_all_configs + def test_transforms_v2_wrapper(self, config): + # This is stable test, so we can't depend on prototype stuff + try: + from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2 + except ImportError: + return + + try: + with self.create_dataset(config) as (dataset, _): + wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) + wrapped_dataset[0] + except ValueError as error: + if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"): + return + raise error + except RuntimeError as error: + if "currently not supported by this wrapper" in str(error): + return + raise error + class ImageDatasetTestCase(DatasetTestCase): """Abstract base class for image dataset testcases. @@ -662,6 +683,15 @@ def wrapper(tmpdir, config): return wrapper + @test_all_configs + def test_transforms_v2_wrapper(self, config): + # `output_format == "THWC"` is not supported by the wrapper. Thus, we skip the `config` if it is set explicitly + # or use the supported `"TCHW"` + if config.setdefault("output_format", "TCHW") == "THWC": + return + + super().test_transforms_v2_wrapper.__wrapped__(self, config) + def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: r"""Create a random uint8 tensor. diff --git a/test/test_datasets.py b/test/test_datasets.py index bd6d1dcb259..25e9a6ab268 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -763,11 +763,19 @@ def _create_annotation_file(self, root, name, file_names, num_annotations_per_im return info def _create_annotations(self, image_ids, num_annotations_per_image): - annotations = datasets_utils.combinations_grid( - image_id=image_ids, bbox=([1.0, 2.0, 3.0, 4.0],) * num_annotations_per_image - ) - for id, annotation in enumerate(annotations): - annotation["id"] = id + annotations = [] + annotion_id = 0 + for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image): + annotations.append( + dict( + image_id=image_id, + id=annotion_id, + bbox=torch.rand(4).tolist(), + segmentation=[torch.rand(8).tolist()], + category_id=int(torch.randint(91, ())), + ) + ) + annotion_id += 1 return annotations, dict() def _create_json(self, root, name, content): diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index f9debb9f991..2d70e6ccb10 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -82,6 +82,13 @@ def __len__(self): return len(self._vision_dataset) +def raise_not_supported(description): + raise RuntimeError( + f"{description} is currently not supported by this wrapper. " + f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues." + ) + + def identity(item): return item @@ -168,13 +175,6 @@ def wrapper(sample): WRAPPER_FACTORIES[dataset_type] = video_classification_wrapper_factory -def raise_not_supported(description): - raise RuntimeError( - f"{description} is currently not supported by this wrapper. " - f"If this would be helpful for you, please open an issue at https://github.com/pytorch/vision/issues." - ) - - @WRAPPER_FACTORIES.register(datasets.Caltech101) def caltech101_wrapper_factory(dataset): if "annotation" in dataset.target_type: @@ -253,7 +253,7 @@ def wrapper(sample): @WRAPPER_FACTORIES.register(datasets.VOCDetection) -def voc_detection_wrapper_factory(dataset, sample): +def voc_detection_wrapper_factory(dataset): def wrapper(sample): image, target = sample @@ -316,12 +316,13 @@ def kitti_wrapper_factory(dataset): def wrapper(sample): image, target = sample - target = list_of_dicts_to_dict_of_lists(target) + if target is not None: + target = list_of_dicts_to_dict_of_lists(target) - target["boxes"] = datapoints.BoundingBox( - target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width) - ) - target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in target["type"]]) + target["boxes"] = datapoints.BoundingBox( + target["bbox"], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(image.height, image.width) + ) + target["labels"] = torch.tensor([KITTI_CATEGORY_TO_IDX[category] for category in target["type"]]) return image, target From ce740c10e7ed9fc67729e982eae137579663b104 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 15:11:58 +0100 Subject: [PATCH 17/25] cleanup --- test/datasets_utils.py | 3 ++- .../prototype/datapoints/_dataset_wrapper.py | 15 ++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 58d596ba0e0..7121813fe91 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -592,7 +592,8 @@ def test_transforms_v2_wrapper(self, config): try: with self.create_dataset(config) as (dataset, _): wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) - wrapped_dataset[0] + wrapped_sample = wrapped_dataset[0] + assert wrapped_sample is not None except ValueError as error: if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"): return diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 2d70e6ccb10..f231c01dc71 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -93,6 +93,10 @@ def identity(item): return item +def pil_image_to_mask(pil_image): + return datapoints.Mask(F.to_image_tensor(pil_image).squeeze(0)) + + def list_of_dicts_to_dict_of_lists(list_of_dicts): dict_of_lists = defaultdict(list) for dct in list_of_dicts: @@ -139,7 +143,7 @@ def classification_wrapper_factory(dataset): def segmentation_wrapper_factory(dataset): def wrapper(sample): image, mask = sample - return image, datapoints.Mask(F.to_image_tensor(mask).squeeze(0)) + return image, pil_image_to_mask(mask) return wrapper @@ -185,7 +189,9 @@ def wrapper(sample): target = wrap_target_by_type(target, target_types=dataset.target_type) - return classification_wrapper_factory(dataset) + return image, target + + return wrapper @WRAPPER_FACTORIES.register(datasets.CocoDetection) @@ -339,7 +345,7 @@ def wrapper(sample): target, target_types=dataset._target_types, type_wrappers={ - "segmentation": lambda item: datapoints.Mask(F.pil_to_tensor(item).squeeze(0)), + "segmentation": pil_image_to_mask, }, ) @@ -378,8 +384,7 @@ def wrapper(sample): target_types=dataset.target_type, type_wrappers={ "instance": instance_segmentation_wrapper, - # FIXME: pil_image_to_mask - "semantic": lambda item: datapoints.Mask(F.pil_to_tensor(item).squeeze(0)), + "semantic": pil_image_to_mask, }, ) From 3398822714a075237c9367c5b469a6d617d10a7d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 15:28:29 +0100 Subject: [PATCH 18/25] remove GenericDatapoint --- torchvision/prototype/datapoints/__init__.py | 2 +- torchvision/prototype/datapoints/_datapoint.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/torchvision/prototype/datapoints/__init__.py b/torchvision/prototype/datapoints/__init__.py index ff6c44ab108..554088b912a 100644 --- a/torchvision/prototype/datapoints/__init__.py +++ b/torchvision/prototype/datapoints/__init__.py @@ -1,5 +1,5 @@ from ._bounding_box import BoundingBox, BoundingBoxFormat -from ._datapoint import FillType, FillTypeJIT, GenericDatapoint, InputType, InputTypeJIT +from ._datapoint import FillType, FillTypeJIT, InputType, InputTypeJIT from ._image import Image, ImageType, ImageTypeJIT, TensorImageType, TensorImageTypeJIT from ._label import Label, OneHotLabel from ._mask import Mask diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/prototype/datapoints/_datapoint.py index 37b8a4261ca..fbd19ad86f1 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -274,7 +274,3 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N InputType = Union[torch.Tensor, PIL.Image.Image, Datapoint] InputTypeJIT = torch.Tensor - - -class GenericDatapoint(Datapoint): - pass From 331a66d94f072fb4926e7be2b14e81370b7feb1e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 16:51:44 +0100 Subject: [PATCH 19/25] move wrapper instantiation into the class --- .../prototype/datapoints/_dataset_wrapper.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index f231c01dc71..2bbd607e1d5 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -17,18 +17,7 @@ # TODO: naming! def wrap_dataset_for_transforms_v2(dataset): - dataset_cls = type(dataset) - wrapper_factory = WRAPPER_FACTORIES.get(dataset_cls) - if wrapper_factory is None: - # TODO: If we have documentation on how to do that, put a link in the error message. - msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." - if dataset_cls in datasets.__dict__.values(): - msg = ( - f"{msg} If an automated wrapper for this dataset would be useful for you, " - f"please open an issue at https://github.com/pytorch/vision/issues." - ) - raise ValueError(msg) - return VisionDatasetDatapointWrapper(dataset, wrapper_factory(dataset)) + return VisionDatasetDatapointWrapper(dataset) class WrapperFactories(dict): @@ -44,9 +33,21 @@ def decorator(wrapper_factory): class VisionDatasetDatapointWrapper(Dataset): - def __init__(self, dataset, wrapper): - self._vision_dataset = dataset - self._wrapper = wrapper + def __init__(self, dataset): + dataset_cls = type(dataset) + wrapper_factory = WRAPPER_FACTORIES.get(dataset_cls) + if wrapper_factory is None: + # TODO: If we have documentation on how to do that, put a link in the error message. + msg = f"No wrapper exist for dataset class {dataset_cls.__name__}. Please wrap the output yourself." + if dataset_cls in datasets.__dict__.values(): + msg = ( + f"{msg} If an automated wrapper for this dataset would be useful for you, " + f"please open an issue at https://github.com/pytorch/vision/issues." + ) + raise TypeError(msg) + + self._dataset = dataset + self._wrapper = wrapper_factory(dataset) # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them. # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint @@ -62,12 +63,12 @@ def __getattr__(self, item): with contextlib.suppress(AttributeError): return object.__getattribute__(self, item) - return getattr(self._vision_dataset, item) + return getattr(self._dataset, item) def __getitem__(self, idx): # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor # of this class - sample = self._vision_dataset[idx] + sample = self._dataset[idx] sample = self._wrapper(sample) @@ -79,7 +80,7 @@ def __getitem__(self, idx): return sample def __len__(self): - return len(self._vision_dataset) + return len(self._dataset) def raise_not_supported(description): From 48405b8e0cd8e77355b4ad55b2a0db0113df04c5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 16:53:20 +0100 Subject: [PATCH 20/25] use decorator registering everywhere --- torchvision/prototype/datapoints/_dataset_wrapper.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 2bbd607e1d5..2673f30f82e 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -127,7 +127,7 @@ def classification_wrapper_factory(dataset): return identity -for dataset_type in [ +for dataset_cls in [ datasets.Caltech256, datasets.CIFAR10, datasets.CIFAR100, @@ -138,7 +138,7 @@ def classification_wrapper_factory(dataset): datasets.DatasetFolder, datasets.ImageFolder, ]: - WRAPPER_FACTORIES[dataset_type] = classification_wrapper_factory + WRAPPER_FACTORIES.register(dataset_cls)(classification_wrapper_factory) def segmentation_wrapper_factory(dataset): @@ -149,10 +149,10 @@ def wrapper(sample): return wrapper -for dataset_type in [ +for dataset_cls in [ datasets.VOCSegmentation, ]: - WRAPPER_FACTORIES[dataset_type] = segmentation_wrapper_factory + WRAPPER_FACTORIES.register(dataset_cls)(segmentation_wrapper_factory) def video_classification_wrapper_factory(dataset): @@ -172,12 +172,12 @@ def wrapper(sample): return wrapper -for dataset_type in [ +for dataset_cls in [ datasets.HMDB51, datasets.Kinetics, datasets.UCF101, ]: - WRAPPER_FACTORIES[dataset_type] = video_classification_wrapper_factory + WRAPPER_FACTORIES.register(dataset_cls)(video_classification_wrapper_factory) @WRAPPER_FACTORIES.register(datasets.Caltech101) From 0286238a2d2c27496af845d796e2bc9bc1286cbe Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 16:55:27 +0100 Subject: [PATCH 21/25] hard depend on wrapper in stable tests --- test/datasets_utils.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 7121813fe91..210d2e5ab59 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -583,11 +583,10 @@ def test_transforms(self, config): @test_all_configs def test_transforms_v2_wrapper(self, config): - # This is stable test, so we can't depend on prototype stuff - try: - from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2 - except ImportError: - return + # Although this is a stable test, we unconditionally import from `torchvision.prototype` here. The wrapper needs + # to be available with the next release when v2 is released. Thus, if this import somehow fails on the release + # branch, we screwed up the roll-out + from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2 try: with self.create_dataset(config) as (dataset, _): From be42cc96caaaa4a73289178e057c80edb51dcb1f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 17:13:33 +0100 Subject: [PATCH 22/25] remove target type wrapping default --- .../prototype/datapoints/_dataset_wrapper.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 2673f30f82e..68d4a230342 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -106,10 +106,7 @@ def list_of_dicts_to_dict_of_lists(list_of_dicts): return dict(dict_of_lists) -def wrap_target_by_type(target, *, target_types, type_wrappers=None): - if type_wrappers is None: - type_wrappers = dict() - +def wrap_target_by_type(target, *, target_types, type_wrappers): if not isinstance(target, (tuple, list)): target = [target] @@ -185,14 +182,7 @@ def caltech101_wrapper_factory(dataset): if "annotation" in dataset.target_type: raise_not_supported("Caltech101 dataset with `target_type=['annotation', ...]`") - def wrapper(sample): - image, target = sample - - target = wrap_target_by_type(target, target_types=dataset.target_type) - - return image, target - - return wrapper + return classification_wrapper_factory(dataset) @WRAPPER_FACTORIES.register(datasets.CocoDetection) From e3c4d50e993bf2ff18aabd67000b134469346a76 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 17:29:28 +0100 Subject: [PATCH 23/25] make test more strict --- test/datasets_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/datasets_utils.py b/test/datasets_utils.py index 210d2e5ab59..598d4408b76 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -25,6 +25,7 @@ import torchvision.datasets import torchvision.io from common_utils import disable_console_output, get_tmp_dir +from torch.utils._pytree import tree_any from torchvision.transforms.functional import get_dimensions @@ -587,13 +588,14 @@ def test_transforms_v2_wrapper(self, config): # to be available with the next release when v2 is released. Thus, if this import somehow fails on the release # branch, we screwed up the roll-out from torchvision.prototype.datapoints import wrap_dataset_for_transforms_v2 + from torchvision.prototype.datapoints._datapoint import Datapoint try: with self.create_dataset(config) as (dataset, _): wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) wrapped_sample = wrapped_dataset[0] - assert wrapped_sample is not None - except ValueError as error: + assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample) + except TypeError as error: if str(error).startswith(f"No wrapper exist for dataset class {type(dataset).__name__}"): return raise error From 351becbc87e0c83299720d0891ebb119c8d41791 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 17:45:59 +0100 Subject: [PATCH 24/25] fix cityscapes instance return --- torchvision/prototype/datapoints/_dataset_wrapper.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index 68d4a230342..cb980a1dd87 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -352,7 +352,7 @@ def cityscapes_wrapper_factory(dataset): def instance_segmentation_wrapper(mask): # See https://github.com/mcordts/cityscapesScripts/blob/8da5dd00c9069058ccc134654116aac52d4f6fa2/cityscapesscripts/preparation/json2instanceImg.py#L7-L21 - data = F.pil_to_tensor(mask).squeeze(0) + data = pil_image_to_mask(mask) masks = [] labels = [] for id in data.unique(): @@ -361,11 +361,7 @@ def instance_segmentation_wrapper(mask): if label >= 1_000: label //= 1_000 labels.append(label) - masks = datapoints.Mask(torch.stack(masks)) - # FIXME: without the labels, returning just the instance masks is pretty useless. However, we would need to - # return a two-tuple or the like where we originally only had a single PIL image. - labels = datapoints.Label(torch.stack(labels)) - return masks + return dict(masks=datapoints.Mask(torch.stack(masks)), labels=torch.stack(labels)) def wrapper(sample): image, target = sample From 8ed41ba6dfe1325cc30d40666e887a175460214d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 9 Feb 2023 17:53:38 +0100 Subject: [PATCH 25/25] add comment for two stage design --- torchvision/prototype/datapoints/_dataset_wrapper.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchvision/prototype/datapoints/_dataset_wrapper.py b/torchvision/prototype/datapoints/_dataset_wrapper.py index cb980a1dd87..db96493b020 100644 --- a/torchvision/prototype/datapoints/_dataset_wrapper.py +++ b/torchvision/prototype/datapoints/_dataset_wrapper.py @@ -29,6 +29,10 @@ def decorator(wrapper_factory): return decorator +# We need this two-stage design, i.e. a wrapper factory producing the actual wrapper, since some wrappers depend on the +# dataset instance rather than just the class, since they require the user defined instance attributes. Thus, we can +# provide a wrapping from the dataset class to the factory here, but can only instantiate the wrapper at runtime when +# we have access to the dataset instance. WRAPPER_FACTORIES = WrapperFactories()