|
1 | 1 | import math
|
2 | 2 | import numbers
|
3 | 3 | import warnings
|
4 |
| -from typing import Any, Dict, Tuple |
| 4 | +from typing import Any, Dict, List, Tuple |
5 | 5 |
|
| 6 | +import PIL.Image |
6 | 7 | import torch
|
| 8 | +from torch.utils._pytree import tree_flatten, tree_unflatten |
| 9 | +from torchvision.ops import masks_to_boxes |
7 | 10 | from torchvision.prototype import features
|
| 11 | + |
8 | 12 | from torchvision.prototype.transforms import functional as F
|
| 13 | +from torchvision.transforms.functional import InterpolationMode, pil_to_tensor |
9 | 14 |
|
10 | 15 | from ._transform import _RandomApplyTransform
|
11 | 16 | from ._utils import has_any, is_simple_tensor, query_chw
|
@@ -178,3 +183,187 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
|
178 | 183 | return self._mixup_onehotlabel(inpt, lam_adjusted)
|
179 | 184 | else:
|
180 | 185 | return inpt
|
| 186 | + |
| 187 | + |
| 188 | +class SimpleCopyPaste(_RandomApplyTransform): |
| 189 | + def __init__( |
| 190 | + self, |
| 191 | + p: float = 0.5, |
| 192 | + blending: bool = True, |
| 193 | + resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR, |
| 194 | + ) -> None: |
| 195 | + super().__init__(p=p) |
| 196 | + self.resize_interpolation = resize_interpolation |
| 197 | + self.blending = blending |
| 198 | + |
| 199 | + def _copy_paste( |
| 200 | + self, |
| 201 | + image: Any, |
| 202 | + target: Dict[str, Any], |
| 203 | + paste_image: Any, |
| 204 | + paste_target: Dict[str, Any], |
| 205 | + random_selection: torch.Tensor, |
| 206 | + blending: bool = True, |
| 207 | + resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, |
| 208 | + ) -> Tuple[Any, Dict[str, Any]]: |
| 209 | + |
| 210 | + paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection]) |
| 211 | + paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection]) |
| 212 | + paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection]) |
| 213 | + |
| 214 | + masks = target["masks"] |
| 215 | + |
| 216 | + # We resize source and paste data if they have different sizes |
| 217 | + # This is something different to TF implementation we introduced here as |
| 218 | + # originally the algorithm works on equal-sized data |
| 219 | + # (for example, coming from LSJ data augmentations) |
| 220 | + size1 = image.shape[-2:] |
| 221 | + size2 = paste_image.shape[-2:] |
| 222 | + if size1 != size2: |
| 223 | + paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation) |
| 224 | + paste_masks = F.resize(paste_masks, size=size1) |
| 225 | + paste_boxes = F.resize(paste_boxes, size=size1) |
| 226 | + |
| 227 | + paste_alpha_mask = paste_masks.sum(dim=0) > 0 |
| 228 | + |
| 229 | + if blending: |
| 230 | + paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0]) |
| 231 | + |
| 232 | + # Copy-paste images: |
| 233 | + image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) |
| 234 | + |
| 235 | + # Copy-paste masks: |
| 236 | + masks = masks * (~paste_alpha_mask) |
| 237 | + non_all_zero_masks = masks.sum((-1, -2)) > 0 |
| 238 | + masks = masks[non_all_zero_masks] |
| 239 | + |
| 240 | + # Do a shallow copy of the target dict |
| 241 | + out_target = {k: v for k, v in target.items()} |
| 242 | + |
| 243 | + out_target["masks"] = torch.cat([masks, paste_masks]) |
| 244 | + |
| 245 | + # Copy-paste boxes and labels |
| 246 | + bbox_format = target["boxes"].format |
| 247 | + xyxy_boxes = masks_to_boxes(masks) |
| 248 | + # TODO: masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive |
| 249 | + # we need to add +1 to x2y2. We need to investigate that. |
| 250 | + xyxy_boxes[:, 2:] += 1 |
| 251 | + boxes = F.convert_bounding_box_format( |
| 252 | + xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False |
| 253 | + ) |
| 254 | + out_target["boxes"] = torch.cat([boxes, paste_boxes]) |
| 255 | + |
| 256 | + labels = target["labels"][non_all_zero_masks] |
| 257 | + out_target["labels"] = torch.cat([labels, paste_labels]) |
| 258 | + |
| 259 | + # Check for degenerated boxes and remove them |
| 260 | + boxes = F.convert_bounding_box_format( |
| 261 | + out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY |
| 262 | + ) |
| 263 | + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] |
| 264 | + if degenerate_boxes.any(): |
| 265 | + valid_targets = ~degenerate_boxes.any(dim=1) |
| 266 | + |
| 267 | + out_target["boxes"] = boxes[valid_targets] |
| 268 | + out_target["masks"] = out_target["masks"][valid_targets] |
| 269 | + out_target["labels"] = out_target["labels"][valid_targets] |
| 270 | + |
| 271 | + return image, out_target |
| 272 | + |
| 273 | + def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]: |
| 274 | + # fetch all images, bboxes, masks and labels from unstructured input |
| 275 | + # with List[image], List[BoundingBox], List[SegmentationMask], List[Label] |
| 276 | + images, bboxes, masks, labels = [], [], [], [] |
| 277 | + for obj in flat_sample: |
| 278 | + if isinstance(obj, features.Image) or is_simple_tensor(obj): |
| 279 | + images.append(obj) |
| 280 | + elif isinstance(obj, PIL.Image.Image): |
| 281 | + images.append(pil_to_tensor(obj)) |
| 282 | + elif isinstance(obj, features.BoundingBox): |
| 283 | + bboxes.append(obj) |
| 284 | + elif isinstance(obj, features.SegmentationMask): |
| 285 | + masks.append(obj) |
| 286 | + elif isinstance(obj, (features.Label, features.OneHotLabel)): |
| 287 | + labels.append(obj) |
| 288 | + |
| 289 | + if not (len(images) == len(bboxes) == len(masks) == len(labels)): |
| 290 | + raise TypeError( |
| 291 | + f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, " |
| 292 | + "BoundingBoxes, Segmentation Masks and Labels or OneHotLabels." |
| 293 | + ) |
| 294 | + |
| 295 | + targets = [] |
| 296 | + for bbox, mask, label in zip(bboxes, masks, labels): |
| 297 | + targets.append({"boxes": bbox, "masks": mask, "labels": label}) |
| 298 | + |
| 299 | + return images, targets |
| 300 | + |
| 301 | + def _insert_outputs( |
| 302 | + self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]] |
| 303 | + ) -> None: |
| 304 | + c0, c1, c2, c3 = 0, 0, 0, 0 |
| 305 | + for i, obj in enumerate(flat_sample): |
| 306 | + if isinstance(obj, features.Image): |
| 307 | + flat_sample[i] = features.Image.new_like(obj, output_images[c0]) |
| 308 | + c0 += 1 |
| 309 | + elif isinstance(obj, PIL.Image.Image): |
| 310 | + flat_sample[i] = F.to_image_pil(output_images[c0]) |
| 311 | + c0 += 1 |
| 312 | + elif is_simple_tensor(obj): |
| 313 | + flat_sample[i] = output_images[c0] |
| 314 | + c0 += 1 |
| 315 | + elif isinstance(obj, features.BoundingBox): |
| 316 | + flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"]) |
| 317 | + c1 += 1 |
| 318 | + elif isinstance(obj, features.SegmentationMask): |
| 319 | + flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"]) |
| 320 | + c2 += 1 |
| 321 | + elif isinstance(obj, (features.Label, features.OneHotLabel)): |
| 322 | + flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] |
| 323 | + c3 += 1 |
| 324 | + |
| 325 | + def forward(self, *inputs: Any) -> Any: |
| 326 | + sample = inputs if len(inputs) > 1 else inputs[0] |
| 327 | + |
| 328 | + flat_sample, spec = tree_flatten(sample) |
| 329 | + |
| 330 | + images, targets = self._extract_image_targets(flat_sample) |
| 331 | + |
| 332 | + # images = [t1, t2, ..., tN] |
| 333 | + # Let's define paste_images as shifted list of input images |
| 334 | + # paste_images = [t2, t3, ..., tN, t1] |
| 335 | + # FYI: in TF they mix data on the dataset level |
| 336 | + images_rolled = images[-1:] + images[:-1] |
| 337 | + targets_rolled = targets[-1:] + targets[:-1] |
| 338 | + |
| 339 | + output_images, output_targets = [], [] |
| 340 | + |
| 341 | + for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): |
| 342 | + |
| 343 | + # Random paste targets selection: |
| 344 | + num_masks = len(paste_target["masks"]) |
| 345 | + |
| 346 | + if num_masks < 1: |
| 347 | + # Such degerante case with num_masks=0 can happen with LSJ |
| 348 | + # Let's just return (image, target) |
| 349 | + output_image, output_target = image, target |
| 350 | + else: |
| 351 | + random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) |
| 352 | + random_selection = torch.unique(random_selection) |
| 353 | + |
| 354 | + output_image, output_target = self._copy_paste( |
| 355 | + image, |
| 356 | + target, |
| 357 | + paste_image, |
| 358 | + paste_target, |
| 359 | + random_selection=random_selection, |
| 360 | + blending=self.blending, |
| 361 | + resize_interpolation=self.resize_interpolation, |
| 362 | + ) |
| 363 | + output_images.append(output_image) |
| 364 | + output_targets.append(output_target) |
| 365 | + |
| 366 | + # Insert updated images and targets into input flat_sample |
| 367 | + self._insert_outputs(flat_sample, output_images, output_targets) |
| 368 | + |
| 369 | + return tree_unflatten(flat_sample, spec) |
0 commit comments