-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[proto] Ported SimpleCopyPaste transform #6451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
17866a2
eb087d6
dce3a29
9a52b8b
1fa9a29
541b3d3
430f16e
ec770d7
3c7a9cd
cf6e028
1b22f32
902e0f2
fc73a88
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,16 @@ | ||
import math | ||
import numbers | ||
import warnings | ||
from typing import Any, Dict, Tuple | ||
from typing import Any, Dict, List, Tuple | ||
|
||
import PIL.Image | ||
import torch | ||
from torch.utils._pytree import tree_flatten, tree_unflatten | ||
from torchvision.ops import masks_to_boxes | ||
from torchvision.prototype import features | ||
|
||
from torchvision.prototype.transforms import functional as F | ||
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor | ||
|
||
from ._transform import _RandomApplyTransform | ||
from ._utils import has_any, is_simple_tensor, query_chw | ||
|
@@ -178,3 +183,187 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: | |
return self._mixup_onehotlabel(inpt, lam_adjusted) | ||
else: | ||
return inpt | ||
|
||
|
||
class SimpleCopyPaste(_RandomApplyTransform): | ||
def __init__( | ||
self, | ||
p: float = 0.5, | ||
blending: bool = True, | ||
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR, | ||
) -> None: | ||
super().__init__(p=p) | ||
self.resize_interpolation = resize_interpolation | ||
self.blending = blending | ||
|
||
def _copy_paste( | ||
self, | ||
image: Any, | ||
target: Dict[str, Any], | ||
paste_image: Any, | ||
paste_target: Dict[str, Any], | ||
random_selection: torch.Tensor, | ||
blending: bool = True, | ||
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, | ||
) -> Tuple[Any, Dict[str, Any]]: | ||
|
||
paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) | ||
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) | ||
paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection]) | ||
|
||
masks = target["masks"] | ||
|
||
# We resize source and paste data if they have different sizes | ||
# This is something different to TF implementation we introduced here as | ||
# originally the algorithm works on equal-sized data | ||
# (for example, coming from LSJ data augmentations) | ||
size1 = image.shape[-2:] | ||
size2 = paste_image.shape[-2:] | ||
if size1 != size2: | ||
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation) | ||
paste_masks = F.resize(paste_masks, size=size1) | ||
paste_boxes = F.resize(paste_boxes, size=size1) | ||
|
||
paste_alpha_mask = paste_masks.sum(dim=0) > 0 | ||
|
||
if blending: | ||
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) | ||
|
||
# Copy-paste images: | ||
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) | ||
|
||
# Copy-paste masks: | ||
masks = masks * (~paste_alpha_mask) | ||
non_all_zero_masks = masks.sum((-1, -2)) > 0 | ||
masks = masks[non_all_zero_masks] | ||
|
||
# Do a shallow copy of the target dict | ||
out_target = {k: v for k, v in target.items()} | ||
|
||
out_target["masks"] = torch.cat([masks, paste_masks]) | ||
|
||
# Copy-paste boxes and labels | ||
bbox_format = target["boxes"].format | ||
xyxy_boxes = masks_to_boxes(masks) | ||
# TODO: masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive | ||
# we need to add +1 to x2y2. We need to investigate that. | ||
xyxy_boxes[:, 2:] += 1 | ||
vfdev-5 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
boxes = F.convert_bounding_box_format( | ||
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False | ||
) | ||
out_target["boxes"] = torch.cat([boxes, paste_boxes]) | ||
|
||
labels = target["labels"][non_all_zero_masks] | ||
out_target["labels"] = torch.cat([labels, paste_labels]) | ||
|
||
# Check for degenerated boxes and remove them | ||
boxes = F.convert_bounding_box_format( | ||
out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY | ||
) | ||
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] | ||
if degenerate_boxes.any(): | ||
valid_targets = ~degenerate_boxes.any(dim=1) | ||
|
||
out_target["boxes"] = boxes[valid_targets] | ||
out_target["masks"] = out_target["masks"][valid_targets] | ||
out_target["labels"] = out_target["labels"][valid_targets] | ||
|
||
return image, out_target | ||
|
||
def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]: | ||
# fetch all images, bboxes, masks and labels from unstructured input | ||
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label] | ||
images, bboxes, masks, labels = [], [], [], [] | ||
for obj in flat_sample: | ||
if isinstance(obj, features.Image) or is_simple_tensor(obj): | ||
images.append(obj) | ||
elif isinstance(obj, PIL.Image.Image): | ||
images.append(pil_to_tensor(obj)) | ||
elif isinstance(obj, features.BoundingBox): | ||
bboxes.append(obj) | ||
elif isinstance(obj, features.SegmentationMask): | ||
masks.append(obj) | ||
elif isinstance(obj, (features.Label, features.OneHotLabel)): | ||
labels.append(obj) | ||
|
||
if not (len(images) == len(bboxes) == len(masks) == len(labels)): | ||
raise TypeError( | ||
f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, " | ||
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels." | ||
) | ||
|
||
targets = [] | ||
for bbox, mask, label in zip(bboxes, masks, labels): | ||
targets.append({"boxes": bbox, "masks": mask, "labels": label}) | ||
|
||
return images, targets | ||
|
||
def _insert_outputs( | ||
self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]] | ||
) -> None: | ||
c0, c1, c2, c3 = 0, 0, 0, 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can't we save these indices the first time we iterate over the flat sample? If yes, maybe we can get away with only doing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, we can do that. I just found that passing indices everywhere would be a bit bulky... |
||
for i, obj in enumerate(flat_sample): | ||
if isinstance(obj, features.Image): | ||
flat_sample[i] = features.Image.new_like(obj, output_images[c0]) | ||
c0 += 1 | ||
elif isinstance(obj, PIL.Image.Image): | ||
flat_sample[i] = F.to_image_pil(output_images[c0]) | ||
c0 += 1 | ||
elif is_simple_tensor(obj): | ||
flat_sample[i] = output_images[c0] | ||
c0 += 1 | ||
elif isinstance(obj, features.BoundingBox): | ||
flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"]) | ||
c1 += 1 | ||
elif isinstance(obj, features.SegmentationMask): | ||
flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"]) | ||
c2 += 1 | ||
elif isinstance(obj, (features.Label, features.OneHotLabel)): | ||
flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] | ||
c3 += 1 | ||
|
||
def forward(self, *inputs: Any) -> Any: | ||
sample = inputs if len(inputs) > 1 else inputs[0] | ||
|
||
flat_sample, spec = tree_flatten(sample) | ||
|
||
images, targets = self._extract_image_targets(flat_sample) | ||
|
||
# images = [t1, t2, ..., tN] | ||
# Let's define paste_images as shifted list of input images | ||
# paste_images = [t2, t3, ..., tN, t1] | ||
# FYI: in TF they mix data on the dataset level | ||
images_rolled = images[-1:] + images[:-1] | ||
targets_rolled = targets[-1:] + targets[:-1] | ||
|
||
output_images, output_targets = [], [] | ||
|
||
for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): | ||
|
||
# Random paste targets selection: | ||
num_masks = len(paste_target["masks"]) | ||
|
||
if num_masks < 1: | ||
# Such degerante case with num_masks=0 can happen with LSJ | ||
# Let's just return (image, target) | ||
output_image, output_target = image, target | ||
else: | ||
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) | ||
random_selection = torch.unique(random_selection) | ||
|
||
output_image, output_target = self._copy_paste( | ||
image, | ||
target, | ||
paste_image, | ||
paste_target, | ||
random_selection=random_selection, | ||
blending=self.blending, | ||
resize_interpolation=self.resize_interpolation, | ||
) | ||
output_images.append(output_image) | ||
output_targets.append(output_target) | ||
|
||
# Insert updated images and targets into input flat_sample | ||
self._insert_outputs(flat_sample, output_images, output_targets) | ||
|
||
return tree_unflatten(flat_sample, spec) |
Uh oh!
There was an error while loading. Please reload this page.