Skip to content

Commit e1d54f1

Browse files
committed
Updated test_adjustments to run on CPU and CUDA
1 parent 91ebd75 commit e1d54f1

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

test/test_functional_tensor.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def test_rgb2hsv(self):
149149

150150
self.assertLess(max_diff, 1e-5)
151151

152-
def test_adjustments(self):
152+
def _test_adjustments(self, device):
153153
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
154154
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
155155
script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)
@@ -164,16 +164,16 @@ def test_adjustments(self):
164164
shape = (channels, dims[0], dims[1])
165165

166166
if torch.randint(0, 2, (1,)) == 0:
167-
img = torch.rand(*shape, dtype=torch.float)
167+
img = torch.rand(*shape, dtype=torch.float, device=device)
168168
else:
169-
img = torch.randint(0, 256, shape, dtype=torch.uint8)
169+
img = torch.randint(0, 256, shape, dtype=torch.uint8, device=device)
170170

171-
factor = 3 * torch.rand(1)
171+
factor = 3 * torch.rand(1).item()
172172
img_clone = img.clone()
173173
for f, ft, sft in fns:
174174

175-
ft_img = ft(img, factor)
176-
sft_img = sft(img, factor)
175+
ft_img = ft(img, factor).cpu()
176+
sft_img = sft(img, factor).cpu()
177177
if not img.dtype.is_floating_point:
178178
ft_img = ft_img.to(torch.float) / 255
179179
sft_img = sft_img.to(torch.float) / 255
@@ -191,22 +191,29 @@ def test_adjustments(self):
191191
self.assertTrue(torch.equal(img, img_clone))
192192

193193
# test for class interface
194-
f = transforms.ColorJitter(brightness=factor.item())
194+
f = transforms.ColorJitter(brightness=factor)
195195
scripted_fn = torch.jit.script(f)
196196
scripted_fn(img)
197197

198-
f = transforms.ColorJitter(contrast=factor.item())
198+
f = transforms.ColorJitter(contrast=factor)
199199
scripted_fn = torch.jit.script(f)
200200
scripted_fn(img)
201201

202-
f = transforms.ColorJitter(saturation=factor.item())
202+
f = transforms.ColorJitter(saturation=factor)
203203
scripted_fn = torch.jit.script(f)
204204
scripted_fn(img)
205205

206206
f = transforms.ColorJitter(brightness=1)
207207
scripted_fn = torch.jit.script(f)
208208
scripted_fn(img)
209209

210+
def test_adjustments(self):
211+
self._test_adjustments("cpu")
212+
213+
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
214+
def test_adjustments_cuda(self):
215+
self._test_adjustments("cuda")
216+
210217
def test_rgb_to_grayscale(self):
211218
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
212219
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)

0 commit comments

Comments
 (0)