Skip to content

Commit 7804725

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] remove spatial_size (#7734)
Reviewed By: matteobettini Differential Revision: D48642265 fbshipit-source-id: 123d2a3157d4536ea9ac25e0192d54307b31ea1e
1 parent ae428c4 commit 7804725

29 files changed

+440
-491
lines changed

gallery/plot_datapoints.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
# corresponding image alongside the actual values:
8181

8282
bounding_box = datapoints.BoundingBoxes(
83-
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
83+
[17, 16, 344, 495], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
8484
)
8585
print(bounding_box)
8686

@@ -108,7 +108,7 @@ def __getitem__(self, item):
108108
target["boxes"] = datapoints.BoundingBoxes(
109109
boxes,
110110
format=datapoints.BoundingBoxFormat.XYXY,
111-
spatial_size=F.get_spatial_size(img),
111+
canvas_size=F.get_size(img),
112112
)
113113
target["labels"] = labels
114114
target["masks"] = datapoints.Mask(masks)
@@ -129,7 +129,7 @@ def __call__(self, img, target):
129129
target["boxes"] = datapoints.BoundingBoxes(
130130
target["boxes"],
131131
format=datapoints.BoundingBoxFormat.XYXY,
132-
spatial_size=F.get_spatial_size(img),
132+
canvas_size=F.get_size(img),
133133
)
134134
target["masks"] = datapoints.Mask(target["masks"])
135135
return img, target

gallery/plot_transforms_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def load_data():
3030
masks = datapoints.Mask(merged_masks == labels.view(-1, 1, 1))
3131

3232
bounding_boxes = datapoints.BoundingBoxes(
33-
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, spatial_size=image.shape[-2:]
33+
masks_to_boxes(masks), format=datapoints.BoundingBoxFormat.XYXY, canvas_size=image.shape[-2:]
3434
)
3535

3636
return path, image, bounding_boxes, masks, labels

test/common_utils.py

Lines changed: 26 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def load(self, device="cpu"):
412412
)
413413

414414

415-
def _parse_spatial_size(size, *, name="size"):
415+
def _parse_canvas_size(size, *, name="size"):
416416
if size == "random":
417417
raise ValueError("This should never happen")
418418
elif isinstance(size, int) and size > 0:
@@ -467,12 +467,13 @@ def load(self, device):
467467

468468
@dataclasses.dataclass
469469
class ImageLoader(TensorLoader):
470-
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
470+
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
471471
num_channels: int = dataclasses.field(init=False)
472472
memory_format: torch.memory_format = torch.contiguous_format
473+
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
473474

474475
def __post_init__(self):
475-
self.spatial_size = self.shape[-2:]
476+
self.canvas_size = self.canvas_size = self.shape[-2:]
476477
self.num_channels = self.shape[-3]
477478

478479
def load(self, device):
@@ -538,7 +539,7 @@ def make_image_loader(
538539
):
539540
if not constant_alpha:
540541
raise ValueError("This should never happen")
541-
size = _parse_spatial_size(size)
542+
size = _parse_canvas_size(size)
542543
num_channels = get_num_channels(color_space)
543544

