Skip to content

Commit c84dbfa

Browse files
authored
[prototype] Speed up Augment Transform Classes (#6835)
* Moving value estimation of `RandomErasing` from runtime to constructor. * Speed up mixing on MixUp/Cutmix and small optimization on SimpleCopyPaste. * Apply nits.
1 parent 8e0e715 commit c84dbfa

File tree

1 file changed

+16
-19
lines changed

1 file changed

+16
-19
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,22 @@ def __init__(
4040
raise ValueError("Scale should be between 0 and 1")
4141
self.scale = scale
4242
self.ratio = ratio
43-
self.value = value
43+
if isinstance(value, (int, float)):
44+
self.value = [value]
45+
elif isinstance(value, str):
46+
self.value = None
47+
elif isinstance(value, tuple):
48+
self.value = list(value)
49+
else:
50+
self.value = value
4451
self.inplace = inplace
4552

4653
self._log_ratio = torch.log(torch.tensor(self.ratio))
4754

4855
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
4956
img_c, img_h, img_w = query_chw(flat_inputs)
5057

51-
if isinstance(self.value, (int, float)):
52-
value = [self.value]
53-
elif isinstance(self.value, str):
54-
value = None
55-
elif isinstance(self.value, tuple):
56-
value = list(self.value)
57-
else:
58-
value = self.value
59-
60-
if value is not None and not (len(value) in (1, img_c)):
58+
if self.value is not None and not (len(self.value) in (1, img_c)):
6159
raise ValueError(
6260
f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)"
6361
)
@@ -79,10 +77,10 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
7977
if not (h < img_h and w < img_w):
8078
continue
8179

82-
if value is None:
80+
if self.value is None:
8381
v = torch.empty([img_c, h, w], dtype=torch.float32).normal_()
8482
else:
85-
v = torch.tensor(value)[:, None, None]
83+
v = torch.tensor(self.value)[:, None, None]
8684

8785
i = torch.randint(0, img_h - h + 1, size=(1,)).item()
8886
j = torch.randint(0, img_w - w + 1, size=(1,)).item()
@@ -121,8 +119,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:
121119
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
122120
if inpt.ndim < 2:
123121
raise ValueError("Need a batch of one hot labels")
124-
output = inpt.clone()
125-
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
122+
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
126123
return features.OneHotLabel.wrap_like(inpt, output)
127124

128125

@@ -136,8 +133,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
136133
expected_ndim = 5 if isinstance(inpt, features.Video) else 4
137134
if inpt.ndim < expected_ndim:
138135
raise ValueError("The transform expects a batched input")
139-
output = inpt.clone()
140-
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
136+
output = inpt.roll(1, 0).mul_(1.0 - lam).add_(inpt.mul(lam))
141137

142138
if isinstance(inpt, (features.Image, features.Video)):
143139
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
@@ -243,11 +239,12 @@ def _copy_paste(
243239
if blending:
244240
paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=[5, 5], sigma=[2.0])
245241

242+
inverse_paste_alpha_mask = paste_alpha_mask.logical_not()
246243
# Copy-paste images:
247-
image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask)
244+
image = image.mul(inverse_paste_alpha_mask).add_(paste_image.mul(paste_alpha_mask))
248245

249246
# Copy-paste masks:
250-
masks = masks * (~paste_alpha_mask)
247+
masks = masks * inverse_paste_alpha_mask
251248
non_all_zero_masks = masks.sum((-1, -2)) > 0
252249
masks = masks[non_all_zero_masks]
253250

0 commit comments

Comments
 (0)