@@ -755,10 +755,11 @@ def test_randaug(self, inpt, interpolation, mocker):
755755 v2_transforms .InterpolationMode .BILINEAR ,
756756 ],
757757 )
758- def test_randaug_jit (self , interpolation ):
758+ @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
759+ def test_randaug_jit (self , interpolation , fill ):
759760 inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
760- t_ref = legacy_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
761- t = v2_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
761+ t_ref = legacy_transforms .RandAugment (interpolation = interpolation , num_ops = 1 , fill = fill )
762+ t = v2_transforms .RandAugment (interpolation = interpolation , num_ops = 1 , fill = fill )
762763
763764 tt_ref = torch .jit .script (t_ref )
764765 tt = torch .jit .script (t )
@@ -830,10 +831,11 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
830831 v2_transforms .InterpolationMode .BILINEAR ,
831832 ],
832833 )
833- def test_trivial_aug_jit (self , interpolation ):
834+ @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
835+ def test_trivial_aug_jit (self , interpolation , fill ):
834836 inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
835- t_ref = legacy_transforms .TrivialAugmentWide (interpolation = interpolation )
836- t = v2_transforms .TrivialAugmentWide (interpolation = interpolation )
837+ t_ref = legacy_transforms .TrivialAugmentWide (interpolation = interpolation , fill = fill )
838+ t = v2_transforms .TrivialAugmentWide (interpolation = interpolation , fill = fill )
837839
838840 tt_ref = torch .jit .script (t_ref )
839841 tt = torch .jit .script (t )
@@ -906,11 +908,12 @@ def test_augmix(self, inpt, interpolation, mocker):
906908 v2_transforms .InterpolationMode .BILINEAR ,
907909 ],
908910 )
909- def test_augmix_jit (self , interpolation ):
911+ @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
912+ def test_augmix_jit (self , interpolation , fill ):
910913 inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
911914
912- t_ref = legacy_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
913- t = v2_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
915+ t_ref = legacy_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 , fill = fill )
916+ t = v2_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 , fill = fill )
914917
915918 tt_ref = torch .jit .script (t_ref )
916919 tt = torch .jit .script (t )
0 commit comments