@@ -26,8 +26,7 @@ def compareTensorToPIL(self, tensor, pil_image, msg=None):
2626 if np_pil_image .ndim == 2 :
2727 np_pil_image = np_pil_image [:, :, None ]
2828 pil_tensor = torch .as_tensor (np_pil_image .transpose ((2 , 0 , 1 )))
29- if msg is None :
30- msg = "tensor:\n {} \n did not equal PIL tensor:\n {}" .format (tensor , pil_tensor )
29+ msg = "{}: tensor:\n {} \n did not equal PIL tensor:\n {}" .format (msg , tensor , pil_tensor )
3130 self .assertTrue (tensor .equal (pil_tensor ), msg )
3231
3332 def approxEqualTensorToPIL (self , tensor , pil_image , tol = 1e-5 , msg = None ):
@@ -130,64 +129,6 @@ def test_rgb2hsv(self):
130129
131130 self .assertLess (max_diff , 1e-5 )
132131
133- def test_adjustments (self ):
134- script_adjust_brightness = torch .jit .script (F_t .adjust_brightness )
135- script_adjust_contrast = torch .jit .script (F_t .adjust_contrast )
136- script_adjust_saturation = torch .jit .script (F_t .adjust_saturation )
137-
138- fns = ((F .adjust_brightness , F_t .adjust_brightness , script_adjust_brightness ),
139- (F .adjust_contrast , F_t .adjust_contrast , script_adjust_contrast ),
140- (F .adjust_saturation , F_t .adjust_saturation , script_adjust_saturation ))
141-
142- for _ in range (20 ):
143- channels = 3
144- dims = torch .randint (1 , 50 , (2 ,))
145- shape = (channels , dims [0 ], dims [1 ])
146-
147- if torch .randint (0 , 2 , (1 ,)) == 0 :
148- img = torch .rand (* shape , dtype = torch .float )
149- else :
150- img = torch .randint (0 , 256 , shape , dtype = torch .uint8 )
151-
152- factor = 3 * torch .rand (1 )
153- img_clone = img .clone ()
154- for f , ft , sft in fns :
155-
156- ft_img = ft (img , factor )
157- sft_img = sft (img , factor )
158- if not img .dtype .is_floating_point :
159- ft_img = ft_img .to (torch .float ) / 255
160- sft_img = sft_img .to (torch .float ) / 255
161-
162- img_pil = transforms .ToPILImage ()(img )
163- f_img_pil = f (img_pil , factor )
164- f_img = transforms .ToTensor ()(f_img_pil )
165-
166- # F uses uint8 and F_t uses float, so there is a small
167- # difference in values caused by (at most 5) truncations.
168- max_diff = (ft_img - f_img ).abs ().max ()
169- max_diff_scripted = (sft_img - f_img ).abs ().max ()
170- self .assertLess (max_diff , 5 / 255 + 1e-5 )
171- self .assertLess (max_diff_scripted , 5 / 255 + 1e-5 )
172- self .assertTrue (torch .equal (img , img_clone ))
173-
174- # test for class interface
175- f = transforms .ColorJitter (brightness = factor .item ())
176- scripted_fn = torch .jit .script (f )
177- scripted_fn (img )
178-
179- f = transforms .ColorJitter (contrast = factor .item ())
180- scripted_fn = torch .jit .script (f )
181- scripted_fn (img )
182-
183- f = transforms .ColorJitter (saturation = factor .item ())
184- scripted_fn = torch .jit .script (f )
185- scripted_fn (img )
186-
187- f = transforms .ColorJitter (brightness = 1 )
188- scripted_fn = torch .jit .script (f )
189- scripted_fn (img )
190-
191132 def test_rgb_to_grayscale (self ):
192133 script_rgb_to_grayscale = torch .jit .script (F .rgb_to_grayscale )
193134
@@ -286,32 +227,76 @@ def test_pad(self):
286227 with self .assertRaises (ValueError , msg = "Padding can not be negative for symmetric padding_mode" ):
287228 F_t .pad (tensor , (- 2 , - 3 ), padding_mode = "symmetric" )
288229
289- def test_adjust_gamma (self ):
290- script_fn = torch .jit .script (F_t .adjust_gamma )
291- tensor , pil_img = self ._create_data (26 , 36 )
230+ def _test_adjust_fn (self , fn , fn_pil , fn_t , configs ):
231+ script_fn = torch .jit .script (fn )
292232
293- for dt in [torch .float64 , torch .float32 , None ]:
233+ torch .manual_seed (15 )
234+
235+ tensor , pil_img = self ._create_data (26 , 34 )
236+
237+ for dt in [None , torch .float32 , torch .float64 ]:
294238
295239 if dt is not None :
296240 tensor = F .convert_image_dtype (tensor , dt )
297241
298- gammas = [0.8 , 1.0 , 1.2 ]
299- gains = [0.7 , 1.0 , 1.3 ]
300- for gamma , gain in zip (gammas , gains ):
242+ for config in configs :
301243
302- adjusted_tensor = F_t .adjust_gamma (tensor , gamma , gain )
303- adjusted_pil = F_pil .adjust_gamma (pil_img , gamma , gain )
304- scripted_result = script_fn (tensor , gamma , gain )
305- self .assertEqual (adjusted_tensor .dtype , scripted_result .dtype )
306- self .assertEqual (adjusted_tensor .size ()[1 :], adjusted_pil .size [::- 1 ])
244+ adjusted_tensor = fn_t (tensor , ** config )
245+ adjusted_pil = fn_pil (pil_img , ** config )
246+ scripted_result = script_fn (tensor , ** config )
247+ msg = "{}, {}" .format (dt , config )
248+ self .assertEqual (adjusted_tensor .dtype , scripted_result .dtype , msg = msg )
249+ self .assertEqual (adjusted_tensor .size ()[1 :], adjusted_pil .size [::- 1 ], msg = msg )
307250
308251 rbg_tensor = adjusted_tensor
252+
309253 if adjusted_tensor .dtype != torch .uint8 :
310254 rbg_tensor = F .convert_image_dtype (adjusted_tensor , torch .uint8 )
311255
312- self .compareTensorToPIL (rbg_tensor , adjusted_pil )
256+ # Check that max difference does not exceed 1 in [0, 255] range
257+ # Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
258+ rbg_tensor = rbg_tensor .float ()
259+ adjusted_pil_tensor = torch .as_tensor (np .array (adjusted_pil ).transpose ((2 , 0 , 1 ))).to (rbg_tensor )
260+ max_diff = torch .abs (rbg_tensor - adjusted_pil_tensor ).max ().item ()
261+ self .assertLessEqual (
262+ max_diff ,
263+ 1.0 ,
264+ msg = "{}: tensor:\n {} \n did not equal PIL tensor:\n {}" .format (msg , rbg_tensor , adjusted_pil_tensor )
265+ )
266+
267+ self .assertTrue (adjusted_tensor .equal (scripted_result ), msg = msg )
268+
269+ def test_adjust_brightness (self ):
270+ self ._test_adjust_fn (
271+ F .adjust_brightness ,
272+ F_pil .adjust_brightness ,
273+ F_t .adjust_brightness ,
274+ [{"brightness_factor" : f } for f in [0.1 , 0.5 , 1.0 , 1.34 , 2.5 ]]
275+ )
276+
277+ def test_adjust_contrast (self ):
278+ self ._test_adjust_fn (
279+ F .adjust_contrast ,
280+ F_pil .adjust_contrast ,
281+ F_t .adjust_contrast ,
282+ [{"contrast_factor" : f } for f in [0.2 , 0.5 , 1.0 , 1.5 , 2.0 ]]
283+ )
313284
314- self .assertTrue (adjusted_tensor .equal (scripted_result ))
285+ def test_adjust_saturation (self ):
286+ self ._test_adjust_fn (
287+ F .adjust_saturation ,
288+ F_pil .adjust_saturation ,
289+ F_t .adjust_saturation ,
290+ [{"saturation_factor" : f } for f in [0.5 , 0.75 , 1.0 , 1.25 , 1.5 ]]
291+ )
292+
293+ def test_adjust_gamma (self ):
294+ self ._test_adjust_fn (
295+ F .adjust_gamma ,
296+ F_pil .adjust_gamma ,
297+ F_t .adjust_gamma ,
298+ [{"gamma" : g1 , "gain" : g2 } for g1 , g2 in zip ([0.8 , 1.0 , 1.2 ], [0.7 , 1.0 , 1.3 ])]
299+ )
315300
316301 def test_resize (self ):
317302 script_fn = torch .jit .script (F_t .resize )
0 commit comments