Skip to content

Commit d016cab

Browse files
committed
Rewritten adjust_* tests
- split test_adjustments into 3 separate tests - unified testing approach with test_adjust_gamma
1 parent 1078e1d commit d016cab

File tree

1 file changed

+59
-74
lines changed

1 file changed

+59
-74
lines changed

test/test_functional_tensor.py

Lines changed: 59 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ def compareTensorToPIL(self, tensor, pil_image, msg=None):
2626
if np_pil_image.ndim == 2:
2727
np_pil_image = np_pil_image[:, :, None]
2828
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1)))
29-
if msg is None:
30-
msg = "tensor:\n{} \ndid not equal PIL tensor:\n{}".format(tensor, pil_tensor)
29+
msg = "{}: tensor:\n{} \ndid not equal PIL tensor:\n{}".format(msg, tensor, pil_tensor)
3130
self.assertTrue(tensor.equal(pil_tensor), msg)
3231

3332
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
@@ -130,64 +129,6 @@ def test_rgb2hsv(self):
130129

131130
self.assertLess(max_diff, 1e-5)
132131

133-
def test_adjustments(self):
134-
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
135-
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
136-
script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)
137-
138-
fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
139-
(F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
140-
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))
141-
142-
for _ in range(20):
143-
channels = 3
144-
dims = torch.randint(1, 50, (2,))
145-
shape = (channels, dims[0], dims[1])
146-
147-
if torch.randint(0, 2, (1,)) == 0:
148-
img = torch.rand(*shape, dtype=torch.float)
149-
else:
150-
img = torch.randint(0, 256, shape, dtype=torch.uint8)
151-
152-
factor = 3 * torch.rand(1)
153-
img_clone = img.clone()
154-
for f, ft, sft in fns:
155-
156-
ft_img = ft(img, factor)
157-
sft_img = sft(img, factor)
158-
if not img.dtype.is_floating_point:
159-
ft_img = ft_img.to(torch.float) / 255
160-
sft_img = sft_img.to(torch.float) / 255
161-
162-
img_pil = transforms.ToPILImage()(img)
163-
f_img_pil = f(img_pil, factor)
164-
f_img = transforms.ToTensor()(f_img_pil)
165-
166-
# F uses uint8 and F_t uses float, so there is a small
167-
# difference in values caused by (at most 5) truncations.
168-
max_diff = (ft_img - f_img).abs().max()
169-
max_diff_scripted = (sft_img - f_img).abs().max()
170-
self.assertLess(max_diff, 5 / 255 + 1e-5)
171-
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
172-
self.assertTrue(torch.equal(img, img_clone))
173-
174-
# test for class interface
175-
f = transforms.ColorJitter(brightness=factor.item())
176-
scripted_fn = torch.jit.script(f)
177-
scripted_fn(img)
178-
179-
f = transforms.ColorJitter(contrast=factor.item())
180-
scripted_fn = torch.jit.script(f)
181-
scripted_fn(img)
182-
183-
f = transforms.ColorJitter(saturation=factor.item())
184-
scripted_fn = torch.jit.script(f)
185-
scripted_fn(img)
186-
187-
f = transforms.ColorJitter(brightness=1)
188-
scripted_fn = torch.jit.script(f)
189-
scripted_fn(img)
190-
191132
def test_rgb_to_grayscale(self):
192133
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
193134

@@ -286,32 +227,76 @@ def test_pad(self):
286227
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
287228
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
288229

289-
def test_adjust_gamma(self):
290-
script_fn = torch.jit.script(F_t.adjust_gamma)
291-
tensor, pil_img = self._create_data(26, 36)
230+
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):
231+
script_fn = torch.jit.script(fn)
292232

293-
for dt in [torch.float64, torch.float32, None]:
233+
torch.manual_seed(15)
234+
235+
tensor, pil_img = self._create_data(26, 34)
236+
237+
for dt in [None, torch.float32, torch.float64]:
294238

295239
if dt is not None:
296240
tensor = F.convert_image_dtype(tensor, dt)
297241

298-
gammas = [0.8, 1.0, 1.2]
299-
gains = [0.7, 1.0, 1.3]
300-
for gamma, gain in zip(gammas, gains):
242+
for config in configs:
301243

302-
adjusted_tensor = F_t.adjust_gamma(tensor, gamma, gain)
303-
adjusted_pil = F_pil.adjust_gamma(pil_img, gamma, gain)
304-
scripted_result = script_fn(tensor, gamma, gain)
305-
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
306-
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])
244+
adjusted_tensor = fn_t(tensor, **config)
245+
adjusted_pil = fn_pil(pil_img, **config)
246+
scripted_result = script_fn(tensor, **config)
247+
msg = "{}, {}".format(dt, config)
248+
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype, msg=msg)
249+
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1], msg=msg)
307250

308251
rbg_tensor = adjusted_tensor
252+
309253
if adjusted_tensor.dtype != torch.uint8:
310254
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)
311255

312-
self.compareTensorToPIL(rbg_tensor, adjusted_pil)
256+
# Check that max difference does not exceed 1 in [0, 255] range
257+
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
258+
rbg_tensor = rbg_tensor.float()
259+
adjusted_pil_tensor = torch.as_tensor(np.array(adjusted_pil).transpose((2, 0, 1))).to(rbg_tensor)
260+
max_diff = torch.abs(rbg_tensor - adjusted_pil_tensor).max().item()
261+
self.assertLessEqual(
262+
max_diff,
263+
1.0,
264+
msg="{}: tensor:\n{} \ndid not equal PIL tensor:\n{}".format(msg, rbg_tensor, adjusted_pil_tensor)
265+
)
266+
267+
self.assertTrue(adjusted_tensor.equal(scripted_result), msg=msg)
268+
269+
def test_adjust_brightness(self):
270+
self._test_adjust_fn(
271+
F.adjust_brightness,
272+
F_pil.adjust_brightness,
273+
F_t.adjust_brightness,
274+
[{"brightness_factor": f} for f in [0.1, 0.5, 1.0, 1.34, 2.5]]
275+
)
276+
277+
def test_adjust_contrast(self):
278+
self._test_adjust_fn(
279+
F.adjust_contrast,
280+
F_pil.adjust_contrast,
281+
F_t.adjust_contrast,
282+
[{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
283+
)
313284

314-
self.assertTrue(adjusted_tensor.equal(scripted_result))
285+
def test_adjust_saturation(self):
286+
self._test_adjust_fn(
287+
F.adjust_saturation,
288+
F_pil.adjust_saturation,
289+
F_t.adjust_saturation,
290+
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.25, 1.5]]
291+
)
292+
293+
def test_adjust_gamma(self):
294+
self._test_adjust_fn(
295+
F.adjust_gamma,
296+
F_pil.adjust_gamma,
297+
F_t.adjust_gamma,
298+
[{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]
299+
)
315300

316301
def test_resize(self):
317302
script_fn = torch.jit.script(F_t.resize)

0 commit comments

Comments
 (0)