Skip to content

Commit 0e2240c

Browse files
committed
Fixing more tests.
1 parent 07e7e25 commit 0e2240c

6 files changed

+123
-123
lines changed

test/prototype_common_utils.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,13 @@ def load(self, device="cpu"):
184184
return args, kwargs
185185

186186

187-
DEFAULT_SQUARE_IMAGE_SIZE = 15
188-
DEFAULT_LANDSCAPE_IMAGE_SIZE = (7, 33)
189-
DEFAULT_PORTRAIT_IMAGE_SIZE = (31, 9)
190-
DEFAULT_IMAGE_SIZES = (DEFAULT_LANDSCAPE_IMAGE_SIZE, DEFAULT_PORTRAIT_IMAGE_SIZE, DEFAULT_SQUARE_IMAGE_SIZE, "random")
187+
DEFAULT_SQUARE_SPATIAL_SIZE = 15
188+
DEFAULT_LANDSCAPE_SPATIAL_SIZE = (7, 33)
189+
DEFAULT_PORTRAIT_SPATIAL_SIZE = (31, 9)
190+
DEFAULT_SPATIAL_SIZES = (DEFAULT_LANDSCAPE_SPATIAL_SIZE, DEFAULT_PORTRAIT_SPATIAL_SIZE, DEFAULT_SQUARE_SPATIAL_SIZE, "random")
191191

192192

193-
def _parse_image_size(size, *, name="size"):
193+
def _parse_spatial_size(size, *, name="size"):
194194
if size == "random":
195195
return tuple(torch.randint(15, 33, (2,)).tolist())
196196
elif isinstance(size, int) and size > 0:
@@ -246,11 +246,11 @@ def load(self, device):
246246
@dataclasses.dataclass
247247
class ImageLoader(TensorLoader):
248248
color_space: features.ColorSpace
249-
image_size: Tuple[int, int] = dataclasses.field(init=False)
249+
spatial_size: Tuple[int, int] = dataclasses.field(init=False)
250250
num_channels: int = dataclasses.field(init=False)
251251

252252
def __post_init__(self):
253-
self.image_size = self.shape[-2:]
253+
self.spatial_size = self.shape[-2:]
254254
self.num_channels = self.shape[-3]
255255

256256

@@ -277,7 +277,7 @@ def make_image_loader(
277277
dtype=torch.float32,
278278
constant_alpha=True,
279279
):
280-
size = _parse_image_size(size)
280+
size = _parse_spatial_size(size)
281281
num_channels = get_num_channels(color_space)
282282

283283
def fn(shape, dtype, device):
@@ -295,7 +295,7 @@ def fn(shape, dtype, device):
295295

