Skip to content

Commit 38f33e7

Browse files
committed
Merge branch 'master' of https://github.com/pytorch/vision into adjust_hue_tensor
2 parents 3f23938 + ca3d193 commit 38f33e7

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

test/test_functional_tensor.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,50 +67,58 @@ def test_crop(self):
6767
self.compareTensorToPIL(img_tensor_cropped, pil_img_cropped)
6868

6969
def test_hsv2rgb(self):
70+
scripted_fn = torch.jit.script(F_t._hsv2rgb)
7071
shape = (3, 100, 150)
71-
for _ in range(20):
72-
img = torch.rand(*shape, dtype=torch.float)
73-
ft_img = F_t._hsv2rgb(img).permute(1, 2, 0).flatten(0, 1)
72+
for _ in range(10):
73+
hsv_img = torch.rand(*shape, dtype=torch.float, device=self.device)
74+
rgb_img = F_t._hsv2rgb(hsv_img)
75+
ft_img = rgb_img.permute(1, 2, 0).flatten(0, 1)
7476

75-
h, s, v, = img.unbind(0)
76-
h = h.flatten().numpy()
77-
s = s.flatten().numpy()
78-
v = v.flatten().numpy()
77+
h, s, v, = hsv_img.unbind(0)
78+
h = h.flatten().cpu().numpy()
79+
s = s.flatten().cpu().numpy()
80+
v = v.flatten().cpu().numpy()
7981

8082
rgb = []
8183
for h1, s1, v1 in zip(h, s, v):
8284
rgb.append(colorsys.hsv_to_rgb(h1, s1, v1))
83-
84-
colorsys_img = torch.tensor(rgb, dtype=torch.float32)
85+
colorsys_img = torch.tensor(rgb, dtype=torch.float32, device=self.device)
8586
max_diff = (ft_img - colorsys_img).abs().max()
8687
self.assertLess(max_diff, 1e-5)
8788

89+
s_rgb_img = scripted_fn(hsv_img)
90+
self.assertTrue(rgb_img.allclose(s_rgb_img))
91+
8892
def test_rgb2hsv(self):
93+
scripted_fn = torch.jit.script(F_t._rgb2hsv)
8994
shape = (3, 150, 100)
90-
for _ in range(20):
91-
img = torch.rand(*shape, dtype=torch.float)
92-
ft_hsv_img = F_t._rgb2hsv(img).permute(1, 2, 0).flatten(0, 1)
95+
for _ in range(10):
96+
rgb_img = torch.rand(*shape, dtype=torch.float, device=self.device)
97+
hsv_img = F_t._rgb2hsv(rgb_img)
98+
ft_hsv_img = hsv_img.permute(1, 2, 0).flatten(0, 1)
9399

94-
r, g, b, = img.unbind(0)
95-
r = r.flatten().numpy()
96-
g = g.flatten().numpy()
97-
b = b.flatten().numpy()
100+
r, g, b, = rgb_img.unbind(0)
101+
r = r.flatten().cpu().numpy()
102+
g = g.flatten().cpu().numpy()
103+
b = b.flatten().cpu().numpy()
98104

99105
hsv = []
100106
for r1, g1, b1 in zip(r, g, b):
101107
hsv.append(colorsys.rgb_to_hsv(r1, g1, b1))
102108

103-
colorsys_img = torch.tensor(hsv, dtype=torch.float32)
109+
colorsys_img = torch.tensor(hsv, dtype=torch.float32, device=self.device)
104110

105111
ft_hsv_img_h, ft_hsv_img_sv = torch.split(ft_hsv_img, [1, 2], dim=1)
106112
colorsys_img_h, colorsys_img_sv = torch.split(colorsys_img, [1, 2], dim=1)
107113

108114
max_diff_h = ((colorsys_img_h * 2 * math.pi).sin() - (ft_hsv_img_h * 2 * math.pi).sin()).abs().max()
109115
max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv).abs().max()
110116
max_diff = max(max_diff_h, max_diff_sv)
111-
112117
self.assertLess(max_diff, 1e-5)
113118

119+
s_hsv_img = scripted_fn(rgb_img)
120+
self.assertTrue(hsv_img.allclose(s_hsv_img))
121+
114122
def test_rgb_to_grayscale(self):
115123
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)
116124

0 commit comments

Comments
 (0)