544545
def fn(shape, dtype, device, memory_format):
@@ -578,7 +579,7 @@ def make_image_loaders(
578579
def make_image_loader_for_interpolation(
579580
size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
580581
):
581-
size = _parse_spatial_size(size)
582+
size = _parse_canvas_size(size)
582583
num_channels = get_num_channels(color_space)
583584

584585
def fn(shape, dtype, device, memory_format):
@@ -623,43 +624,20 @@ def make_image_loaders_for_interpolation(
623624
class BoundingBoxesLoader(TensorLoader):
624625
format: datapoints.BoundingBoxFormat
625626
spatial_size: Tuple[int, int]
627+
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
628+
629+
def __post_init__(self):
630+
self.canvas_size = self.spatial_size
626631

627632

628633
def make_bounding_box(
629-
size=None,
634+
canvas_size=DEFAULT_SIZE,
630635
*,
631636
format=datapoints.BoundingBoxFormat.XYXY,
632-
spatial_size=None,
633637
batch_dims=(),
634638
dtype=None,
635639
device="cpu",
636640
):
637-
"""
638-
size: Size of the actual bounding box, i.e.
639-
- (box[3] - box[1], box[2] - box[0]) for XYXY
640-
- (H, W) for XYWH and CXCYWH
641-
spatial_size: Size of the reference object, e.g. an image. Corresponds to the .spatial_size attribute on
642-
returned datapoints.BoundingBoxes
643-
644-
To generate a valid joint sample, you need to set spatial_size here to the same value as size on the other maker
645-
functions, e.g.
646-
647-
.. code::
648-
649-
image = make_image=(size=size)
650-
bounding_boxes = make_bounding_box(spatial_size=size)
651-
assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
652-
653-
For convenience, if both size and spatial_size are omitted, spatial_size defaults to the same value as size for all
654-
other maker functions, e.g.
655-
656-
.. code::
657-
658-
image = make_image=()
659-
bounding_boxes = make_bounding_box()
660-
assert F.get_spatial_size(bounding_boxes) == F.get_spatial_size(image)
661-
"""
662-
663641
def sample_position(values, max_value):
664642
# We cannot use torch.randint directly here, because it only allows integer scalars as values for low and high.
665643
# However, if we have batch_dims, we need tensors as limits.
@@ -668,28 +646,16 @@ def sample_position(values, max_value):
668646
if isinstance(format, str):
669647
format = datapoints.BoundingBoxFormat[format]
670648

671-
if spatial_size is None:
672-
if size is None:
673-
spatial_size = DEFAULT_SIZE
674-
else:
675-
height, width = size
676-
height_margin, width_margin = torch.randint(10, (2,)).tolist()
677-
spatial_size = (height + height_margin, width + width_margin)
678-
679649
dtype = dtype or torch.float32
680650

681651
if any(dim == 0 for dim in batch_dims):
682652
return datapoints.BoundingBoxes(
683-
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
653+
torch.empty(*batch_dims, 4, dtype=dtype, device=device), format=format, canvas_size=canvas_size
684654
)
685655

686-
if size is None:
687-
h, w = [torch.randint(1, s, batch_dims) for s in spatial_size]
688-
else:
689-
h, w = [torch.full(batch_dims, s, dtype=torch.int) for s in size]
690-
691-
y = sample_position(h, spatial_size[0])
692-
x = sample_position(w, spatial_size[1])
656+
h, w = [torch.randint(1, c, batch_dims) for c in canvas_size]
657+
y = sample_position(h, canvas_size[0])
658+
x = sample_position(w, canvas_size[1])
693659

694660
if format is datapoints.BoundingBoxFormat.XYWH:
695661
parts = (x, y, w, h)
@@ -706,37 +672,37 @@ def sample_position(values, max_value):
706672
raise ValueError(f"Format {format} is not supported")
707673

708674
return datapoints.BoundingBoxes(
709-
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
675+
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, canvas_size=canvas_size
710676
)
711677

712678

713-
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
679+
def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
714680
if isinstance(format, str):
715681
format = datapoints.BoundingBoxFormat[format]
716682

717-
spatial_size = _parse_spatial_size(spatial_size, name="spatial_size")
683+
canvas_size = _parse_canvas_size(canvas_size, name="canvas_size")
718684

719685
def fn(shape, dtype, device):
720686
*batch_dims, num_coordinates = shape
721687
if num_coordinates != 4:
722688
raise pytest.UsageError()
723689

724690
return make_bounding_box(
725-
format=format, spatial_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
691+
format=format, canvas_size=canvas_size, batch_dims=batch_dims, dtype=dtype, device=device
726692
)
727693

728-
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
694+
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=canvas_size)
729695

730696

731697
def make_bounding_box_loaders(
732698
*,
733699
extra_dims=DEFAULT_EXTRA_DIMS,
734700
formats=tuple(datapoints.BoundingBoxFormat),
735-
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
701+
canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
736702
dtypes=(torch.float32, torch.float64, torch.int64),
737703
):
738704
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
739-
yield make_bounding_box_loader(**params, spatial_size=spatial_size)
705+
yield make_bounding_box_loader(**params, canvas_size=canvas_size)
740706

741707

742708
make_bounding_boxes = from_loaders(make_bounding_box_loaders)
@@ -761,7 +727,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
761727

762728
def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8):
763729
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
764-
size = _parse_spatial_size(size)
730+
size = _parse_canvas_size(size)
765731