296296
def make_image_loaders(
297297
*,
298-
sizes=DEFAULT_IMAGE_SIZES,
298+
sizes=DEFAULT_SPATIAL_SIZES,
299299
color_spaces=(
300300
features.ColorSpace.GRAY,
301301
features.ColorSpace.GRAY_ALPHA,
@@ -316,7 +316,7 @@ def make_image_loaders(
316316
@dataclasses.dataclass
317317
class BoundingBoxLoader(TensorLoader):
318318
format: features.BoundingBoxFormat
319-
image_size: Tuple[int, int]
319+
spatial_size: Tuple[int, int]
320320

321321

322322
def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
@@ -331,7 +331,7 @@ def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
331331
).reshape(low.shape)
332332

333333

334-
def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtype=torch.float32):
334+
def make_bounding_box_loader(*, extra_dims=(), format, spatial_size="random", dtype=torch.float32):
335335
if isinstance(format, str):
336336
format = features.BoundingBoxFormat[format]
337337
if format not in {
@@ -341,7 +341,7 @@ def make_bounding_box_loader(*, extra_dims=(), format, image_size="random", dtyp
341341
}:
342342
raise pytest.UsageError(f"Can't make bounding box in format {format}")
343343

344-
image_size = _parse_image_size(image_size, name="image_size")
344+
spatial_size = _parse_spatial_size(spatial_size, name="spatial_size")
345345

346346
def fn(shape, dtype, device):
347347
*extra_dims, num_coordinates = shape
@@ -350,10 +350,10 @@ def fn(shape, dtype, device):
350350

351351
if any(dim == 0 for dim in extra_dims):
352352
return features.BoundingBox(
353-
torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=image_size
353+
torch.empty(*extra_dims, 4, dtype=dtype, device=device), format=format, spatial_size=spatial_size
354354
)
355355

356-
height, width = image_size
356+
height, width = spatial_size
357357

358358
if format == features.BoundingBoxFormat.XYXY:
359359
x1 = torch.randint(0, width // 2, extra_dims)
@@ -375,10 +375,10 @@ def fn(shape, dtype, device):
375375
parts = (cx, cy, w, h)
376376

377377
return features.BoundingBox(
378-
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=image_size
378+
torch.stack(parts, dim=-1).to(dtype=dtype, device=device), format=format, spatial_size=spatial_size
379379
)
380380

381-
return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, image_size=image_size)
381+
return BoundingBoxLoader(fn, shape=(*extra_dims, 4), dtype=dtype, format=format, spatial_size=spatial_size)
382382

383383

384384
make_bounding_box = from_loader(make_bounding_box_loader)
@@ -388,11 +388,11 @@ def make_bounding_box_loaders(
388388
*,
389389
extra_dims=DEFAULT_EXTRA_DIMS,
390390
formats=tuple(features.BoundingBoxFormat),
391-
image_size="random",
391+
spatial_size="random",
392392
dtypes=(torch.float32, torch.int64),
393393
):
394394
for params in combinations_grid(extra_dims=extra_dims, format=formats, dtype=dtypes):
395-
yield make_bounding_box_loader(**params, image_size=image_size)
395+
yield make_bounding_box_loader(**params, spatial_size=spatial_size)
396396

397397

398398
make_bounding_boxes = from_loaders(make_bounding_box_loaders)
@@ -475,7 +475,7 @@ class MaskLoader(TensorLoader):
475475

476476
def make_detection_mask_loader(size="random", *, num_objects="random", extra_dims=(), dtype=torch.uint8):
477477
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
478-
size = _parse_image_size(size)
478+
size = _parse_spatial_size(size)
479479
num_objects = int(torch.randint(1, 11, ())) if num_objects == "random" else num_objects
480480

481481
def fn(shape, dtype, device):
@@ -489,7 +489,7 @@ def fn(shape, dtype, device):
489489

490490

491491
def make_detection_mask_loaders(
492-
sizes=DEFAULT_IMAGE_SIZES,
492+
sizes=DEFAULT_SPATIAL_SIZES,
493493
num_objects=(1, 0, "random"),
494494
extra_dims=DEFAULT_EXTRA_DIMS,
495495
dtypes=(torch.uint8,),
@@ -503,7 +503,7 @@ def make_detection_mask_loaders(
503503

504504
def make_segmentation_mask_loader(size="random", *, num_categories="random", extra_dims=(), dtype=torch.uint8):
505505
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
506-
size = _parse_image_size(size)
506+
size = _parse_spatial_size(size)
507507
num_categories = int(torch.randint(1, 11, ())) if num_categories == "random" else num_categories
508508

509509
def fn(shape, dtype, device):
@@ -518,7 +518,7 @@ def fn(shape, dtype, device):
518518

519519
def make_segmentation_mask_loaders(
520520
*,
521-
sizes=DEFAULT_IMAGE_SIZES,
521+
sizes=DEFAULT_SPATIAL_SIZES,
522522
num_categories=(1, 2, "random"),
523523
extra_dims=DEFAULT_EXTRA_DIMS,
524524
dtypes=(torch.uint8,),
@@ -532,7 +532,7 @@ def make_segmentation_mask_loaders(
532532

533533
def make_mask_loaders(
534534
*,
535-
sizes=DEFAULT_IMAGE_SIZES,
535+
sizes=DEFAULT_SPATIAL_SIZES,
536536
num_objects=(1, 0, "random"),
537537
num_categories=(1, 2, "random"),
538538
extra_dims=DEFAULT_EXTRA_DIMS,
@@ -559,7 +559,7 @@ def make_video_loader(
559559
extra_dims=(),
560560
dtype=torch.uint8,
561561
):
562-
size = _parse_image_size(size)
562+
size = _parse_spatial_size(size)
563563
num_frames = int(torch.randint(1, 5, ())) if num_frames == "random" else num_frames
564564

565565
def fn(shape, dtype, device):
@@ -576,7 +576,7 @@ def fn(shape, dtype, device):
576576

577577
def make_video_loaders(
578578
*,
579-
sizes=DEFAULT_IMAGE_SIZES,
579+
sizes=DEFAULT_SPATIAL_SIZES,
580580
color_spaces=(
581581
features.ColorSpace.GRAY,
582582
features.ColorSpace.RGB,

test/prototype_transforms_kernel_infos.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def sample_inputs_horizontal_flip_bounding_box():
145145
formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
146146
):
147147
yield ArgsKwargs(
148-
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.spatial_size
148+
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
149149
)
150150

151151

@@ -185,9 +185,9 @@ def sample_inputs_horizontal_flip_video():
185185
)
186186

187187

188-
def _get_resize_sizes(image_size):
189-
height, width = image_size
190-
length = max(image_size)
188+
def _get_resize_sizes(spatial_size):
189+
height, width = spatial_size
190+
length = max(spatial_size)
191191
yield length
192192
yield [length]
193193
yield (length,)
@@ -252,7 +252,7 @@ def reference_inputs_resize_image_tensor():
252252
def sample_inputs_resize_bounding_box():
253253
for bounding_box_loader in make_bounding_box_loaders():
254254
for size in _get_resize_sizes(bounding_box_loader.spatial_size):
255-
yield ArgsKwargs(bounding_box_loader, size=size, image_size=bounding_box_loader.spatial_size)
255+
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)
256256

257257

258258
def sample_inputs_resize_mask():
@@ -394,7 +394,7 @@ def sample_inputs_affine_bounding_box():
394394
yield ArgsKwargs(
395395
bounding_box_loader,
396396
format=bounding_box_loader.format,
397-
image_size=bounding_box_loader.spatial_size,
397+
spatial_size=bounding_box_loader.spatial_size,
398398
**affine_params,
399399
)
400400

@@ -422,9 +422,9 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
422422
return true_matrix
423423

424424

425-
def reference_affine_bounding_box(bounding_box, *, format, image_size, angle, translate, scale, shear, center=None):
425+
def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, translate, scale, shear, center=None):
426426
if center is None:
427-
center = [s * 0.5 for s in image_size[::-1]]
427+
center = [s * 0.5 for s in spatial_size[::-1]]
428428

429429
def transform(bbox):
430430
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
@@ -473,7 +473,7 @@ def reference_inputs_affine_bounding_box():
473473
yield ArgsKwargs(
474474
bounding_box_loader,
475475
format=bounding_box_loader.format,
476-
image_size=bounding_box_loader.spatial_size,
476+
spatial_size=bounding_box_loader.spatial_size,
477477
**affine_kwargs,
478478
)
479479

@@ -650,7 +650,7 @@ def sample_inputs_vertical_flip_bounding_box():
650650
formats=[features.BoundingBoxFormat.XYXY], dtypes=[torch.float32]
651651
):
652652
yield ArgsKwargs(
653-
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.spatial_size
653+
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
654654
)
655655

656656

@@ -729,7 +729,7 @@ def sample_inputs_rotate_bounding_box():
729729
yield ArgsKwargs(
730730
bounding_box_loader,
731731
format=bounding_box_loader.format,
732-
image_size=bounding_box_loader.spatial_size,
732+
spatial_size=bounding_box_loader.spatial_size,
733733
angle=_ROTATE_ANGLES[0],
734734
)
735735

@@ -1001,7 +1001,7 @@ def sample_inputs_pad_bounding_box():
10011001
yield ArgsKwargs(
10021002
bounding_box_loader,
10031003
format=bounding_box_loader.format,
1004-
image_size=bounding_box_loader.spatial_size,
1004+
spatial_size=bounding_box_loader.spatial_size,
10051005
padding=padding,
10061006
padding_mode="constant",
10071007
)
@@ -1131,8 +1131,8 @@ def sample_inputs_perspective_video():
11311131
)
11321132

11331133

1134-
def _get_elastic_displacement(image_size):
1135-
return torch.rand(1, *image_size, 2)
1134+
def _get_elastic_displacement(spatial_size):
1135+
return torch.rand(1, *spatial_size, 2)
11361136

11371137

11381138
def sample_inputs_elastic_image_tensor():
@@ -1212,7 +1212,7 @@ def sample_inputs_elastic_video():
12121212
)
12131213

12141214

1215-
_CENTER_CROP_IMAGE_SIZES = [(16, 16), (7, 33), (31, 9)]
1215+
_CENTER_CROP_SPATIAL_SIZES = [(16, 16), (7, 33), (31, 9)]
12161216
_CENTER_CROP_OUTPUT_SIZES = [[4, 3], [42, 70], [4], 3, (5, 2), (6,)]
12171217

12181218

@@ -1231,7 +1231,7 @@ def sample_inputs_center_crop_image_tensor():
12311231

12321232
def reference_inputs_center_crop_image_tensor():
12331233
for image_loader, output_size in itertools.product(
1234-
make_image_loaders(sizes=_CENTER_CROP_IMAGE_SIZES, extra_dims=[()]), _CENTER_CROP_OUTPUT_SIZES
1234+
make_image_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()]), _CENTER_CROP_OUTPUT_SIZES
12351235
):
12361236
yield ArgsKwargs(image_loader, output_size=output_size)
12371237

