Skip to content

Commit fa57c16

Browse files
committed
added comments
1 parent 279e189 commit fa57c16

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

references/detection/transforms.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,6 @@ class SimpleCopyPaste(torch.nn.Module):
445445
def __init__(self, jittering_type: str = "LSJ"):
446446
super().__init__()
447447

448-
# TODO: Apply random scale jittering ( resize and crop )
449448
if jittering_type == "LSJ":
450449
scale_range = (0.1, 2.0)
451450
elif jittering_type == "SSJ":
@@ -476,33 +475,41 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens
476475
for i, (image, mask) in enumerate(zip(batch, target)):
477476
batch[i], target[i] = self.transforms(image, mask)
478477

478+
# create copy of batch and target as the original will be modified
479479
batch_rolled = batch.roll(1, 0).detach().clone()
480480
target_rolled = copy.deepcopy(target[-1:] + target[:-1])
481481

482482
# TODO: select a random subset of objects from one of the images and paste them onto the other image
483483

484484
# TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask
485485

486+
# collect binary paste masks for all images
486487
paste_masks = []
487488

488489
for source_image, paste_image, source_data, paste_data in zip(batch, batch_rolled, target, target_rolled):
489490
paste_alpha_mask = self.combine_masks(paste_data["masks"])
490491
paste_masks.append(paste_alpha_mask)
491492

493+
# update original masks
492494
for i, mask in enumerate(source_data["masks"]):
493495
source_data["masks"][i] = mask ^ paste_alpha_mask & mask
494496

497+
# remove masks where no annotations are present (all values are 0)
495498
mask_filter = source_data["masks"].sum((2, 1)).not_equal(0)
496499
filtered_masks = source_data["masks"][mask_filter]
500+
501+
# update bboxes based on new masks
497502
source_data["boxes"] = ops.masks_to_boxes(filtered_masks)
498503
# TODO: update area
499504

505+
# concatenate paste data with original data
500506
source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"]))
501507
source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"]))
502508
source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"]))
503509
source_data["area"] = torch.cat((source_data["area"], paste_data["area"]))
504510
source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"]))
505511

512+
# update the original images with paste images
506513
paste_masks = torch.stack(paste_masks)
507514
batch.mul_(torch.unsqueeze(torch.logical_not(paste_masks), 1))
508515

0 commit comments

Comments
 (0)