Skip to content

Commit 13e5eb6

Browse files
authored
Merge branch 'main' into utils/sequence_to_str
2 parents 4ccc7f1 + 5ddd564 commit 13e5eb6

File tree

8 files changed

+230
-25
lines changed

8 files changed

+230
-25
lines changed

references/detection/transforms.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,3 +326,78 @@ def forward(
326326
)
327327

328328
return image, target
329+
330+
331+
class FixedSizeCrop(nn.Module):
332+
def __init__(self, size, fill=0, padding_mode="constant"):
333+
super().__init__()
334+
size = tuple(T._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
335+
self.crop_height = size[0]
336+
self.crop_width = size[1]
337+
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
338+
self.padding_mode = padding_mode
339+
340+
def _pad(self, img, target, padding):
341+
# Taken from the functional_tensor.py pad
342+
if isinstance(padding, int):
343+
pad_left = pad_right = pad_top = pad_bottom = padding
344+
elif len(padding) == 1:
345+
pad_left = pad_right = pad_top = pad_bottom = padding[0]
346+
elif len(padding) == 2:
347+
pad_left = pad_right = padding[0]
348+
pad_top = pad_bottom = padding[1]
349+
else:
350+
pad_left = padding[0]
351+
pad_top = padding[1]
352+
pad_right = padding[2]
353+
pad_bottom = padding[3]
354+
355+
padding = [pad_left, pad_top, pad_right, pad_bottom]
356+
img = F.pad(img, padding, self.fill, self.padding_mode)
357+
if target is not None:
358+
target["boxes"][:, 0::2] += pad_left
359+
target["boxes"][:, 1::2] += pad_top
360+
if "masks" in target:
361+
target["masks"] = F.pad(target["masks"], padding, 0, "constant")
362+
363+
return img, target
364+
365+
def _crop(self, img, target, top, left, height, width):
366+
img = F.crop(img, top, left, height, width)
367+
if target is not None:
368+
boxes = target["boxes"]
369+
boxes[:, 0::2] -= left
370+
boxes[:, 1::2] -= top
371+
boxes[:, 0::2].clamp_(min=0, max=width)
372+
boxes[:, 1::2].clamp_(min=0, max=height)
373+
374+
is_valid = (boxes[:, 0] < boxes[:, 2]) & (boxes[:, 1] < boxes[:, 3])
375+
376+
target["boxes"] = boxes[is_valid]
377+
target["labels"] = target["labels"][is_valid]
378+
if "masks" in target:
379+
target["masks"] = F.crop(target["masks"][is_valid], top, left, height, width)
380+
381+
return img, target
382+
383+
def forward(self, img, target=None):
384+
_, height, width = F.get_dimensions(img)
385+
new_height = min(height, self.crop_height)
386+
new_width = min(width, self.crop_width)
387+
388+
if new_height != height or new_width != width:
389+
offset_height = max(height - self.crop_height, 0)
390+
offset_width = max(width - self.crop_width, 0)
391+
392+
r = torch.rand(1)
393+
top = int(offset_height * r)
394+
left = int(offset_width * r)
395+
396+
img, target = self._crop(img, target, top, left, new_height, new_width)
397+
398+
pad_bottom = max(self.crop_height - new_height, 0)
399+
pad_right = max(self.crop_width - new_width, 0)
400+
if pad_bottom != 0 or pad_right != 0:
401+
img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom])
402+
403+
return img, target

torchvision/prototype/features/_bounding_box.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox:
6464
from torchvision.prototype.transforms.functional import convert_bounding_box_format
6565

6666
if isinstance(format, str):
67-
format = BoundingBoxFormat[format]
67+
format = BoundingBoxFormat.from_str(format.upper())
6868

