@@ -655,6 +655,39 @@ def test_call_consistency(config, args_kwargs):
655
655
)
656
656
657
657
658
+ @pytest .mark .parametrize (
659
+ "config" ,
660
+ [config for config in CONSISTENCY_CONFIGS if hasattr (config .legacy_cls , "get_params" )],
661
+ ids = lambda config : config .legacy_cls .__name__ ,
662
+ )
663
+ def test_get_params_alias (config ):
664
+ assert config .prototype_cls .get_params is config .legacy_cls .get_params
665
+
666
+
667
+ @pytest .mark .parametrize (
668
+ ("transform_cls" , "args_kwargs" ),
669
+ [
670
+ (prototype_transforms .RandomResizedCrop , ArgsKwargs (make_image (), scale = [0.3 , 0.7 ], ratio = [0.5 , 1.5 ])),
671
+ (prototype_transforms .RandomErasing , ArgsKwargs (make_image (), scale = (0.3 , 0.7 ), ratio = (0.5 , 1.5 ))),
672
+ (prototype_transforms .ColorJitter , ArgsKwargs (brightness = None , contrast = None , saturation = None , hue = None )),
673
+ (prototype_transforms .ElasticTransform , ArgsKwargs (alpha = [15.3 , 27.2 ], sigma = [2.5 , 3.9 ], size = [17 , 31 ])),
674
+ (prototype_transforms .GaussianBlur , ArgsKwargs (0.3 , 1.4 )),
675
+ (
676
+ prototype_transforms .RandomAffine ,
677
+ ArgsKwargs (degrees = [- 20.0 , 10.0 ], translate = None , scale_ranges = None , shears = None , img_size = [15 , 29 ]),
678
+ ),
679
+ (prototype_transforms .RandomCrop , ArgsKwargs (make_image (size = (61 , 47 )), output_size = (19 , 25 ))),
680
+ (prototype_transforms .RandomPerspective , ArgsKwargs (23 , 17 , 0.5 )),
681
+ (prototype_transforms .RandomRotation , ArgsKwargs (degrees = [- 20.0 , 10.0 ])),
682
+ (prototype_transforms .AutoAugment , ArgsKwargs (5 )),
683
+ ],
684
+ )
685
+ def test_get_params_jit (transform_cls , args_kwargs ):
686
+ args , kwargs = args_kwargs
687
+
688
+ torch .jit .script (transform_cls .get_params )(* args , ** kwargs )
689
+
690
+
658
691
@pytest .mark .parametrize (
659
692
("config" , "args_kwargs" ),
660
693
[
0 commit comments