@@ -326,3 +326,78 @@ def forward(
326
326
)
327
327
328
328
return image , target
329
+
330
+
331
+ class FixedSizeCrop (nn .Module ):
332
+ def __init__ (self , size , fill = 0 , padding_mode = "constant" ):
333
+ super ().__init__ ()
334
+ size = tuple (T ._setup_size (size , error_msg = "Please provide only two dimensions (h, w) for size." ))
335
+ self .crop_height = size [0 ]
336
+ self .crop_width = size [1 ]
337
+ self .fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
338
+ self .padding_mode = padding_mode
339
+
340
+ def _pad (self , img , target , padding ):
341
+ # Taken from the functional_tensor.py pad
342
+ if isinstance (padding , int ):
343
+ pad_left = pad_right = pad_top = pad_bottom = padding
344
+ elif len (padding ) == 1 :
345
+ pad_left = pad_right = pad_top = pad_bottom = padding [0 ]
346
+ elif len (padding ) == 2 :
347
+ pad_left = pad_right = padding [0 ]
348
+ pad_top = pad_bottom = padding [1 ]
349
+ else :
350
+ pad_left = padding [0 ]
351
+ pad_top = padding [1 ]
352
+ pad_right = padding [2 ]
353
+ pad_bottom = padding [3 ]
354
+
355
+ padding = [pad_left , pad_top , pad_right , pad_bottom ]
356
+ img = F .pad (img , padding , self .fill , self .padding_mode )
357
+ if target is not None :
358
+ target ["boxes" ][:, 0 ::2 ] += pad_left
359
+ target ["boxes" ][:, 1 ::2 ] += pad_top
360
+ if "masks" in target :
361
+ target ["masks" ] = F .pad (target ["masks" ], padding , 0 , "constant" )
362
+
363
+ return img , target
364
+
365
+ def _crop (self , img , target , top , left , height , width ):
366
+ img = F .crop (img , top , left , height , width )
367
+ if target is not None :
368
+ boxes = target ["boxes" ]
369
+ boxes [:, 0 ::2 ] -= left
370
+ boxes [:, 1 ::2 ] -= top
371
+ boxes [:, 0 ::2 ].clamp_ (min = 0 , max = width )
372
+ boxes [:, 1 ::2 ].clamp_ (min = 0 , max = height )
373
+
374
+ is_valid = (boxes [:, 0 ] < boxes [:, 2 ]) & (boxes [:, 1 ] < boxes [:, 3 ])
375
+
376
+ target ["boxes" ] = boxes [is_valid ]
377
+ target ["labels" ] = target ["labels" ][is_valid ]
378
+ if "masks" in target :
379
+ target ["masks" ] = F .crop (target ["masks" ][is_valid ], top , left , height , width )
380
+
381
+ return img , target
382
+
383
+ def forward (self , img , target = None ):
384
+ _ , height , width = F .get_dimensions (img )
385
+ new_height = min (height , self .crop_height )
386
+ new_width = min (width , self .crop_width )
387
+
388
+ if new_height != height or new_width != width :
389
+ offset_height = max (height - self .crop_height , 0 )
390
+ offset_width = max (width - self .crop_width , 0 )
391
+
392
+ r = torch .rand (1 )
393
+ top = int (offset_height * r )
394
+ left = int (offset_width * r )
395
+
396
+ img , target = self ._crop (img , target , top , left , new_height , new_width )
397
+
398
+ pad_bottom = max (self .crop_height - new_height , 0 )
399
+ pad_right = max (self .crop_width - new_width , 0 )
400
+ if pad_bottom != 0 or pad_right != 0 :
401
+ img , target = self ._pad (img , target , [0 , 0 , pad_right , pad_bottom ])
402
+
403
+ return img , target
0 commit comments