Skip to content

Commit c6a715c

Browse files
vfdev-5pmeier
andauthored
[proto] Ported SimpleCopyPaste transform (#6451)
* WIP * [proto] Added SimpleCopyPaste transform * Refactored and cleaned the implementation and added tests * Fixing code * Fixed code formatting issue * Minor updates * Fixed merge issue Co-authored-by: Philip Meier <[email protected]>
1 parent d8025b9 commit c6a715c

File tree

3 files changed

+282
-2
lines changed

3 files changed

+282
-2
lines changed

test/test_prototype_transforms.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1313,6 +1313,97 @@ def test__transform(self, mocker):
13131313
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
13141314

13151315

1316+
class TestSimpleCopyPaste:
1317+
def create_fake_image(self, mocker, image_type):
1318+
if image_type == PIL.Image.Image:
1319+
return PIL.Image.new("RGB", (32, 32), 123)
1320+
return mocker.MagicMock(spec=image_type)
1321+
1322+
def test__extract_image_targets_assertion(self, mocker):
1323+
transform = transforms.SimpleCopyPaste()
1324+
1325+
flat_sample = [
1326+
# images, batch size = 2
1327+
self.create_fake_image(mocker, features.Image),
1328+
# labels, bboxes, masks
1329+
mocker.MagicMock(spec=features.Label),
1330+
mocker.MagicMock(spec=features.BoundingBox),
1331+
mocker.MagicMock(spec=features.SegmentationMask),
1332+
# labels, bboxes, masks
1333+
mocker.MagicMock(spec=features.BoundingBox),
1334+
mocker.MagicMock(spec=features.SegmentationMask),
1335+
]
1336+
1337+
with pytest.raises(TypeError, match="requires input sample to contain equal-sized list of Images"):
1338+
transform._extract_image_targets(flat_sample)
1339+
1340+
@pytest.mark.parametrize("image_type", [features.Image, PIL.Image.Image, torch.Tensor])
1341+
def test__extract_image_targets(self, image_type, mocker):
1342+
transform = transforms.SimpleCopyPaste()
1343+
1344+
flat_sample = [
1345+
# images, batch size = 2
1346+
self.create_fake_image(mocker, image_type),
1347+
self.create_fake_image(mocker, image_type),
1348+
# labels, bboxes, masks
1349+
mocker.MagicMock(spec=features.Label),
1350+
mocker.MagicMock(spec=features.BoundingBox),
1351+
mocker.MagicMock(spec=features.SegmentationMask),
1352+
# labels, bboxes, masks
1353+
mocker.MagicMock(spec=features.Label),
1354+
mocker.MagicMock(spec=features.BoundingBox),
1355+
mocker.MagicMock(spec=features.SegmentationMask),
1356+
]
1357+
1358+
images, targets = transform._extract_image_targets(flat_sample)
1359+
1360+
assert len(images) == len(targets) == 2
1361+
if image_type == PIL.Image.Image:
1362+
torch.testing.assert_close(images[0], pil_to_tensor(flat_sample[0]))
1363+
torch.testing.assert_close(images[1], pil_to_tensor(flat_sample[1]))
1364+
else:
1365+
assert images[0] == flat_sample[0]
1366+
assert images[1] == flat_sample[1]
1367+
1368+
def test__copy_paste(self):
1369+
image = 2 * torch.ones(3, 32, 32)
1370+
masks = torch.zeros(2, 32, 32)
1371+
masks[0, 3:9, 2:8] = 1
1372+
masks[1, 20:30, 20:30] = 1
1373+
target = {
1374+
"boxes": features.BoundingBox(
1375+
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32)
1376+
),
1377+
"masks": features.SegmentationMask(masks),
1378+
"labels": features.Label(torch.tensor([1, 2])),
1379+
}
1380+
1381+
paste_image = 10 * torch.ones(3, 32, 32)
1382+
paste_masks = torch.zeros(2, 32, 32)
1383+
paste_masks[0, 13:19, 12:18] = 1
1384+
paste_masks[1, 15:19, 1:8] = 1
1385+
paste_target = {
1386+
"boxes": features.BoundingBox(
1387+
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32)
1388+
),
1389+
"masks": features.SegmentationMask(paste_masks),
1390+
"labels": features.Label(torch.tensor([3, 4])),
1391+
}
1392+
1393+
transform = transforms.SimpleCopyPaste()
1394+
random_selection = torch.tensor([0, 1])
1395+
output_image, output_target = transform._copy_paste(image, target, paste_image, paste_target, random_selection)
1396+
1397+
assert output_image.unique().tolist() == [2, 10]
1398+
assert output_target["boxes"].shape == (4, 4)
1399+
torch.testing.assert_close(output_target["boxes"][:2, :], target["boxes"])
1400+
torch.testing.assert_close(output_target["boxes"][2:, :], paste_target["boxes"])
1401+
torch.testing.assert_close(output_target["labels"], features.Label(torch.tensor([1, 2, 3, 4])))
1402+
assert output_target["masks"].shape == (4, 32, 32)
1403+
torch.testing.assert_close(output_target["masks"][:2, :], target["masks"])
1404+
torch.testing.assert_close(output_target["masks"][2:, :], paste_target["masks"])
1405+
1406+
13161407
class TestFixedSizeCrop:
13171408
def test__get_params(self, mocker):
13181409
crop_size = (7, 7)

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from ._transform import Transform # usort: skip
44

