@@ -445,7 +445,6 @@ class SimpleCopyPaste(torch.nn.Module):
445
445
def __init__ (self , jittering_type : str = "LSJ" ):
446
446
super ().__init__ ()
447
447
448
- # TODO: Apply random scale jittering ( resize and crop )
449
448
if jittering_type == "LSJ" :
450
449
scale_range = (0.1 , 2.0 )
451
450
elif jittering_type == "SSJ" :
@@ -476,33 +475,41 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens
476
475
for i , (image , mask ) in enumerate (zip (batch , target )):
477
476
batch [i ], target [i ] = self .transforms (image , mask )
478
477
478
+ # create copy of batch and target as the original will be modified
479
479
batch_rolled = batch .roll (1 , 0 ).detach ().clone ()
480
480
target_rolled = copy .deepcopy (target [- 1 :] + target [:- 1 ])
481
481
482
482
# TODO: select a random subset of objects from one of the images and paste them onto the other image
483
483
484
484
# TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask
485
485
486
+ # collect binary paste masks for all images
486
487
paste_masks = []
487
488
488
489
for source_image , paste_image , source_data , paste_data in zip (batch , batch_rolled , target , target_rolled ):
489
490
paste_alpha_mask = self .combine_masks (paste_data ["masks" ])
490
491
paste_masks .append (paste_alpha_mask )
491
492
493
+ # update original masks
492
494
for i , mask in enumerate (source_data ["masks" ]):
493
495
source_data ["masks" ][i ] = mask ^ paste_alpha_mask & mask
494
496
497
+ # remove masks where no annotations are present (all values are 0)
495
498
mask_filter = source_data ["masks" ].sum ((2 , 1 )).not_equal (0 )
496
499
filtered_masks = source_data ["masks" ][mask_filter ]
500
+
501
+ # update bboxes based on new masks
497
502
source_data ["boxes" ] = ops .masks_to_boxes (filtered_masks )
498
503
# TODO: update area
499
504
505
+ # concatenate paste data with original data
500
506
source_data ["masks" ] = torch .cat ((source_data ["masks" ], paste_data ["masks" ]))
501
507
source_data ["boxes" ] = torch .cat ((source_data ["boxes" ], paste_data ["boxes" ]))
502
508
source_data ["labels" ] = torch .cat ((source_data ["labels" ], paste_data ["labels" ]))
503
509
source_data ["area" ] = torch .cat ((source_data ["area" ], paste_data ["area" ]))
504
510
source_data ["iscrowd" ] = torch .cat ((source_data ["iscrowd" ], paste_data ["iscrowd" ]))
505
511
512
+ # update the original images with paste images
506
513
paste_masks = torch .stack (paste_masks )
507
514
batch .mul_ (torch .unsqueeze (torch .logical_not (paste_masks ), 1 ))
508
515
0 commit comments