@@ -649,37 +649,58 @@ def test_call_consistency(config, args_kwargs):
649
649
)
650
650
651
651
652
- @pytest .mark .parametrize (
653
- "config" ,
654
- [config for config in CONSISTENCY_CONFIGS if hasattr (config .legacy_cls , "get_params" )],
655
- ids = lambda config : config .legacy_cls .__name__ ,
652
+ get_params_parametrization = pytest .mark .parametrize (
653
+ ("config" , "get_params_args_kwargs" ),
654
+ [
655
+ pytest .param (
656
+ next (config for config in CONSISTENCY_CONFIGS if config .prototype_cls is transform_cls ),
657
+ get_params_args_kwargs ,
658
+ id = transform_cls .__name__ ,
659
+ )
660
+ for transform_cls , get_params_args_kwargs in [
661
+ (prototype_transforms .RandomResizedCrop , ArgsKwargs (make_image (), scale = [0.3 , 0.7 ], ratio = [0.5 , 1.5 ])),
662
+ (prototype_transforms .RandomErasing , ArgsKwargs (make_image (), scale = (0.3 , 0.7 ), ratio = (0.5 , 1.5 ))),
663
+ (prototype_transforms .ColorJitter , ArgsKwargs (brightness = None , contrast = None , saturation = None , hue = None )),
664
+ (prototype_transforms .ElasticTransform , ArgsKwargs (alpha = [15.3 , 27.2 ], sigma = [2.5 , 3.9 ], size = [17 , 31 ])),
665
+ (prototype_transforms .GaussianBlur , ArgsKwargs (0.3 , 1.4 )),
666
+ (
667
+ prototype_transforms .RandomAffine ,
668
+ ArgsKwargs (degrees = [- 20.0 , 10.0 ], translate = None , scale_ranges = None , shears = None , img_size = [15 , 29 ]),
669
+ ),
670
+ (prototype_transforms .RandomCrop , ArgsKwargs (make_image (size = (61 , 47 )), output_size = (19 , 25 ))),
671
+ (prototype_transforms .RandomPerspective , ArgsKwargs (23 , 17 , 0.5 )),
672
+ (prototype_transforms .RandomRotation , ArgsKwargs (degrees = [- 20.0 , 10.0 ])),
673
+ (prototype_transforms .AutoAugment , ArgsKwargs (5 )),
674
+ ]
675
+ ],
656
676
)
657
- def test_get_params_alias (config ):
677
+
678
+
679
+ @get_paramsl_parametrization
680
+ def test_get_params_alias (config , get_params_args_kwargs ):
658
681
assert config .prototype_cls .get_params is config .legacy_cls .get_params
659
682
683
+ if not config .args_kwargs :
684
+ return
685
+ args , kwargs = config .args_kwargs [0 ]
686
+ legacy_transform = config .legacy_cls (* args , ** kwargs )
687
+ prototype_transform = config .prototype_cls (* args , ** kwargs )
660
688
661
- @pytest .mark .parametrize (
662
- ("transform_cls" , "args_kwargs" ),
663
- [
664
- (prototype_transforms .RandomResizedCrop , ArgsKwargs (make_image (), scale = [0.3 , 0.7 ], ratio = [0.5 , 1.5 ])),
665
- (prototype_transforms .RandomErasing , ArgsKwargs (make_image (), scale = (0.3 , 0.7 ), ratio = (0.5 , 1.5 ))),
666
- (prototype_transforms .ColorJitter , ArgsKwargs (brightness = None , contrast = None , saturation = None , hue = None )),
667
- (prototype_transforms .ElasticTransform , ArgsKwargs (alpha = [15.3 , 27.2 ], sigma = [2.5 , 3.9 ], size = [17 , 31 ])),
668
- (prototype_transforms .GaussianBlur , ArgsKwargs (0.3 , 1.4 )),
669
- (
670
- prototype_transforms .RandomAffine ,
671
- ArgsKwargs (degrees = [- 20.0 , 10.0 ], translate = None , scale_ranges = None , shears = None , img_size = [15 , 29 ]),
672
- ),
673
- (prototype_transforms .RandomCrop , ArgsKwargs (make_image (size = (61 , 47 )), output_size = (19 , 25 ))),
674
- (prototype_transforms .RandomPerspective , ArgsKwargs (23 , 17 , 0.5 )),
675
- (prototype_transforms .RandomRotation , ArgsKwargs (degrees = [- 20.0 , 10.0 ])),
676
- (prototype_transforms .AutoAugment , ArgsKwargs (5 )),
677
- ],
678
- )
679
- def test_get_params_jit (transform_cls , args_kwargs ):
680
- args , kwargs = args_kwargs
689
+ assert prototype_transform .get_params is legacy_transform .get_params
690
+
691
+
692
+ @get_paramsl_parametrization
693
+ def test_get_params_jit (config , get_params_args_kwargs ):
694
+ get_params_args , get_params_kwargs = get_params_args_kwargs
695
+
696
+ torch .jit .script (config .prototype_cls .get_params )(* get_params_args , ** get_params_kwargs )
697
+
698
+ if not config .args_kwargs :
699
+ return
700
+ args , kwargs = config .args_kwargs [0 ]
701
+ transform = config .prototype_cls (* args , ** kwargs )
681
702
682
- torch .jit .script (transform_cls .get_params )(* args , ** kwargs )
703
+ torch .jit .script (transform .get_params )(* get_params_args , ** get_params_kwargs )
683
704
684
705
685
706
@pytest .mark .parametrize (
0 commit comments