@@ -290,6 +290,7 @@ class ScaleJitter(nn.Module):
290
290
`"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" <https://arxiv.org/abs/2012.07177>`_.
291
291
292
292
Args:
293
+ target_size (tuple of ints): The target size for the transform provided in (height, weight) format.
293
294
scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the
294
295
range a <= scale <= b.
295
296
interpolation (InterpolationMode): Desired interpolation enum defined by
@@ -298,10 +299,12 @@ class ScaleJitter(nn.Module):
298
299
299
300
def __init__ (
300
301
self ,
302
+ target_size : Tuple [int , int ],
301
303
scale_range : Tuple [float , float ] = (0.1 , 2.0 ),
302
304
interpolation : InterpolationMode = InterpolationMode .BILINEAR ,
303
305
):
304
306
super ().__init__ ()
307
+ self .target_size = target_size
305
308
self .scale_range = scale_range
306
309
self .interpolation = interpolation
307
310
@@ -314,15 +317,17 @@ def forward(
314
317
elif image .ndimension () == 2 :
315
318
image = image .unsqueeze (0 )
316
319
317
- old_width , old_height = F .get_image_size (image )
318
-
319
320
r = self .scale_range [0 ] + torch .rand (1 ) * (self .scale_range [1 ] - self .scale_range [0 ])
320
- new_width = int (old_width * r )
321
- new_height = int (old_height * r )
321
+ new_width = int (self . target_size [ 1 ] * r )
322
+ new_height = int (self . target_size [ 0 ] * r )
322
323
323
324
image = F .resize (image , [new_height , new_width ], interpolation = self .interpolation )
324
325
325
326
if target is not None :
326
327
target ["boxes" ] *= r
328
+ if "masks" in target :
329
+ target ["masks" ] = F .resize (
330
+ target ["masks" ], [new_height , new_width ], interpolation = InterpolationMode .NEAREST
331
+ )
327
332
328
333
return image , target
0 commit comments