@@ -97,9 +97,16 @@ def wrap_dataset_for_transforms_v2(dataset, target_keys=None):
9797 f"but got { target_keys } "
9898 )
9999
100- return type (f"Wrapped{ type (dataset ).__name__ } " , (VisionDatasetDatapointWrapper , type (dataset )), {})(
101- dataset , target_keys
102- )
100+ return type (
101+ f"Wrapped{ type (dataset ).__name__ } " ,
102+ (type (dataset ),),
103+ {
104+ "__init__" : wrapped_init ,
105+ "__getattr__" : wrapped_getattr ,
106+ "__getitem__" : wrapped_getitem ,
107+ "__len__" : wrapped_len ,
108+ },
109+ )(dataset , target_keys )
103110
104111
105112class WrapperFactories (dict ):
@@ -118,77 +125,79 @@ def decorator(wrapper_factory):
118125WRAPPER_FACTORIES = WrapperFactories ()
119126
120127
121- class VisionDatasetDatapointWrapper :
122- def __init__ (self , dataset , target_keys ):
123- dataset_cls = type (dataset )
124-
125- if not isinstance (dataset , datasets .VisionDataset ):
126- raise TypeError (
127- f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
128- f"but got a '{ dataset_cls .__name__ } ' instead.\n "
129- f"For an example of how to perform the wrapping for custom datasets, see\n \n "
130- "https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
131- )
128+ def wrapped_init (self , dataset , target_keys ):
129+ dataset_cls = type (dataset )
132130
133- for cls in dataset_cls .mro ():
134- if cls in WRAPPER_FACTORIES :
135- wrapper_factory = WRAPPER_FACTORIES [cls ]
136- if target_keys is not None and cls not in {
137- datasets .CocoDetection ,
138- datasets .VOCDetection ,
139- datasets .Kitti ,
140- datasets .WIDERFace ,
141- }:
142- raise ValueError (
143- f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
144- f"and `WIDERFace`, but got { cls .__name__ } ."
145- )
146- break
147- elif cls is datasets .VisionDataset :
148- # TODO: If we have documentation on how to do that, put a link in the error message.
149- msg = f"No wrapper exists for dataset class { dataset_cls .__name__ } . Please wrap the output yourself."
150- if dataset_cls in datasets .__dict__ .values ():
151- msg = (
152- f"{ msg } If an automated wrapper for this dataset would be useful for you, "
153- f"please open an issue at https://github.com/pytorch/vision/issues."
154- )
155- raise TypeError (msg )
156-
157- self ._dataset = dataset
158- self ._wrapper = wrapper_factory (dataset , target_keys )
159-
160- # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
161- # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
162- # `transforms`
163- # https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
164- # some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
165- # disable all three here to be able to extract the untransformed sample to wrap.
166- self .transform , dataset .transform = dataset .transform , None
167- self .target_transform , dataset .target_transform = dataset .target_transform , None
168- self .transforms , dataset .transforms = dataset .transforms , None
169-
170- def __getattr__ (self , item ):
171- with contextlib .suppress (AttributeError ):
172- return object .__getattribute__ (self , item )
173-
174- return getattr (self ._dataset , item )
175-
176- def __getitem__ (self , idx ):
177- # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
178- # of this class
179- sample = self ._dataset [idx ]
180-
181- sample = self ._wrapper (idx , sample )
182-
183- # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
184- # or joint (`transforms`), we can access the full functionality through `transforms`
185- if self .transforms is not None :
186- sample = self .transforms (* sample )
187-
188- return sample
131+ if not isinstance (dataset , datasets .VisionDataset ):
132+ raise TypeError (
133+ f"This wrapper is meant for subclasses of `torchvision.datasets.VisionDataset`, "
134+ f"but got a '{ dataset_cls .__name__ } ' instead.\n "
135+ f"For an example of how to perform the wrapping for custom datasets, see\n \n "
136+ "https://pytorch.org/vision/main/auto_examples/plot_datapoints.html#do-i-have-to-wrap-the-output-of-the-datasets-myself"
137+ )
189138
190- def __len__ (self ):
191- return len (self ._dataset )
139+ for cls in dataset_cls .mro ():
140+ if cls in WRAPPER_FACTORIES :
141+ wrapper_factory = WRAPPER_FACTORIES [cls ]
142+ if target_keys is not None and cls not in {
143+ datasets .CocoDetection ,
144+ datasets .VOCDetection ,
145+ datasets .Kitti ,
146+ datasets .WIDERFace ,
147+ }:
148+ raise ValueError (
149+ f"`target_keys` is currently only supported for `CocoDetection`, `VOCDetection`, `Kitti`, "
150+ f"and `WIDERFace`, but got { cls .__name__ } ."
151+ )
152+ break
153+ elif cls is datasets .VisionDataset :
154+ # TODO: If we have documentation on how to do that, put a link in the error message.
155+ msg = f"No wrapper exists for dataset class { dataset_cls .__name__ } . Please wrap the output yourself."
156+ if dataset_cls in datasets .__dict__ .values ():
157+ msg = (
158+ f"{ msg } If an automated wrapper for this dataset would be useful for you, "
159+ f"please open an issue at https://github.com/pytorch/vision/issues."
160+ )
161+ raise TypeError (msg )
162+
163+ self ._dataset = dataset
164+ self ._wrapper = wrapper_factory (dataset , target_keys )
165+
166+ # We need to disable the transforms on the dataset here to be able to inject the wrapping before we apply them.
167+ # Although internally, `datasets.VisionDataset` merges `transform` and `target_transform` into the joint
168+ # `transforms`
169+ # https://github.com/pytorch/vision/blob/135a0f9ea9841b6324b4fe8974e2543cbb95709a/torchvision/datasets/vision.py#L52-L54
170+ # some (if not most) datasets still use `transform` and `target_transform` individually. Thus, we need to
171+ # disable all three here to be able to extract the untransformed sample to wrap.
172+ self .transform , dataset .transform = dataset .transform , None
173+ self .target_transform , dataset .target_transform = dataset .target_transform , None
174+ self .transforms , dataset .transforms = dataset .transforms , None
175+
176+
177+ def wrapped_getattr (self , item ):
178+ with contextlib .suppress (AttributeError ):
179+ return object .__getattribute__ (self , item )
180+
181+ return getattr (self ._dataset , item )
182+
183+
184+ def wrapped_getitem (self , idx ):
185+ # This gets us the raw sample since we disabled the transforms for the underlying dataset in the constructor
186+ # of this class
187+ sample = self ._dataset [idx ]
188+
189+ sample = self ._wrapper (idx , sample )
190+
191+ # Regardless of whether the user has supplied the transforms individually (`transform` and `target_transform`)
192+ # or joint (`transforms`), we can access the full functionality through `transforms`
193+ if self .transforms is not None :
194+ sample = self .transforms (* sample )
195+
196+ return sample
197+
198+
199+ def wrapped_len (self ):
200+ return len (self ._dataset )
192201
193202
194203def raise_not_supported (description ):
0 commit comments