@@ -67,50 +67,58 @@ def test_crop(self):
67
67
self .compareTensorToPIL (img_tensor_cropped , pil_img_cropped )
68
68
69
69
def test_hsv2rgb (self ):
70
+ scripted_fn = torch .jit .script (F_t ._hsv2rgb )
70
71
shape = (3 , 100 , 150 )
71
- for _ in range (20 ):
72
- img = torch .rand (* shape , dtype = torch .float )
73
- ft_img = F_t ._hsv2rgb (img ).permute (1 , 2 , 0 ).flatten (0 , 1 )
72
+ for _ in range (10 ):
73
+ hsv_img = torch .rand (* shape , dtype = torch .float , device = self .device )
74
+ rgb_img = F_t ._hsv2rgb (hsv_img )
75
+ ft_img = rgb_img .permute (1 , 2 , 0 ).flatten (0 , 1 )
74
76
75
- h , s , v , = img .unbind (0 )
76
- h = h .flatten ().numpy ()
77
- s = s .flatten ().numpy ()
78
- v = v .flatten ().numpy ()
77
+ h , s , v , = hsv_img .unbind (0 )
78
+ h = h .flatten ().cpu (). numpy ()
79
+ s = s .flatten ().cpu (). numpy ()
80
+ v = v .flatten ().cpu (). numpy ()
79
81
80
82
rgb = []
81
83
for h1 , s1 , v1 in zip (h , s , v ):
82
84
rgb .append (colorsys .hsv_to_rgb (h1 , s1 , v1 ))
83
-
84
- colorsys_img = torch .tensor (rgb , dtype = torch .float32 )
85
+ colorsys_img = torch .tensor (rgb , dtype = torch .float32 , device = self .device )
85
86
max_diff = (ft_img - colorsys_img ).abs ().max ()
86
87
self .assertLess (max_diff , 1e-5 )
87
88
89
+ s_rgb_img = scripted_fn (hsv_img )
90
+ self .assertTrue (rgb_img .allclose (s_rgb_img ))
91
+
88
92
def test_rgb2hsv (self ):
93
+ scripted_fn = torch .jit .script (F_t ._rgb2hsv )
89
94
shape = (3 , 150 , 100 )
90
- for _ in range (20 ):
91
- img = torch .rand (* shape , dtype = torch .float )
92
- ft_hsv_img = F_t ._rgb2hsv (img ).permute (1 , 2 , 0 ).flatten (0 , 1 )
95
+ for _ in range (10 ):
96
+ rgb_img = torch .rand (* shape , dtype = torch .float , device = self .device )
97
+ hsv_img = F_t ._rgb2hsv (rgb_img )
98
+ ft_hsv_img = hsv_img .permute (1 , 2 , 0 ).flatten (0 , 1 )
93
99
94
- r , g , b , = img .unbind (0 )
95
- r = r .flatten ().numpy ()
96
- g = g .flatten ().numpy ()
97
- b = b .flatten ().numpy ()
100
+ r , g , b , = rgb_img .unbind (0 )
101
+ r = r .flatten ().cpu (). numpy ()
102
+ g = g .flatten ().cpu (). numpy ()
103
+ b = b .flatten ().cpu (). numpy ()
98
104
99
105
hsv = []
100
106
for r1 , g1 , b1 in zip (r , g , b ):
101
107
hsv .append (colorsys .rgb_to_hsv (r1 , g1 , b1 ))
102
108
103
- colorsys_img = torch .tensor (hsv , dtype = torch .float32 )
109
+ colorsys_img = torch .tensor (hsv , dtype = torch .float32 , device = self . device )
104
110
105
111
ft_hsv_img_h , ft_hsv_img_sv = torch .split (ft_hsv_img , [1 , 2 ], dim = 1 )
106
112
colorsys_img_h , colorsys_img_sv = torch .split (colorsys_img , [1 , 2 ], dim = 1 )
107
113
108
114
max_diff_h = ((colorsys_img_h * 2 * math .pi ).sin () - (ft_hsv_img_h * 2 * math .pi ).sin ()).abs ().max ()
109
115
max_diff_sv = (colorsys_img_sv - ft_hsv_img_sv ).abs ().max ()
110
116
max_diff = max (max_diff_h , max_diff_sv )
111
-
112
117
self .assertLess (max_diff , 1e-5 )
113
118
119
+ s_hsv_img = scripted_fn (rgb_img )
120
+ self .assertTrue (hsv_img .allclose (s_hsv_img ))
121
+
114
122
def test_rgb_to_grayscale (self ):
115
123
script_rgb_to_grayscale = torch .jit .script (F .rgb_to_grayscale )
116
124
0 commit comments