5-
from ._augment import RandomCutmix, RandomErasing, RandomMixup
5+
from ._augment import RandomCutmix, RandomErasing, RandomMixup, SimpleCopyPaste
66
from ._auto_augment import AugMix, AutoAugment, AutoAugmentPolicy, RandAugment, TrivialAugmentWide
77
from ._color import (
88
ColorJitter,

torchvision/prototype/transforms/_augment.py

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import math
22
import numbers
33
import warnings
4-
from typing import Any, Dict, Tuple
4+
from typing import Any, Dict, List, Tuple
55

6+
import PIL.Image
67
import torch
8+
from torch.utils._pytree import tree_flatten, tree_unflatten
9+
from torchvision.ops import masks_to_boxes
710
from torchvision.prototype import features
11+
812
from torchvision.prototype.transforms import functional as F
13+
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor
914

1015
from ._transform import _RandomApplyTransform
1116
from ._utils import has_any, is_simple_tensor, query_chw
@@ -178,3 +183,187 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
178183
return self._mixup_onehotlabel(inpt, lam_adjusted)
179184
else:
180185
return inpt
186+
187+
188+
class SimpleCopyPaste(_RandomApplyTransform):
189+
def __init__(
190+
self,
191+
p: float = 0.5,
192+
blending: bool = True,
193+
resize_interpolation: InterpolationMode = F.InterpolationMode.BILINEAR,
194+
) -> None:
195+
super().__init__(p=p)
196+
self.resize_interpolation = resize_interpolation
197+
self.blending = blending
198+
199+
def _copy_paste(
200+
self,
201+
image: Any,
202+
target: Dict[str, Any],
203+
paste_image: Any,
204+
paste_target: Dict[str, Any],
205+
random_selection: torch.Tensor,
206+
blending: bool = True,
207+
resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR,
208+
) -> Tuple[Any, Dict[str, Any]]:
209+
210+
paste_masks = paste_target["masks"].new_like(paste_target["masks"], paste_target["masks"][random_selection])
211+
paste_boxes = paste_target["boxes"].new_like(paste_target["boxes"], paste_target["boxes"][random_selection])
212+
paste_labels = paste_target["labels"].new_like(paste_target["labels"], paste_target["labels"][random_selection])
213+
214+
masks = target["masks"]
215+
216+
# We resize source and paste data if they have different sizes
217+
# This is something different to TF implementation we introduced here as
218+
# originally the algorithm works on equal-sized data
219+
# (for example, coming from LSJ data augmentations)
220+
size1 = image.shape[-2:]
221+
size2 = paste_image.shape[-2:]
222+
if size1 != size2:
223+
paste_image = F.resize(paste_image, size=size1, interpolation=resize_interpolation)
224+
paste_masks = F.resize(paste_masks, size=size1)
225+
paste_boxes = F.resize(paste_boxes, size=size1)
226+
227+
paste_alpha_mask = paste_masks.sum(dim=0) > 0
228+
229+
if blending:
230+
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0])
231+
232+
# Copy-paste images:
233+
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
234+
235+
# Copy-paste masks:
236+
masks = masks * (~paste_alpha_mask)
237+
non_all_zero_masks = masks.sum((-1, -2)) > 0
238+
masks = masks[non_all_zero_masks]
239+
240+
# Do a shallow copy of the target dict
241+
out_target = {k: v for k, v in target.items()}
242+
243+
out_target["masks"] = torch.cat([masks, paste_masks])
244+
245+
# Copy-paste boxes and labels
246+
bbox_format = target["boxes"].format
247+
xyxy_boxes = masks_to_boxes(masks)
248+
# TODO: masks_to_boxes produces bboxes with x2y2 inclusive but x2y2 should be exclusive
249+
# we need to add +1 to x2y2. We need to investigate that.
250+
xyxy_boxes[:, 2:] += 1
251+
boxes = F.convert_bounding_box_format(
252+
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
253+
)
254+
out_target["boxes"] = torch.cat([boxes, paste_boxes])
255+
256+
labels = target["labels"][non_all_zero_masks]
257+
out_target["labels"] = torch.cat([labels, paste_labels])
258+
259+
# Check for degenerated boxes and remove them
260+
boxes = F.convert_bounding_box_format(
261+
out_target["boxes"], old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY
262+
)
263+
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
264+
if degenerate_boxes.any():
265+
valid_targets = ~degenerate_boxes.any(dim=1)
266+
267+
out_target["boxes"] = boxes[valid_targets]
268+
out_target["masks"] = out_target["masks"][valid_targets]
269+
out_target["labels"] = out_target["labels"][valid_targets]
270+
271+
return image, out_target
272+
273+
def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]:
274+
# fetch all images, bboxes, masks and labels from unstructured input
275+
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
276+
images, bboxes, masks, labels = [], [], [], []
277+
for obj in flat_sample:
278+
if isinstance(obj, features.Image) or is_simple_tensor(obj):
279+
images.append(obj)
280+
elif isinstance(obj, PIL.Image.Image):
281+
images.append(pil_to_tensor(obj))
282+
elif isinstance(obj, features.BoundingBox):
283+
bboxes.append(obj)
284+
elif isinstance(obj, features.SegmentationMask):
285+
masks.append(obj)
286+
elif isinstance(obj, (features.Label, features.OneHotLabel)):
287+
labels.append(obj)
288+
289+
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
290+
raise TypeError(
291+
f"{type(self).__name__}() requires input sample to contain equal-sized list of Images, "
292+
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels."
293+
)
294+
295+
targets = []
296+
for bbox, mask, label in zip(bboxes, masks, labels):
297+
targets.append({"boxes": bbox, "masks": mask, "labels": label})
298+
299+
return images, targets
300+
301+
def _insert_outputs(
302+
self, flat_sample: List[Any], output_images: List[Any], output_targets: List[Dict[str, Any]]
303+
) -> None:
304+
c0, c1, c2, c3 = 0, 0, 0, 0
305+
for i, obj in enumerate(flat_sample):
306+
if isinstance(obj, features.Image):
307+
flat_sample[i] = features.Image.new_like(obj, output_images[c0])
308+
c0 += 1
309+
elif isinstance(obj, PIL.Image.Image):
310+
flat_sample[i] = F.to_image_pil(output_images[c0])
311+
c0 += 1
312+
elif is_simple_tensor(obj):
313+
flat_sample[i] = output_images[c0]
314+
c0 += 1
315+
elif isinstance(obj, features.BoundingBox):
316+
flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"])
317+
c1 += 1
318+
elif isinstance(obj, features.SegmentationMask):
319+
flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"])
320+
c2 += 1
321+
elif isinstance(obj, (features.Label, features.OneHotLabel)):
322+
flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
323+
c3 += 1
324+
325+
def forward(self, *inputs: Any) -> Any:
326+
sample = inputs if len(inputs) > 1 else inputs[0]
327+
328+
flat_sample, spec = tree_flatten(sample)
329+
330+
images, targets = self._extract_image_targets(flat_sample)
331+
332+
# images = [t1, t2, ..., tN]
333+
# Let's define paste_images as shifted list of input images
334+
# paste_images = [t2, t3, ..., tN, t1]
335+
# FYI: in TF they mix data on the dataset level
336+
images_rolled = images[-1:] + images[:-1]
337+
targets_rolled = targets[-1:] + targets[:-1]
338+
339+
output_images, output_targets = [], []
340+
341+
for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled):
342+
343+
# Random paste targets selection:
344+
num_masks = len(paste_target["masks"])
345+
346+
if num_masks < 1:
347+
# Such degerante case with num_masks=0 can happen with LSJ
348+
# Let's just return (image, target)
349+
output_image, output_target = image, target
350+
else:
351+
random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device)
352+
random_selection = torch.unique(random_selection)
353+
354+
output_image, output_target = self._copy_paste(
355+
image,
356+
target,
357+
paste_image,
358+
paste_target,
359+
random_selection=random_selection,
360+
blending=self.blending,
361+
resize_interpolation=self.resize_interpolation,
362+
)
363+
output_images.append(output_image)
364+
output_targets.append(output_target)
365+
366+
# Insert updated images and targets into input flat_sample
367+
self._insert_outputs(flat_sample, output_images, output_targets)
368+
369+
return tree_unflatten(flat_sample, spec)

0 commit comments

Comments
 (0)