@@ -26,24 +26,29 @@ def _test_functional_geom_op(self, func, fn_kwargs):
26
26
transformed_pil_img = getattr (F , func )(pil_img , ** fn_kwargs )
27
27
self .compareTensorToPIL (transformed_tensor , transformed_pil_img )
28
28
29
- def _test_geom_op (self , func , method , fn_kwargs = None , meth_kwargs = None ):
30
- if fn_kwargs is None :
31
- fn_kwargs = {}
29
+ def _test_class_geom_op (self , method , meth_kwargs = None ):
32
30
if meth_kwargs is None :
33
31
meth_kwargs = {}
32
+
34
33
tensor , pil_img = self ._create_data (height = 10 , width = 10 )
35
- transformed_tensor = getattr (F , func )(tensor , ** fn_kwargs )
36
- transformed_pil_img = getattr (F , func )(pil_img , ** fn_kwargs )
34
+ # test for class interface
35
+ f = getattr (T , method )(** meth_kwargs )
36
+ scripted_fn = torch .jit .script (f )
37
+
38
+ # set seed to reproduce the same transformation for tensor and PIL image
39
+ torch .manual_seed (12 )
40
+ transformed_tensor = f (tensor )
41
+ torch .manual_seed (12 )
42
+ transformed_pil_img = f (pil_img )
37
43
self .compareTensorToPIL (transformed_tensor , transformed_pil_img )
38
44
39
- scripted_fn = torch .jit . script ( getattr ( F , func ) )
40
- transformed_tensor_script = scripted_fn (tensor , ** fn_kwargs )
45
+ torch .manual_seed ( 12 )
46
+ transformed_tensor_script = scripted_fn (tensor )
41
47
self .assertTrue (transformed_tensor .equal (transformed_tensor_script ))
42
48
43
- # test for class interface
44
- f = getattr (T , method )(** meth_kwargs )
45
- scripted_fn = torch .jit .script (f )
46
- scripted_fn (tensor )
49
+ def _test_geom_op (self , func , method , fn_kwargs = None , meth_kwargs = None ):
50
+ self ._test_functional_geom_op (func , fn_kwargs )
51
+ self ._test_class_geom_op (method , meth_kwargs )
47
52
48
53
def test_random_horizontal_flip (self ):
49
54
self ._test_geom_op ('hflip' , 'RandomHorizontalFlip' )
@@ -107,21 +112,20 @@ def test_crop(self):
107
112
'crop' , 'RandomCrop' , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
108
113
)
109
114
110
- tensor = torch .randint (0 , 255 , (3 , 10 , 10 ), dtype = torch .uint8 )
111
- # Test torchscript of transforms.RandomCrop with size as int
112
- f = T .RandomCrop (size = 5 )
113
- scripted_fn = torch .jit .script (f )
114
- scripted_fn (tensor )
115
-
116
- # Test torchscript of transforms.RandomCrop with size as [int, ]
117
- f = T .RandomCrop (size = [5 , ], padding = [2 , ])
118
- scripted_fn = torch .jit .script (f )
119
- scripted_fn (tensor )
120
-
121
- # Test torchscript of transforms.RandomCrop with size as list
122
- f = T .RandomCrop (size = [6 , 6 ])
123
- scripted_fn = torch .jit .script (f )
124
- scripted_fn (tensor )
115
+ sizes = [5 , [5 , ], [6 , 6 ]]
116
+ padding_configs = [
117
+ {"padding_mode" : "constant" , "fill" : 0 },
118
+ {"padding_mode" : "constant" , "fill" : 10 },
119
+ {"padding_mode" : "constant" , "fill" : 20 },
120
+ {"padding_mode" : "edge" },
121
+ {"padding_mode" : "reflect" },
122
+ ]
123
+
124
+ for size in sizes :
125
+ for padding_config in padding_configs :
126
+ config = dict (padding_config )
127
+ config ["size" ] = size
128
+ self ._test_class_geom_op ("RandomCrop" , config )
125
129
126
130
def test_center_crop (self ):
127
131
fn_kwargs = {"output_size" : (4 , 5 )}
0 commit comments