6969
return BoundingBox.new_like(
7070
self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format

torchvision/prototype/transforms/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,16 @@
77
from ._augment import RandomErasing, RandomMixup, RandomCutmix
88
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
99
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
10-
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop
10+
from ._geometry import (
11+
HorizontalFlip,
12+
Resize,
13+
CenterCrop,
14+
RandomResizedCrop,
15+
FiveCrop,
16+
TenCrop,
17+
BatchMultiCrop,
18+
RandomZoomOut,
19+
)
1120
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
1221
from ._misc import Identity, Normalize, ToDtype, Lambda
1322
from ._presets import (

torchvision/prototype/transforms/_geometry.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,3 +256,88 @@ def apply_recursively(obj: Any) -> Any:
256256
return obj
257257

258258
return apply_recursively(inputs if len(inputs) > 1 else inputs[0])
259+
260+
261+
class RandomZoomOut(Transform):
262+
def __init__(
263+
self, fill: Union[float, Sequence[float]] = 0.0, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
264+
) -> None:
265+
super().__init__()
266+
267+
if fill is None:
268+
fill = 0.0
269+
self.fill = fill
270+
271+
self.side_range = side_range
272+
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
273+
raise ValueError(f"Invalid canvas side range provided {side_range}.")
274+
275+
self.p = p
276+
277+
def _get_params(self, sample: Any) -> Dict[str, Any]:
278+
image = query_image(sample)
279+
orig_c, orig_h, orig_w = get_image_dimensions(image)
280+
281+
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
282+
canvas_width = int(orig_w * r)
283+
canvas_height = int(orig_h * r)
284+
285+
r = torch.rand(2)
286+
left = int((canvas_width - orig_w) * r[0])
287+
top = int((canvas_height - orig_h) * r[1])
288+
right = canvas_width - (left + orig_w)
289+
bottom = canvas_height - (top + orig_h)
290+
padding = [left, top, right, bottom]
291+
292+
fill = self.fill
293+
if not isinstance(fill, collections.abc.Sequence):
294+
fill = [fill] * orig_c
295+
296+
return dict(padding=padding, fill=fill)
297+
298+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
299+
if isinstance(input, features.Image) or is_simple_tensor(input):
300+
# PyTorch's pad supports only integers on fill. So we need to overwrite the colour
301+
output = F.pad_image_tensor(input, params["padding"], fill=0, padding_mode="constant")
302+
303+
left, top, right, bottom = params["padding"]
304+
fill = torch.tensor(params["fill"], dtype=input.dtype, device=input.device).to().view(-1, 1, 1)
305+
306+
if top > 0:
307+
output[..., :top, :] = fill
308+
if left > 0:
309+
output[..., :, :left] = fill
310+
if bottom > 0:
311+
output[..., -bottom:, :] = fill
312+
if right > 0:
313+
output[..., :, -right:] = fill
314+
315+
if isinstance(input, features.Image):
316+
output = features.Image.new_like(input, output)
317+
318+
return output
319+
elif isinstance(input, PIL.Image.Image):
320+
return F.pad_image_pil(
321+
input,
322+
params["padding"],
323+
fill=tuple(int(v) if input.mode != "F" else v for v in params["fill"]),
324+
padding_mode="constant",
325+
)
326+
elif isinstance(input, features.BoundingBox):
327+
output = F.pad_bounding_box(input, params["padding"], format=input.format)
328+
329+
left, top, right, bottom = params["padding"]
330+
height, width = input.image_size
331+
height += top + bottom
332+
width += left + right
333+
334+
return features.BoundingBox.new_like(input, output, image_size=(height, width))
335+
else:
336+
return input
337+
338+
def forward(self, *inputs: Any) -> Any:
339+
sample = inputs if len(inputs) > 1 else inputs[0]
340+
if torch.rand(1) >= self.p:
341+
return sample
342+
343+
return super().forward(sample)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
rotate_image_pil,
5555
pad_image_tensor,
5656
pad_image_pil,
57+
pad_bounding_box,
5758
crop_image_tensor,
5859
crop_image_pil,
5960
perspective_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def horizontal_flip_bounding_box(
2727
bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]]
2828

2929
return convert_bounding_box_format(
30-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
30+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
3131
).view(shape)
3232

3333

@@ -210,6 +210,26 @@ def rotate_image_pil(
210210
pad_image_tensor = _FT.pad
211211
pad_image_pil = _FP.pad
212212

213+
214+
def pad_bounding_box(
215+
bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat
216+
) -> torch.Tensor:
217+
left, _, top, _ = _FT._parse_pad_padding(padding)
218+
219+
shape = bounding_box.shape
220+
221+
bounding_box = convert_bounding_box_format(
222+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
223+
).view(-1, 4)
224+
225+
bounding_box[:, 0::2] += left
226+
bounding_box[:, 1::2] += top
227+
228+
return convert_bounding_box_format(
229+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
230+
).view(shape)
231+
232+
213233
crop_image_tensor = _FT.crop
214234
crop_image_pil = _FP.crop
215235

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,13 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
4040

4141

4242
def convert_bounding_box_format(
43-
bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat
43+
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, copy: bool = True
4444
) -> torch.Tensor:
4545
if new_format == old_format:
46-
return bounding_box.clone()
46+
if copy:
47+
return bounding_box.clone()
48+
else:
49+
return bounding_box
4750

4851
if old_format == BoundingBoxFormat.XYWH:
4952
bounding_box = _xywh_to_xyxy(bounding_box)
@@ -89,10 +92,13 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
8992

9093

9194
def convert_image_color_space_tensor(
92-
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
95+
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True
9396
) -> torch.Tensor:
9497
if new_color_space == old_color_space:
95-
return image.clone()
98+
if copy:
99+
return image.clone()
100+
else:
101+
return image
96102

97103
if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER:
98104
raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.")
@@ -135,11 +141,16 @@ def convert_image_color_space_tensor(
135141
}
136142

137143

138-
def convert_image_color_space_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image:
144+
def convert_image_color_space_pil(
145+
image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True
146+
) -> PIL.Image.Image:
139147
old_mode = image.mode
140148
try:
141149
new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space]
142150
except KeyError:
143151
raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.")
144152

153+
if not copy and image.mode == new_mode:
154+
return image
155+
145156
return image.convert(new_mode)

torchvision/transforms/functional_tensor.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,26 @@ def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
350350
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
351351

352352

353+
def _parse_pad_padding(padding: List[int]) -> List[int]:
354+
if isinstance(padding, int):
355+
if torch.jit.is_scripting():
356+
# This maybe unreachable
357+
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
358+
pad_left = pad_right = pad_top = pad_bottom = padding
359+
elif len(padding) == 1:
360+
pad_left = pad_right = pad_top = pad_bottom = padding[0]
361+
elif len(padding) == 2:
362+
pad_left = pad_right = padding[0]
363+
pad_top = pad_bottom = padding[1]
364+
else:
365+
pad_left = padding[0]
366+
pad_top = padding[1]
367+
pad_right = padding[2]
368+
pad_bottom = padding[3]
369+
370+
return [pad_left, pad_right, pad_top, pad_bottom]
371+
372+
353373
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
354374
_assert_image_tensor(img)
355375

@@ -369,23 +389,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
369389
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
370390
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
371391

372-
if isinstance(padding, int):
373-
if torch.jit.is_scripting():
374-
# This maybe unreachable
375-
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
376-
pad_left = pad_right = pad_top = pad_bottom = padding
377-
elif len(padding) == 1:
378-
pad_left = pad_right = pad_top = pad_bottom = padding[0]
379-
elif len(padding) == 2:
380-
pad_left = pad_right = padding[0]
381-
pad_top = pad_bottom = padding[1]
382-
else:
383-
pad_left = padding[0]
384-
pad_top = padding[1]
385-
pad_right = padding[2]
386-
pad_bottom = padding[3]
387-
388-
p = [pad_left, pad_right, pad_top, pad_bottom]
392+
p = _parse_pad_padding(padding)
389393

390394
if padding_mode == "edge":
391395
# remap padding_mode str

0 commit comments

Comments
 (0)