@@ -26,24 +26,29 @@ def _test_functional_geom_op(self, func, fn_kwargs):
2626 transformed_pil_img = getattr (F , func )(pil_img , ** fn_kwargs )
2727 self .compareTensorToPIL (transformed_tensor , transformed_pil_img )
2828
29- def _test_geom_op (self , func , method , fn_kwargs = None , meth_kwargs = None ):
30- if fn_kwargs is None :
31- fn_kwargs = {}
29+ def _test_class_geom_op (self , method , meth_kwargs = None ):
3230 if meth_kwargs is None :
3331 meth_kwargs = {}
32+
3433 tensor , pil_img = self ._create_data (height = 10 , width = 10 )
35- transformed_tensor = getattr (F , func )(tensor , ** fn_kwargs )
36- transformed_pil_img = getattr (F , func )(pil_img , ** fn_kwargs )
34+ # test for class interface
35+ f = getattr (T , method )(** meth_kwargs )
36+ scripted_fn = torch .jit .script (f )
37+
38+ # set seed to reproduce the same transformation for tensor and PIL image
39+ torch .manual_seed (12 )
40+ transformed_tensor = f (tensor )
41+ torch .manual_seed (12 )
42+ transformed_pil_img = f (pil_img )
3743 self .compareTensorToPIL (transformed_tensor , transformed_pil_img )
3844
39- scripted_fn = torch .jit . script ( getattr ( F , func ) )
40- transformed_tensor_script = scripted_fn (tensor , ** fn_kwargs )
45+ torch .manual_seed ( 12 )
46+ transformed_tensor_script = scripted_fn (tensor )
4147 self .assertTrue (transformed_tensor .equal (transformed_tensor_script ))
4248
43- # test for class interface
44- f = getattr (T , method )(** meth_kwargs )
45- scripted_fn = torch .jit .script (f )
46- scripted_fn (tensor )
49+ def _test_geom_op (self , func , method , fn_kwargs = None , meth_kwargs = None ):
50+ self ._test_functional_geom_op (func , fn_kwargs )
51+ self ._test_class_geom_op (method , meth_kwargs )
4752
4853 def test_random_horizontal_flip (self ):
4954 self ._test_geom_op ('hflip' , 'RandomHorizontalFlip' )
@@ -107,21 +112,20 @@ def test_crop(self):
107112 'crop' , 'RandomCrop' , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
108113 )
109114
110- tensor = torch .randint (0 , 255 , (3 , 10 , 10 ), dtype = torch .uint8 )
111- # Test torchscript of transforms.RandomCrop with size as int
112- f = T .RandomCrop (size = 5 )
113- scripted_fn = torch .jit .script (f )
114- scripted_fn (tensor )
115-
116- # Test torchscript of transforms.RandomCrop with size as [int, ]
117- f = T .RandomCrop (size = [5 , ], padding = [2 , ])
118- scripted_fn = torch .jit .script (f )
119- scripted_fn (tensor )
120-
121- # Test torchscript of transforms.RandomCrop with size as list
122- f = T .RandomCrop (size = [6 , 6 ])
123- scripted_fn = torch .jit .script (f )
124- scripted_fn (tensor )
115+ sizes = [5 , [5 , ], [6 , 6 ]]
116+ padding_configs = [
117+ {"padding_mode" : "constant" , "fill" : 0 },
118+ {"padding_mode" : "constant" , "fill" : 10 },
119+ {"padding_mode" : "constant" , "fill" : 20 },
120+ {"padding_mode" : "edge" },
121+ {"padding_mode" : "reflect" },
122+ ]
123+
124+ for size in sizes :
125+ for padding_config in padding_configs :
126+ config = dict (padding_config )
127+ config ["size" ] = size
128+ self ._test_class_geom_op ("RandomCrop" , config )
125129
126130 def test_center_crop (self ):
127131 fn_kwargs = {"output_size" : (4 , 5 )}
0 commit comments