@@ -26,8 +26,7 @@ def compareTensorToPIL(self, tensor, pil_image, msg=None):
26
26
if np_pil_image .ndim == 2 :
27
27
np_pil_image = np_pil_image [:, :, None ]
28
28
pil_tensor = torch .as_tensor (np_pil_image .transpose ((2 , 0 , 1 )))
29
- if msg is None :
30
- msg = "tensor:\n {} \n did not equal PIL tensor:\n {}" .format (tensor , pil_tensor )
29
+ msg = "{}: tensor:\n {} \n did not equal PIL tensor:\n {}" .format (msg , tensor , pil_tensor )
31
30
self .assertTrue (tensor .equal (pil_tensor ), msg )
32
31
33
32
def approxEqualTensorToPIL (self , tensor , pil_image , tol = 1e-5 , msg = None ):
@@ -130,64 +129,6 @@ def test_rgb2hsv(self):
130
129
131
130
self .assertLess (max_diff , 1e-5 )
132
131
133
- def test_adjustments (self ):
134
- script_adjust_brightness = torch .jit .script (F_t .adjust_brightness )
135
- script_adjust_contrast = torch .jit .script (F_t .adjust_contrast )
136
- script_adjust_saturation = torch .jit .script (F_t .adjust_saturation )
137
-
138
- fns = ((F .adjust_brightness , F_t .adjust_brightness , script_adjust_brightness ),
139
- (F .adjust_contrast , F_t .adjust_contrast , script_adjust_contrast ),
140
- (F .adjust_saturation , F_t .adjust_saturation , script_adjust_saturation ))
141
-
142
- for _ in range (20 ):
143
- channels = 3
144
- dims = torch .randint (1 , 50 , (2 ,))
145
- shape = (channels , dims [0 ], dims [1 ])
146
-
147
- if torch .randint (0 , 2 , (1 ,)) == 0 :
148
- img = torch .rand (* shape , dtype = torch .float )
149
- else :
150
- img = torch .randint (0 , 256 , shape , dtype = torch .uint8 )
151
-
152
- factor = 3 * torch .rand (1 )
153
- img_clone = img .clone ()
154
- for f , ft , sft in fns :
155
-
156
- ft_img = ft (img , factor )
157
- sft_img = sft (img , factor )
158
- if not img .dtype .is_floating_point :
159
- ft_img = ft_img .to (torch .float ) / 255
160
- sft_img = sft_img .to (torch .float ) / 255
161
-
162
- img_pil = transforms .ToPILImage ()(img )
163
- f_img_pil = f (img_pil , factor )
164
- f_img = transforms .ToTensor ()(f_img_pil )
165
-
166
- # F uses uint8 and F_t uses float, so there is a small
167
- # difference in values caused by (at most 5) truncations.
168
- max_diff = (ft_img - f_img ).abs ().max ()
169
- max_diff_scripted = (sft_img - f_img ).abs ().max ()
170
- self .assertLess (max_diff , 5 / 255 + 1e-5 )
171
- self .assertLess (max_diff_scripted , 5 / 255 + 1e-5 )
172
- self .assertTrue (torch .equal (img , img_clone ))
173
-
174
- # test for class interface
175
- f = transforms .ColorJitter (brightness = factor .item ())
176
- scripted_fn = torch .jit .script (f )
177
- scripted_fn (img )
178
-
179
- f = transforms .ColorJitter (contrast = factor .item ())
180
- scripted_fn = torch .jit .script (f )
181
- scripted_fn (img )
182
-
183
- f = transforms .ColorJitter (saturation = factor .item ())
184
- scripted_fn = torch .jit .script (f )
185
- scripted_fn (img )
186
-
187
- f = transforms .ColorJitter (brightness = 1 )
188
- scripted_fn = torch .jit .script (f )
189
- scripted_fn (img )
190
-
191
132
def test_rgb_to_grayscale (self ):
192
133
script_rgb_to_grayscale = torch .jit .script (F .rgb_to_grayscale )
193
134
@@ -286,32 +227,76 @@ def test_pad(self):
286
227
with self .assertRaises (ValueError , msg = "Padding can not be negative for symmetric padding_mode" ):
287
228
F_t .pad (tensor , (- 2 , - 3 ), padding_mode = "symmetric" )
288
229
289
- def test_adjust_gamma (self ):
290
- script_fn = torch .jit .script (F_t .adjust_gamma )
291
- tensor , pil_img = self ._create_data (26 , 36 )
230
+ def _test_adjust_fn (self , fn , fn_pil , fn_t , configs ):
231
+ script_fn = torch .jit .script (fn )
292
232
293
- for dt in [torch .float64 , torch .float32 , None ]:
233
+ torch .manual_seed (15 )
234
+
235
+ tensor , pil_img = self ._create_data (26 , 34 )
236
+
237
+ for dt in [None , torch .float32 , torch .float64 ]:
294
238
295
239
if dt is not None :
296
240
tensor = F .convert_image_dtype (tensor , dt )
297
241
298
- gammas = [0.8 , 1.0 , 1.2 ]
299
- gains = [0.7 , 1.0 , 1.3 ]
300
- for gamma , gain in zip (gammas , gains ):
242
+ for config in configs :
301
243
302
- adjusted_tensor = F_t .adjust_gamma (tensor , gamma , gain )
303
- adjusted_pil = F_pil .adjust_gamma (pil_img , gamma , gain )
304
- scripted_result = script_fn (tensor , gamma , gain )
305
- self .assertEqual (adjusted_tensor .dtype , scripted_result .dtype )
306
- self .assertEqual (adjusted_tensor .size ()[1 :], adjusted_pil .size [::- 1 ])
244
+ adjusted_tensor = fn_t (tensor , ** config )
245
+ adjusted_pil = fn_pil (pil_img , ** config )
246
+ scripted_result = script_fn (tensor , ** config )
247
+ msg = "{}, {}" .format (dt , config )
248
+ self .assertEqual (adjusted_tensor .dtype , scripted_result .dtype , msg = msg )
249
+ self .assertEqual (adjusted_tensor .size ()[1 :], adjusted_pil .size [::- 1 ], msg = msg )
307
250
308
251
rbg_tensor = adjusted_tensor
252
+
309
253
if adjusted_tensor .dtype != torch .uint8 :
310
254
rbg_tensor = F .convert_image_dtype (adjusted_tensor , torch .uint8 )
311
255
312
- self .compareTensorToPIL (rbg_tensor , adjusted_pil )
256
+ # Check that max difference does not exceed 1 in [0, 255] range
257
+ # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
258
+ rbg_tensor = rbg_tensor .float ()
259
+ adjusted_pil_tensor = torch .as_tensor (np .array (adjusted_pil ).transpose ((2 , 0 , 1 ))).to (rbg_tensor )
260
+ max_diff = torch .abs (rbg_tensor - adjusted_pil_tensor ).max ().item ()
261
+ self .assertLessEqual (
262
+ max_diff ,
263
+ 1.0 ,
264
+ msg = "{}: tensor:\n {} \n did not equal PIL tensor:\n {}" .format (msg , rbg_tensor , adjusted_pil_tensor )
265
+ )
266
+
267
+ self .assertTrue (adjusted_tensor .equal (scripted_result ), msg = msg )
268
+
269
+ def test_adjust_brightness (self ):
270
+ self ._test_adjust_fn (
271
+ F .adjust_brightness ,
272
+ F_pil .adjust_brightness ,
273
+ F_t .adjust_brightness ,
274
+ [{"brightness_factor" : f } for f in [0.1 , 0.5 , 1.0 , 1.34 , 2.5 ]]
275
+ )
276
+
277
+ def test_adjust_contrast (self ):
278
+ self ._test_adjust_fn (
279
+ F .adjust_contrast ,
280
+ F_pil .adjust_contrast ,
281
+ F_t .adjust_contrast ,
282
+ [{"contrast_factor" : f } for f in [0.2 , 0.5 , 1.0 , 1.5 , 2.0 ]]
283
+ )
313
284
314
- self .assertTrue (adjusted_tensor .equal (scripted_result ))
285
+ def test_adjust_saturation (self ):
286
+ self ._test_adjust_fn (
287
+ F .adjust_saturation ,
288
+ F_pil .adjust_saturation ,
289
+ F_t .adjust_saturation ,
290
+ [{"saturation_factor" : f } for f in [0.5 , 0.75 , 1.0 , 1.25 , 1.5 ]]
291
+ )
292
+
293
+ def test_adjust_gamma (self ):
294
+ self ._test_adjust_fn (
295
+ F .adjust_gamma ,
296
+ F_pil .adjust_gamma ,
297
+ F_t .adjust_gamma ,
298
+ [{"gamma" : g1 , "gain" : g2 } for g1 , g2 in zip ([0.8 , 1.0 , 1.2 ], [0.7 , 1.0 , 1.3 ])]
299
+ )
315
300
316
301
def test_resize (self ):
317
302
script_fn = torch .jit .script (F_t .resize )
0 commit comments