13
13
14
14
class Tester (TransformsTester ):
15
15
16
- def _test_functional_geom_op (self , func , fn_kwargs ):
16
+ def _test_functional_op (self , func , fn_kwargs ):
17
17
if fn_kwargs is None :
18
18
fn_kwargs = {}
19
19
tensor , pil_img = self ._create_data (height = 10 , width = 10 )
20
20
transformed_tensor = getattr (F , func )(tensor , ** fn_kwargs )
21
21
transformed_pil_img = getattr (F , func )(pil_img , ** fn_kwargs )
22
22
self .compareTensorToPIL (transformed_tensor , transformed_pil_img )
23
23
24
- def _test_class_geom_op (self , method , meth_kwargs = None ):
24
+ def _test_class_op (self , method , meth_kwargs = None , test_exact_match = True , ** match_kwargs ):
25
25
if meth_kwargs is None :
26
26
meth_kwargs = {}
27
27
@@ -35,21 +35,24 @@ def _test_class_geom_op(self, method, meth_kwargs=None):
35
35
transformed_tensor = f (tensor )
36
36
torch .manual_seed (12 )
37
37
transformed_pil_img = f (pil_img )
38
- self .compareTensorToPIL (transformed_tensor , transformed_pil_img )
38
+ if test_exact_match :
39
+ self .compareTensorToPIL (transformed_tensor , transformed_pil_img , ** match_kwargs )
40
+ else :
41
+ self .approxEqualTensorToPIL (transformed_tensor .float (), transformed_pil_img , ** match_kwargs )
39
42
40
43
torch .manual_seed (12 )
41
44
transformed_tensor_script = scripted_fn (tensor )
42
45
self .assertTrue (transformed_tensor .equal (transformed_tensor_script ))
43
46
44
- def _test_geom_op (self , func , method , fn_kwargs = None , meth_kwargs = None ):
45
- self ._test_functional_geom_op (func , fn_kwargs )
46
- self ._test_class_geom_op (method , meth_kwargs )
47
+ def _test_op (self , func , method , fn_kwargs = None , meth_kwargs = None ):
48
+ self ._test_functional_op (func , fn_kwargs )
49
+ self ._test_class_op (method , meth_kwargs )
47
50
48
51
def test_random_horizontal_flip (self ):
49
- self ._test_geom_op ('hflip' , 'RandomHorizontalFlip' )
52
+ self ._test_op ('hflip' , 'RandomHorizontalFlip' )
50
53
51
54
def test_random_vertical_flip (self ):
52
- self ._test_geom_op ('vflip' , 'RandomVerticalFlip' )
55
+ self ._test_op ('vflip' , 'RandomVerticalFlip' )
53
56
54
57
def test_adjustments (self ):
55
58
fns = ['adjust_brightness' , 'adjust_contrast' , 'adjust_saturation' ]
@@ -80,30 +83,30 @@ def test_adjustments(self):
80
83
def test_pad (self ):
81
84
82
85
# Test functional.pad (PIL and Tensor) with padding as single int
83
- self ._test_functional_geom_op (
86
+ self ._test_functional_op (
84
87
"pad" , fn_kwargs = {"padding" : 2 , "fill" : 0 , "padding_mode" : "constant" }
85
88
)
86
89
# Test functional.pad and transforms.Pad with padding as [int, ]
87
90
fn_kwargs = meth_kwargs = {"padding" : [2 , ], "fill" : 0 , "padding_mode" : "constant" }
88
- self ._test_geom_op (
91
+ self ._test_op (
89
92
"pad" , "Pad" , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
90
93
)
91
94
# Test functional.pad and transforms.Pad with padding as list
92
95
fn_kwargs = meth_kwargs = {"padding" : [4 , 4 ], "fill" : 0 , "padding_mode" : "constant" }
93
- self ._test_geom_op (
96
+ self ._test_op (
94
97
"pad" , "Pad" , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
95
98
)
96
99
# Test functional.pad and transforms.Pad with padding as tuple
97
100
fn_kwargs = meth_kwargs = {"padding" : (2 , 2 , 2 , 2 ), "fill" : 127 , "padding_mode" : "constant" }
98
- self ._test_geom_op (
101
+ self ._test_op (
99
102
"pad" , "Pad" , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
100
103
)
101
104
102
105
def test_crop (self ):
103
106
fn_kwargs = {"top" : 2 , "left" : 3 , "height" : 4 , "width" : 5 }
104
107
# Test transforms.RandomCrop with size and padding as tuple
105
108
meth_kwargs = {"size" : (4 , 5 ), "padding" : (4 , 4 ), "pad_if_needed" : True , }
106
- self ._test_geom_op (
109
+ self ._test_op (
107
110
'crop' , 'RandomCrop' , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
108
111
)
109
112
@@ -120,17 +123,17 @@ def test_crop(self):
120
123
for padding_config in padding_configs :
121
124
config = dict (padding_config )
122
125
config ["size" ] = size
123
- self ._test_class_geom_op ("RandomCrop" , config )
126
+ self ._test_class_op ("RandomCrop" , config )
124
127
125
128
def test_center_crop (self ):
126
129
fn_kwargs = {"output_size" : (4 , 5 )}
127
130
meth_kwargs = {"size" : (4 , 5 ), }
128
- self ._test_geom_op (
131
+ self ._test_op (
129
132
"center_crop" , "CenterCrop" , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
130
133
)
131
134
fn_kwargs = {"output_size" : (5 ,)}
132
135
meth_kwargs = {"size" : (5 , )}
133
- self ._test_geom_op (
136
+ self ._test_op (
134
137
"center_crop" , "CenterCrop" , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
135
138
)
136
139
tensor = torch .randint (0 , 255 , (3 , 10 , 10 ), dtype = torch .uint8 )
@@ -149,7 +152,7 @@ def test_center_crop(self):
149
152
scripted_fn = torch .jit .script (f )
150
153
scripted_fn (tensor )
151
154
152
- def _test_geom_op_list_output (self , func , method , out_length , fn_kwargs = None , meth_kwargs = None ):
155
+ def _test_op_list_output (self , func , method , out_length , fn_kwargs = None , meth_kwargs = None ):
153
156
if fn_kwargs is None :
154
157
fn_kwargs = {}
155
158
if meth_kwargs is None :
@@ -178,37 +181,37 @@ def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, me
178
181
179
182
def test_five_crop (self ):
180
183
fn_kwargs = meth_kwargs = {"size" : (5 ,)}
181
- self ._test_geom_op_list_output (
184
+ self ._test_op_list_output (
182
185
"five_crop" , "FiveCrop" , out_length = 5 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
183
186
)
184
187
fn_kwargs = meth_kwargs = {"size" : [5 , ]}
185
- self ._test_geom_op_list_output (
188
+ self ._test_op_list_output (
186
189
"five_crop" , "FiveCrop" , out_length = 5 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
187
190
)
188
191
fn_kwargs = meth_kwargs = {"size" : (4 , 5 )}
189
- self ._test_geom_op_list_output (
192
+ self ._test_op_list_output (
190
193
"five_crop" , "FiveCrop" , out_length = 5 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
191
194
)
192
195
fn_kwargs = meth_kwargs = {"size" : [4 , 5 ]}
193
- self ._test_geom_op_list_output (
196
+ self ._test_op_list_output (
194
197
"five_crop" , "FiveCrop" , out_length = 5 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
195
198
)
196
199
197
200
def test_ten_crop (self ):
198
201
fn_kwargs = meth_kwargs = {"size" : (5 ,)}
199
- self ._test_geom_op_list_output (
202
+ self ._test_op_list_output (
200
203
"ten_crop" , "TenCrop" , out_length = 10 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
201
204
)
202
205
fn_kwargs = meth_kwargs = {"size" : [5 , ]}
203
- self ._test_geom_op_list_output (
206
+ self ._test_op_list_output (
204
207
"ten_crop" , "TenCrop" , out_length = 10 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
205
208
)
206
209
fn_kwargs = meth_kwargs = {"size" : (4 , 5 )}
207
- self ._test_geom_op_list_output (
210
+ self ._test_op_list_output (
208
211
"ten_crop" , "TenCrop" , out_length = 10 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
209
212
)
210
213
fn_kwargs = meth_kwargs = {"size" : [4 , 5 ]}
211
- self ._test_geom_op_list_output (
214
+ self ._test_op_list_output (
212
215
"ten_crop" , "TenCrop" , out_length = 10 , fn_kwargs = fn_kwargs , meth_kwargs = meth_kwargs
213
216
)
214
217
@@ -312,6 +315,24 @@ def test_random_perspective(self):
312
315
out2 = s_transform (tensor )
313
316
self .assertTrue (out1 .equal (out2 ))
314
317
318
+ def test_to_grayscale (self ):
319
+
320
+ meth_kwargs = {"num_output_channels" : 1 }
321
+ tol = 1.0 + 1e-10
322
+ self ._test_class_op (
323
+ "Grayscale" , meth_kwargs = meth_kwargs , test_exact_match = False , tol = tol , agg_method = "max"
324
+ )
325
+
326
+ meth_kwargs = {"num_output_channels" : 3 }
327
+ self ._test_class_op (
328
+ "Grayscale" , meth_kwargs = meth_kwargs , test_exact_match = False , tol = tol , agg_method = "max"
329
+ )
330
+
331
+ meth_kwargs = {}
332
+ self ._test_class_op (
333
+ "RandomGrayscale" , meth_kwargs = meth_kwargs , test_exact_match = False , tol = tol , agg_method = "max"
334
+ )
335
+
315
336
316
337
if __name__ == '__main__' :
317
338
unittest .main ()
0 commit comments