Skip to content

Commit 669f201

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] cleanup spatial_size -> canvas_size (#7783)
Reviewed By: matteobettini Differential Revision: D48642317 fbshipit-source-id: 536290b00dbf1f613cdfb71c56eb4bab19f9771b
1 parent 00149cc commit 669f201

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

test/common_utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def load(self, device="cpu"):
423423
)
424424

425425

426-
def _parse_canvas_size(size, *, name="size"):
426+
def _parse_size(size, *, name="size"):
427427
if size == "random":
428428
raise ValueError("This should never happen")
429429
elif isinstance(size, int) and size > 0:
@@ -478,13 +478,13 @@ def load(self, device):
478478

479479
@dataclasses.dataclass
480480
class ImageLoader(TensorLoader):
481-
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
481+
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
482482
num_channels: int = dataclasses.field(init=False)
483483
memory_format: torch.memory_format = torch.contiguous_format
484484
canvas_size: Tuple[int, int] = dataclasses.field(init=False)
485485

486486
def __post_init__(self):
487-
self.canvas_size = self.canvas_size = self.shape[-2:]
487+
self.spatial_size = self.canvas_size = self.shape[-2:]
488488
self.num_channels = self.shape[-3]
489489

490490
def load(self, device):
@@ -550,7 +550,7 @@ def make_image_loader(
550550
):
551551
if not constant_alpha:
552552
raise ValueError("This should never happen")
553-
size = _parse_canvas_size(size)
553+
size = _parse_size(size)
554554
num_channels = get_num_channels(color_space)
555555

556556
def fn(shape, dtype, device, memory_format):
@@ -590,7 +590,7 @@ def make_image_loaders(
590590
def make_image_loader_for_interpolation(
591591
size=(233, 147), *, color_space="RGB", dtype=torch.uint8, memory_format=torch.contiguous_format
592592
):
593-
size = _parse_canvas_size(size)
593+
size = _parse_size(size)
594594
num_channels = get_num_channels(color_space)
595595

596596
def fn(shape, dtype, device, memory_format):
@@ -687,33 +687,33 @@ def sample_position(values, max_value):
687687
)
688688

689689

690-
def make_bounding_box_loader(*, extra_dims=(), format, canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
690+
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE, dtype=torch.float32):
691691
if isinstance(format, str):
692692
format = datapoints.BoundingBoxFormat[format]
693693

694-
canvas_size = _parse_canvas_size(canvas_size, name="canvas_size")
694+
spatial_size = _parse_size(spatial_size, name="canvas_size")
695695

696696
def fn(shape, dtype, device):
697697
*batch_dims, num_coordinates = shape
698698
if num_coordinates != 4:
699699
raise pytest.UsageError()
700700

701701
return make_bounding_box(
702-
format=format, canvas_size=canvas_size, batch_dims=batch_dims, dtype=dtype, device=device
702+
format=format, canvas_size=spatial_size, batch_dims=batch_dims, dtype=dtype, device=device
703703
)
704704

705-
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=canvas_size)
705+
return BoundingBoxesLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
706706

707707

708708
def make_bounding_box_loaders(
709709
*,
710710
extra_dims=DEFAULT_EXTRA_DIMS,
711711
formats=tuple(datapoints.BoundingBoxFormat),
712-
canvas_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
712+
spatial_size=DEFAULT_PORTRAIT_SPATIAL_SIZE,
713713
dtypes=(torch.float32, torch.float64, torch.int64),
714714
):
715715
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
716-
yield make_bounding_box_loader(**params, canvas_size=canvas_size)
716+
yield make_bounding_box_loader(**params, spatial_size=spatial_size)
717717

718718

719719
make_bounding_boxes = from_loaders(make_bounding_box_loaders)
@@ -738,7 +738,7 @@ def make_detection_mask(size=DEFAULT_SIZE, *, num_objects=5, batch_dims=(), dtyp
738738

739739
def make_detection_mask_loader(size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_objects=5, extra_dims=(), dtype=torch.uint8):
740740
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
741-
size = _parse_canvas_size(size)
741+
size = _parse_size(size)
742742

743743
def fn(shape, dtype, device):
744744
*batch_dims, num_objects, height, width = shape
@@ -779,15 +779,15 @@ def make_segmentation_mask_loader(
779779
size=DEFAULT_PORTRAIT_SPATIAL_SIZE, *, num_categories=10, extra_dims=(), dtype=torch.uint8
780780
):
781781
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
782-
canvas_size = _parse_canvas_size(size)
782+
size = _parse_size(size)
783783

784784
def fn(shape, dtype, device):
785785
*batch_dims, height, width = shape
786786
return make_segmentation_mask(
787787
(height, width), num_categories=num_categories, batch_dims=batch_dims, dtype=dtype, device=device
788788
)
789789

790-
return MaskLoader(fn, shape=(*extra_dims, *canvas_size), dtype=dtype)
790+
return MaskLoader(fn, shape=(*extra_dims, *size), dtype=dtype)
791791

792792

793793
def make_segmentation_mask_loaders(
@@ -841,7 +841,7 @@ def make_video_loader(
841841
extra_dims=(),
842842
dtype=torch.uint8,
843843
):
844-
size = _parse_canvas_size(size)
844+
size = _parse_size(size)
845845

846846
def fn(shape, dtype, device, memory_format):
847847
*batch_dims, num_frames, _, height, width = shape

test/test_transforms_v2_functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -884,7 +884,7 @@ def _compute_expected_bbox(bbox, pcoeffs_):
884884
pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
885885
inv_pcoeffs = _get_perspective_coeffs(endpoints, startpoints)
886886

887-
for bboxes in make_bounding_boxes(canvas_size=canvas_size, extra_dims=((4,),)):
887+
for bboxes in make_bounding_boxes(spatial_size=canvas_size, extra_dims=((4,),)):
888888
bboxes = bboxes.to(device)
889889

890890
output_bboxes = F.perspective_bounding_boxes(

0 commit comments

Comments
 (0)