766732
def fn(shape, dtype, device):
767733
*batch_dims, num_objects, height, width = shape
@@ -802,15 +768,15 @@ def make_segmentation_mask_loader(
802768
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8
803769
):
804770
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
805-
spatial_size = _parse_spatial_size(size)
771+
canvas_size = _parse_canvas_size(size)
806772

807773
def fn(shape, dtype, device):
808774
*batch_dims, height, width = shape
809775
return make_segmentation_mask(
810776
(height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device
811777
)
812778

813-
return MaskLoader(fn, shape=(*extra_dims, *spatial_size), dtype=dtype)
779+
return MaskLoader(fn, shape=(*extra_dims, *canvas_size), dtype=dtype)
814780

815781

816782
def make_segmentation_mask_loaders(
@@ -860,7 +826,7 @@ def make_video_loader(
860826
extra_dims=(),
861827
dtype=torch.uint8,
862828
):
863-
size = _parse_spatial_size(size)
829+
size = _parse_canvas_size(size)
864830

865831
def fn(shape, dtype, device, memory_format):
866832
*batch_dims, num_frames, _, height, width = shape

test/test_datapoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_mask_instance(data):
2727
"format", ["XYXY", "CXCYWH", datapoints.BoundingBoxFormat.XYXY, datapoints.BoundingBoxFormat.XYWH]
2828
)
2929
def test_bbox_instance(data, format):
30-
bboxes = datapoints.BoundingBoxes(data, format=format, spatial_size=(32, 32))
30+
bboxes = datapoints.BoundingBoxes(data, format=format, canvas_size=(32, 32))
3131
assert isinstance(bboxes, torch.Tensor)
3232
assert bboxes.ndim == 2 and bboxes.shape[1] == 4
3333
if isinstance(format, str):
@@ -164,7 +164,7 @@ def test_wrap_like():
164164
[
165165
datapoints.Image(torch.rand(3, 16, 16)),
166166
datapoints.Video(torch.rand(2, 3, 16, 16)),
167-
datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, spatial_size=(10, 10)),
167+
datapoints.BoundingBoxes([0.0, 1.0, 2.0, 3.0], format=datapoints.BoundingBoxFormat.XYXY, canvas_size=(10, 10)),
168168
datapoints.Mask(torch.randint(0, 256, (16, 16), dtype=torch.uint8)),
169169
],
170170
)