@@ -1241,7 +1241,7 @@ def sample_inputs_center_crop_bounding_box():
12411241
yield ArgsKwargs(
12421242
bounding_box_loader,
12431243
format=bounding_box_loader.format,
1244-
image_size=bounding_box_loader.spatial_size,
1244+
spatial_size=bounding_box_loader.spatial_size,
12451245
output_size=output_size,
12461246
)
12471247

@@ -1254,7 +1254,7 @@ def sample_inputs_center_crop_mask():
12541254

12551255
def reference_inputs_center_crop_mask():
12561256
for mask_loader, output_size in itertools.product(
1257-
make_mask_loaders(sizes=_CENTER_CROP_IMAGE_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES
1257+
make_mask_loaders(sizes=_CENTER_CROP_SPATIAL_SIZES, extra_dims=[()], num_objects=[1]), _CENTER_CROP_OUTPUT_SIZES
12581258
):
12591259
yield ArgsKwargs(mask_loader, output_size=output_size)
12601260

@@ -1820,7 +1820,7 @@ def sample_inputs_adjust_saturation_video():
18201820
def sample_inputs_clamp_bounding_box():
18211821
for bounding_box_loader in make_bounding_box_loaders():
18221822
yield ArgsKwargs(
1823-
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.spatial_size
1823+
bounding_box_loader, format=bounding_box_loader.format, spatial_size=bounding_box_loader.spatial_size
18241824
)
18251825

18261826

@@ -1834,7 +1834,7 @@ def sample_inputs_clamp_bounding_box():
18341834
_FIVE_TEN_CROP_SIZES = [7, (6,), [5], (6, 5), [7, 6]]
18351835

18361836

1837-
def _get_five_ten_crop_image_size(size):
1837+
def _get_five_ten_crop_spatial_size(size):
18381838
if isinstance(size, int):
18391839
crop_height = crop_width = size
18401840
elif len(size) == 1:
@@ -1847,28 +1847,28 @@ def _get_five_ten_crop_image_size(size):
18471847
def sample_inputs_five_crop_image_tensor():
18481848
for size in _FIVE_TEN_CROP_SIZES:
18491849
for image_loader in make_image_loaders(
1850-
sizes=[_get_five_ten_crop_image_size(size)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
1850+
sizes=[_get_five_ten_crop_spatial_size(size)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
18511851
):
18521852
yield ArgsKwargs(image_loader, size=size)
18531853

18541854

18551855
def reference_inputs_five_crop_image_tensor():
18561856
for size in _FIVE_TEN_CROP_SIZES:
1857-
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]):
1857+
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()]):
18581858
yield ArgsKwargs(image_loader, size=size)
18591859

18601860

18611861
def sample_inputs_ten_crop_image_tensor():
18621862
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
18631863
for image_loader in make_image_loaders(
1864-
sizes=[_get_five_ten_crop_image_size(size)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
1864+
sizes=[_get_five_ten_crop_spatial_size(size)], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]
18651865
):
18661866
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
18671867

18681868

18691869
def reference_inputs_ten_crop_image_tensor():
18701870
for size, vertical_flip in itertools.product(_FIVE_TEN_CROP_SIZES, [False, True]):
1871-
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_image_size(size)], extra_dims=[()]):
1871+
for image_loader in make_image_loaders(sizes=[_get_five_ten_crop_spatial_size(size)], extra_dims=[()]):
18721872
yield ArgsKwargs(image_loader, size=size, vertical_flip=vertical_flip)
18731873

18741874

0 commit comments

Comments
 (0)