@@ -149,7 +149,7 @@ def test_rgb2hsv(self):
149
149
150
150
self .assertLess (max_diff , 1e-5 )
151
151
152
- def test_adjustments (self ):
152
+ def _test_adjustments (self , device ):
153
153
script_adjust_brightness = torch .jit .script (F_t .adjust_brightness )
154
154
script_adjust_contrast = torch .jit .script (F_t .adjust_contrast )
155
155
script_adjust_saturation = torch .jit .script (F_t .adjust_saturation )
@@ -164,16 +164,16 @@ def test_adjustments(self):
164
164
shape = (channels , dims [0 ], dims [1 ])
165
165
166
166
if torch .randint (0 , 2 , (1 ,)) == 0 :
167
- img = torch .rand (* shape , dtype = torch .float )
167
+ img = torch .rand (* shape , dtype = torch .float , device = device )
168
168
else :
169
- img = torch .randint (0 , 256 , shape , dtype = torch .uint8 )
169
+ img = torch .randint (0 , 256 , shape , dtype = torch .uint8 , device = device )
170
170
171
- factor = 3 * torch .rand (1 )
171
+ factor = 3 * torch .rand (1 ). item ()
172
172
img_clone = img .clone ()
173
173
for f , ft , sft in fns :
174
174
175
- ft_img = ft (img , factor )
176
- sft_img = sft (img , factor )
175
+ ft_img = ft (img , factor ). cpu ()
176
+ sft_img = sft (img , factor ). cpu ()
177
177
if not img .dtype .is_floating_point :
178
178
ft_img = ft_img .to (torch .float ) / 255
179
179
sft_img = sft_img .to (torch .float ) / 255
@@ -191,22 +191,29 @@ def test_adjustments(self):
191
191
self .assertTrue (torch .equal (img , img_clone ))
192
192
193
193
# test for class interface
194
- f = transforms .ColorJitter (brightness = factor . item () )
194
+ f = transforms .ColorJitter (brightness = factor )
195
195
scripted_fn = torch .jit .script (f )
196
196
scripted_fn (img )
197
197
198
- f = transforms .ColorJitter (contrast = factor . item () )
198
+ f = transforms .ColorJitter (contrast = factor )
199
199
scripted_fn = torch .jit .script (f )
200
200
scripted_fn (img )
201
201
202
- f = transforms .ColorJitter (saturation = factor . item () )
202
+ f = transforms .ColorJitter (saturation = factor )
203
203
scripted_fn = torch .jit .script (f )
204
204
scripted_fn (img )
205
205
206
206
f = transforms .ColorJitter (brightness = 1 )
207
207
scripted_fn = torch .jit .script (f )
208
208
scripted_fn (img )
209
209
210
+ def test_adjustments (self ):
211
+ self ._test_adjustments ("cpu" )
212
+
213
+ @unittest .skipIf (not torch .cuda .is_available (), reason = "Skip if no CUDA device" )
214
+ def test_adjustments_cuda (self ):
215
+ self ._test_adjustments ("cuda" )
216
+
210
217
def test_rgb_to_grayscale (self ):
211
218
script_rgb_to_grayscale = torch .jit .script (F_t .rgb_to_grayscale )
212
219
img_tensor = torch .randint (0 , 255 , (3 , 16 , 16 ), dtype = torch .uint8 )
0 commit comments