test/test_prototype_transforms.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def test__copy_paste(self, label_type):
164164
labels = torch.nn.functional.one_hot(labels, num_classes=5)
165165
target = {
166166
"boxes": BoundingBoxes(
167-
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", spatial_size=(32, 32)
167+
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", canvas_size=(32, 32)
168168
),
169169
"masks": Mask(masks),
170170
"labels": label_type(labels),
@@ -179,7 +179,7 @@ def test__copy_paste(self, label_type):
179179
paste_labels = torch.nn.functional.one_hot(paste_labels, num_classes=5)
180180
paste_target = {
181181
"boxes": BoundingBoxes(
182-
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", spatial_size=(32, 32)
182+
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", canvas_size=(32, 32)
183183
),
184184
"masks": Mask(paste_masks),
185185
"labels": label_type(paste_labels),
@@ -210,13 +210,13 @@ class TestFixedSizeCrop:
210210
def test__get_params(self, mocker):
211211
crop_size = (7, 7)
212212
batch_shape = (10,)
213-
spatial_size = (11, 5)
213+
canvas_size = (11, 5)
214214

215215
transform = transforms.FixedSizeCrop(size=crop_size)
216216

217217
flat_inputs = [
218-
make_image(size=spatial_size, color_space="RGB"),
219-
make_bounding_box(format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=batch_shape),
218+
make_image(size=canvas_size, color_space="RGB"),
219+
make_bounding_box(format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=batch_shape),
220220
]
221221
params = transform._get_params(flat_inputs)
222222

@@ -295,7 +295,7 @@ def test__transform(self, mocker, needs):
295295

296296
def test__transform_culling(self, mocker):
297297
batch_size = 10
298-
spatial_size = (10, 10)
298+
canvas_size = (10, 10)
299299

300300
is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
301301
mocker.patch(
@@ -304,17 +304,17 @@ def test__transform_culling(self, mocker):
304304
needs_crop=True,
305305
top=0,
306306
left=0,
307-
height=spatial_size[0],
308-
width=spatial_size[1],
307+
height=canvas_size[0],
308+
width=canvas_size[1],
309309
is_valid=is_valid,
310310
needs_pad=False,
311311
),
312312
)
313313

314314
bounding_boxes = make_bounding_box(
315-
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
315+
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
316316
)
317-
masks = make_detection_mask(size=spatial_size, batch_dims=(batch_size,))
317+
masks = make_detection_mask(size=canvas_size, batch_dims=(batch_size,))
318318
labels = make_label(extra_dims=(batch_size,))
319319

320320
transform = transforms.FixedSizeCrop((-1, -1))
@@ -334,23 +334,23 @@ def test__transform_culling(self, mocker):
334334

335335
def test__transform_bounding_boxes_clamping(self, mocker):
336336
batch_size = 3
337-
spatial_size = (10, 10)
337+
canvas_size = (10, 10)
338338

339339
mocker.patch(
340340
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
341341
return_value=dict(
342342
needs_crop=True,
343343
top=0,
344344
left=0,
345-
height=spatial_size[0],
346-
width=spatial_size[1],
345+
height=canvas_size[0],
346+
width=canvas_size[1],
347347
is_valid=torch.full((batch_size,), fill_value=True),
348348
needs_pad=False,
349349
),
350350
)
351351

352352
bounding_boxes = make_bounding_box(
353-
format=BoundingBoxFormat.XYXY, spatial_size=spatial_size, batch_dims=(batch_size,)
353+
format=BoundingBoxFormat.XYXY, canvas_size=canvas_size, batch_dims=(batch_size,)
354354
)
355355
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_boxes")
356356

@@ -496,7 +496,7 @@ def make_datapoints():
496496

497497
pil_image = to_image_pil(make_image(size=size, color_space="RGB"))
498498
target = {
499-
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
499+
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
500500
"labels": make_label(extra_dims=(num_objects,), categories=80),
501501
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
502502
}
@@ -505,7 +505,7 @@ def make_datapoints():
505505

506506
tensor_image = torch.Tensor(make_image(size=size, color_space="RGB"))
507507
target = {
508-
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
508+
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
509509
"labels": make_label(extra_dims=(num_objects,), categories=80),
510510
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
511511
}
@@ -514,7 +514,7 @@ def make_datapoints():
514514

515515
datapoint_image = make_image(size=size, color_space="RGB")
516516
target = {
517-
"boxes": make_bounding_box(spatial_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
517+
"boxes": make_bounding_box(canvas_size=size, format="XYXY", batch_dims=(num_objects,), dtype=torch.float),
518518
"labels": make_label(extra_dims=(num_objects,), categories=80),
519519
"masks": make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long),
520520
}

0 commit comments

Comments
 (0)