Skip to content

Commit ae428c4

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] add support for instance checks on dataset wrappers (#7239)
Reviewed By: matteobettini Differential Revision: D48642290 fbshipit-source-id: d44279d2024dfb0387f0d70d84d6c8d128b33394
1 parent 379ed22 commit ae428c4

File tree

4 files changed

+17
-11
lines changed

4 files changed

+17
-11
lines changed

references/detection/coco_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,9 +178,7 @@ def get_coco_api_from_dataset(dataset):
178178
break
179179
if isinstance(dataset, torch.utils.data.Subset):
180180
dataset = dataset.dataset
181-
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
182-
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
183-
):
181+
if isinstance(dataset, torchvision.datasets.CocoDetection):
184182
return dataset.coco
185183
return convert_to_coco_api(dataset)
186184

references/detection/group_by_aspect_ratio.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ def compute_aspect_ratios(dataset, indices=None):
164164
if hasattr(dataset, "get_height_and_width"):
165165
return _compute_aspect_ratios_custom_dataset(dataset, indices)
166166

167-
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
168-
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
169-
):
167+
if isinstance(dataset, torchvision.datasets.CocoDetection):
170168
return _compute_aspect_ratios_coco_dataset(dataset, indices)
171169

172170
if isinstance(dataset, torchvision.datasets.VOCDetection):

test/datasets_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ def test_transforms_v2_wrapper(self, config):
571571
from torchvision.datasets import wrap_dataset_for_transforms_v2
572572

573573
try:
574-
with self.create_dataset(config) as (dataset, _):
574+
with self.create_dataset(config) as (dataset, info):
575575
for target_keys in [None, "all"]:
576576
if target_keys is not None and self.DATASET_CLASS not in {
577577
torchvision.datasets.CocoDetection,
@@ -584,8 +584,10 @@ def test_transforms_v2_wrapper(self, config):
584584
continue
585585

586586
wrapped_dataset = wrap_dataset_for_transforms_v2(dataset, target_keys=target_keys)
587-
wrapped_sample = wrapped_dataset[0]
587+
assert isinstance(wrapped_dataset, self.DATASET_CLASS)
588+
assert len(wrapped_dataset) == info["num_examples"]
588589

590+
wrapped_sample = wrapped_dataset[0]
589591
assert tree_any(lambda item: isinstance(item, (Datapoint, PIL.Image.Image)), wrapped_sample)
590592
except TypeError as error:
591593
msg = f"No wrapper exists for dataset class {type(dataset).__name__}"

torchvision/datapoints/_dataset_wrapper.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from collections import defaultdict
99

1010
import torch
11-
from torch.utils.data import Dataset
1211

1312
from torchvision import datapoints, datasets
1413
from torchvision.transforms.v2 import functional as F
@@ -98,7 +97,16 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
9897
f"but got {target_keys}"
9998
)
10099

101-
return VisionDatasetDatapointWrapper(dataset, target_keys)
100+
# Imagine we have isinstance(dataset, datasets.ImageNet). This will create a new class with the name
101+
# "WrappedImageNet" at runtime that doubly inherits from VisionDatasetDatapointWrapper (see below) as well as the
102+
# original ImageNet class. This allows the user to do regular isinstance(wrapped_dataset, datasets.ImageNet) checks,
103+
# while we can still inject everything that we need.
104+
wrapped_dataset_cls = type(f"Wrapped{type(dataset).__name__}", (VisionDatasetDatapointWrapper, type(dataset)), {})
105+
# Since VisionDatasetDatapointWrapper comes before ImageNet in the MRO, calling the class hits
106+
# VisionDatasetDatapointWrapper.__init__ first. Since we are never doing super().__init__(...), the constructor of
107+
# ImageNet is never hit. That is by design, since we don't want to create the dataset instance again, but rather
108+
# have the existing instance as attribute on the new object.
109+
return wrapped_dataset_cls(dataset, target_keys)
102110

103111

104112
class WrapperFactories(dict):
@@ -117,7 +125,7 @@ def decorator(wrapper_factory):
117125
WRAPPER_FACTORIES = WrapperFactories()
118126

119127

120-
class VisionDatasetDatapointWrapper(Dataset):
128+
class VisionDatasetDatapointWrapper:
121129
def __init__(self, dataset, target_keys):
122130
dataset_cls = type(dataset)
123131

0 commit comments

Comments
 (0)