@@ -183,6 +183,9 @@ def test_combined_targets(self):
183183 ), "Type of the combined target does not match the type of the corresponding individual target: "
184184 f"{ actual } is not { expected } " ,
185185
186+ def test_transforms_v2_wrapper (self ):
187+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (target_type = "category" ))
188+
186189
187190class Caltech256TestCase (datasets_utils .ImageDatasetTestCase ):
188191 DATASET_CLASS = datasets .Caltech256
@@ -203,6 +206,9 @@ def inject_fake_data(self, tmpdir, config):
203206
204207 return num_images_per_category * len (categories )
205208
209+ def test_transforms_v2_wrapper (self ):
210+ datasets_utils .check_transforms_v2_wrapper (self )
211+
206212
207213class WIDERFaceTestCase (datasets_utils .ImageDatasetTestCase ):
208214 DATASET_CLASS = datasets .WIDERFace
@@ -258,6 +264,9 @@ def inject_fake_data(self, tmpdir, config):
258264
259265 return split_to_num_examples [config ["split" ]]
260266
267+ def test_transforms_v2_wrapper (self ):
268+ datasets_utils .check_transforms_v2_wrapper (self , supports_target_keys = True )
269+
261270
262271class CityScapesTestCase (datasets_utils .ImageDatasetTestCase ):
263272 DATASET_CLASS = datasets .Cityscapes
@@ -382,6 +391,10 @@ def test_feature_types_target_polygon(self):
382391 assert isinstance (polygon_img , PIL .Image .Image )
383392 (polygon_target , info ["expected_polygon_target" ])
384393
394+ def test_transforms_v2_wrapper (self ):
395+ for target_type in ["instance" , "semantic" , ["instance" , "semantic" ]]:
396+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (target_type = target_type ))
397+
385398
386399class ImageNetTestCase (datasets_utils .ImageDatasetTestCase ):
387400 DATASET_CLASS = datasets .ImageNet
@@ -413,6 +426,9 @@ def inject_fake_data(self, tmpdir, config):
413426 torch .save ((wnid_to_classes , None ), tmpdir / "meta.bin" )
414427 return num_examples
415428
429+ def test_transforms_v2_wrapper (self ):
430+ datasets_utils .check_transforms_v2_wrapper (self )
431+
416432
417433class CIFAR10TestCase (datasets_utils .ImageDatasetTestCase ):
418434 DATASET_CLASS = datasets .CIFAR10
@@ -470,6 +486,9 @@ def test_class_to_idx(self):
470486 actual = dataset .class_to_idx
471487 assert actual == expected
472488
489+ def test_transforms_v2_wrapper (self ):
490+ datasets_utils .check_transforms_v2_wrapper (self )
491+
473492
474493class CIFAR100 (CIFAR10TestCase ):
475494 DATASET_CLASS = datasets .CIFAR100
@@ -484,6 +503,9 @@ class CIFAR100(CIFAR10TestCase):
484503 categories_key = "fine_label_names" ,
485504 )
486505
506+ def test_transforms_v2_wrapper (self ):
507+ datasets_utils .check_transforms_v2_wrapper (self )
508+
487509
488510class CelebATestCase (datasets_utils .ImageDatasetTestCase ):
489511 DATASET_CLASS = datasets .CelebA
@@ -607,6 +629,10 @@ def test_images_names_split(self):
607629
608630 assert merged_imgs_names == all_imgs_names
609631
632+ def test_transforms_v2_wrapper (self ):
633+ for target_type in ["identity" , "bbox" , ["identity" , "bbox" ]]:
634+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (target_type = target_type ))
635+
610636
611637class VOCSegmentationTestCase (datasets_utils .ImageDatasetTestCase ):
612638 DATASET_CLASS = datasets .VOCSegmentation
@@ -694,6 +720,9 @@ def add_bndbox(obj, bndbox=None):
694720
695721 return data
696722
723+ def test_transforms_v2_wrapper (self ):
724+ datasets_utils .check_transforms_v2_wrapper (self )
725+
697726
698727class VOCDetectionTestCase (VOCSegmentationTestCase ):
699728 DATASET_CLASS = datasets .VOCDetection
@@ -714,6 +743,10 @@ def test_annotations(self):
714743
715744 assert object == info ["annotation" ]
716745
746+ def test_transforms_v2_wrapper (self ):
747+ for target_type in ["identity" , "bbox" , ["identity" , "bbox" ]]:
748+ datasets_utils .check_transforms_v2_wrapper (self , supports_target_keys = True )
749+
717750
718751class CocoDetectionTestCase (datasets_utils .ImageDatasetTestCase ):
719752 DATASET_CLASS = datasets .CocoDetection
@@ -784,6 +817,9 @@ def _create_json(self, root, name, content):
784817 json .dump (content , fh )
785818 return file
786819
820+ def test_transforms_v2_wrapper (self ):
821+ datasets_utils .check_transforms_v2_wrapper (self , supports_target_keys = True )
822+
787823
788824class CocoCaptionsTestCase (CocoDetectionTestCase ):
789825 DATASET_CLASS = datasets .CocoCaptions
@@ -800,6 +836,11 @@ def test_captions(self):
800836 _ , captions = dataset [0 ]
801837 assert tuple (captions ) == tuple (info ["captions" ])
802838
839+ def test_transforms_v2_wrapper (self ):
840+ # We need to define this method, because otherwise the test from the super class will
841+ # be run
842+ pytest .skip ("CocoCaptions is currently not supported by the v2 wrapper." )
843+
803844
804845class UCF101TestCase (datasets_utils .VideoDatasetTestCase ):
805846 DATASET_CLASS = datasets .UCF101
@@ -860,6 +901,9 @@ def _create_annotation_file(self, root, name, video_files):
860901 with open (pathlib .Path (root ) / name , "w" ) as fh :
861902 fh .writelines (f"{ str (file ).replace (os .sep , '/' )} \n " for file in sorted (video_files ))
862903
904+ def test_transforms_v2_wrapper (self ):
905+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (output_format = "TCHW" ))
906+
863907
864908class LSUNTestCase (datasets_utils .ImageDatasetTestCase ):
865909 DATASET_CLASS = datasets .LSUN
@@ -966,6 +1010,9 @@ def inject_fake_data(self, tmpdir, config):
9661010 )
9671011 return num_videos_per_class * len (classes )
9681012
1013+ def test_transforms_v2_wrapper (self ):
1014+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (output_format = "TCHW" ))
1015+
9691016
9701017class HMDB51TestCase (datasets_utils .VideoDatasetTestCase ):
9711018 DATASET_CLASS = datasets .HMDB51
@@ -1026,6 +1073,9 @@ def _create_split_files(self, root, video_files, fold, train):
10261073
10271074 return num_train_videos if train else (num_videos - num_train_videos )
10281075
1076+ def test_transforms_v2_wrapper (self ):
1077+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (output_format = "TCHW" ))
1078+
10291079
10301080class OmniglotTestCase (datasets_utils .ImageDatasetTestCase ):
10311081 DATASET_CLASS = datasets .Omniglot
@@ -1193,6 +1243,9 @@ def _create_segmentation(self, size):
11931243 def _file_stem (self , idx ):
11941244 return f"2008_{ idx :06d} "
11951245
1246+ def test_transforms_v2_wrapper (self ):
1247+ datasets_utils .check_transforms_v2_wrapper (self , config = dict (mode = "segmentation" ))
1248+
11961249
11971250class FakeDataTestCase (datasets_utils .ImageDatasetTestCase ):
11981251 DATASET_CLASS = datasets .FakeData
@@ -1434,6 +1487,9 @@ def _magic(self, dtype, dims):
14341487 def _encode (self , v ):
14351488 return torch .tensor (v , dtype = torch .int32 ).numpy ().tobytes ()[::- 1 ]
14361489
1490+ def test_transforms_v2_wrapper (self ):
1491+ datasets_utils .check_transforms_v2_wrapper (self )
1492+
14371493
14381494class FashionMNISTTestCase (MNISTTestCase ):
14391495 DATASET_CLASS = datasets .FashionMNIST
@@ -1585,6 +1641,9 @@ def test_classes(self, config):
15851641 assert len (dataset .classes ) == len (info ["classes" ])
15861642 assert all ([a == b for a , b in zip (dataset .classes , info ["classes" ])])
15871643
1644+ def test_transforms_v2_wrapper (self ):
1645+ datasets_utils .check_transforms_v2_wrapper (self )
1646+
15881647
15891648class ImageFolderTestCase (datasets_utils .ImageDatasetTestCase ):
15901649 DATASET_CLASS = datasets .ImageFolder
@@ -1606,6 +1665,9 @@ def test_classes(self, config):
16061665 assert len (dataset .classes ) == len (info ["classes" ])
16071666 assert all ([a == b for a , b in zip (dataset .classes , info ["classes" ])])
16081667
1668+ def test_transforms_v2_wrapper (self ):
1669+ datasets_utils .check_transforms_v2_wrapper (self )
1670+
16091671
16101672class KittiTestCase (datasets_utils .ImageDatasetTestCase ):
16111673 DATASET_CLASS = datasets .Kitti
@@ -1642,6 +1704,9 @@ def inject_fake_data(self, tmpdir, config):
16421704
16431705 return split_to_num_examples [config ["train" ]]
16441706
1707+ def test_transforms_v2_wrapper (self ):
1708+ datasets_utils .check_transforms_v2_wrapper (self , supports_target_keys = True )
1709+
16451710
16461711class SvhnTestCase (datasets_utils .ImageDatasetTestCase ):
16471712 DATASET_CLASS = datasets .SVHN
@@ -2516,6 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx):
25162581 breed_id = "-1"
25172582 return (image_id , class_id , species , breed_id )
25182583
2584+ def test_transforms_v2_wrapper (self ):
2585+ datasets_utils .check_transforms_v2_wrapper (self )
2586+
25192587
25202588class StanfordCarsTestCase (datasets_utils .ImageDatasetTestCase ):
25212589 DATASET_CLASS = datasets .StanfordCars
0 commit comments