Skip to content

Commit 9a020c5

Browse files
committed
[PoC] eliminate double inheritance
1 parent 09c07a0 commit 9a020c5

File tree

1 file changed

+81
-72
lines changed

1 file changed

+81
-72
lines changed

torchvision/datapoints/_dataset_wrapper.py

Lines changed: 81 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,16 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
9797
f"but got {target_keys}"
9898
)
9999

100-
return type(f"Wrapped{type(dataset).__name__}", (VisionDatasetDatapointWrapper, type(dataset)), {})(
101-
dataset, target_keys
102-
)
100+
return type(
101+
f"Wrapped{type(dataset).__name__}",
102+
(type(dataset),),
103+
{
104+
"__init__": wrapped_init,
105+
"__getattr__": wrapped_getattr,
106+
"__getitem__": wrapped_getitem,
107+
"__len__": wrapped_len,
108+
},
109+
)(dataset, target_keys)
103110

104111

105112
class WrapperFactories(dict):
@@ -118,77 +125,79 @@ def decorator(wrapper_factory):
118125
WRAPPER_FACTORIES = WrapperFactories()
119126

120127

121-
class VisionDatasetDatapointWrapper:
122-
def __init__(self, dataset, target_keys):
123-
dataset_cls = type(dataset)
124-
125-
if not isinstance(dataset, datasets.VisionDataset):
126-
raise TypeError(
127-
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
128-
f"but got a '{dataset_cls.__name__}' instead.\n"
129-
f"For an example of how to perform the wrapping for custom datasets, see\n\n"
130-
"https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
131-
)
128+
def wrapped_init(self, dataset, target_keys):
129+
dataset_cls = type(dataset)
132130

133-
for cls in dataset_cls.mro():
134-
if cls in WRAPPER_FACTORIES:
135-
wrapper_factory = WRAPPER_FACTORIES[cls]
136-
if target_keys is not None and cls not in {
137-
datasets.CocoDetection,
138-
datasets.VOCDetection,
139-
datasets.Kitti,
140-
datasets.WIDERFace,
141-
}:
142-
raise ValueError(
143-
f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
144-
f"and `WIDERFace`, but got {cls.__name__}."
145-
)
146-
break
147-
elif cls is datasets.VisionDataset:
148-
# TODO: If we have documentation on how to do that, put a link in the error message.
149-
msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
150-
if dataset_cls in datasets.__dict__.values():
151-
msg = (
152-
f"{msg} If an automated wrapper for this dataset would be useful for you, "
153-
f"please open an issue at https://github.com/pytorch/vision/issues."
154-
)
155-
raise TypeError(msg)
156-
157-
self._dataset = dataset
158-
self._wrapper = wrapper_factory(dataset, target_keys)
159-
160-
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
161-
# Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
162-
# `transforms`
163-
# https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
164-
# some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
165-
# disable all three here to be able to extract the untransformed sample to wrap.
166-
self.transform, dataset.transform = dataset.transform, None
167-
self.target_transform, dataset.target_transform = dataset.target_transform, None
168-
self.transforms, dataset.transforms = dataset.transforms, None
169-
170-
def __getattr__(self, item):
171-
with contextlib.suppress(AttributeError):
172-
return object.__getattribute__(self, item)
173-
174-
return getattr(self._dataset, item)
175-
176-
def __getitem__(self, idx):
177-
# This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
178-
# of this class
179-
sample = self._dataset[idx]
180-
181-
sample = self._wrapper(idx, sample)
182-
183-
# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
184-
# or joint (`transforms`), we can access the full functionality through `transforms`
185-
if self.transforms is not None:
186-
sample = self.transforms(*sample)
187-
188-
return sample
131+
if not isinstance(dataset, datasets.VisionDataset):
132+
raise TypeError(
133+
f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
134+
f"but got a '{dataset_cls.__name__}' instead.\n"
135+
f"For an example of how to perform the wrapping for custom datasets, see\n\n"
136+
"https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
137+
)
189138

190-
def __len__(self):
191-
return len(self._dataset)
139+
for cls in dataset_cls.mro():
140+
if cls in WRAPPER_FACTORIES:
141+
wrapper_factory = WRAPPER_FACTORIES[cls]
142+
if target_keys is not None and cls not in {
143+
datasets.CocoDetection,
144+
datasets.VOCDetection,
145+
datasets.Kitti,
146+
datasets.WIDERFace,
147+
}:
148+
raise ValueError(
149+
f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
150+
f"and `WIDERFace`, but got {cls.__name__}."
151+
)
152+
break
153+
elif cls is datasets.VisionDataset:
154+
# TODO: If we have documentation on how to do that, put a link in the error message.
155+
msg = f"No wrapper exists for dataset class {dataset_cls.__name__}. Please wrap the output yourself."
156+
if dataset_cls in datasets.__dict__.values():
157+
msg = (
158+
f"{msg} If an automated wrapper for this dataset would be useful for you, "
159+
f"please open an issue at https://github.com/pytorch/vision/issues."
160+
)
161+
raise TypeError(msg)
162+
163+
self._dataset = dataset
164+
self._wrapper = wrapper_factory(dataset, target_keys)
165+
166+
# We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
167+
# Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
168+
# `transforms`
169+
# https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
170+
# some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
171+
# disable all three here to be able to extract the untransformed sample to wrap.
172+
self.transform, dataset.transform = dataset.transform, None
173+
self.target_transform, dataset.target_transform = dataset.target_transform, None
174+
self.transforms, dataset.transforms = dataset.transforms, None
175+
176+
177+
def wrapped_getattr(self, item):
178+
with contextlib.suppress(AttributeError):
179+
return object.__getattribute__(self, item)
180+
181+
return getattr(self._dataset, item)
182+
183+
184+
def wrapped_getitem(self, idx):
185+
# This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
186+
# of this class
187+
sample = self._dataset[idx]
188+
189+
sample = self._wrapper(idx, sample)
190+
191+
# Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
192+
# or joint (`transforms`), we can access the full functionality through `transforms`
193+
if self.transforms is not None:
194+
sample = self.transforms(*sample)
195+
196+
return sample
197+
198+
199+
def wrapped_len(self):
200+
return len(self._dataset)
192201

193202

194203
def raise_not_supported(description):

0 commit comments

Comments
 (0)