Skip to content

Commit 8897252

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] add tests for perspective start- / endpoints (#7226)
Reviewed By: vmoens Differential Revision: D44416261 fbshipit-source-id: ac7c6b974876948a38df23a6373b3b785772377e
1 parent 8a188ad commit 8897252

File tree

2 files changed

+32
-6
lines changed

2 files changed

+32
-6
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
get_num_channels,
1616
ImageLoader,
1717
InfoBase,
18+
make_bounding_box_loader,
1819
make_bounding_box_loaders,
20+
make_detection_mask_loader,
1921
make_image_loader,
2022
make_image_loaders,
2123
make_image_loaders_for_interpolation,
2224
make_mask_loaders,
25+
make_video_loader,
2326
make_video_loaders,
2427
mark_framework_limitation,
2528
TestMark,
@@ -1168,12 +1171,18 @@ def reference_inputs_pad_bounding_box():
11681171
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
11691172
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
11701173
]
1174+
_STARTPOINTS = [[0, 1], [2, 3], [4, 5], [6, 7]]
1175+
_ENDPOINTS = [[9, 8], [7, 6], [5, 4], [3, 2]]
11711176

11721177

11731178
def sample_inputs_perspective_image_tensor():
11741179
for image_loader in make_image_loaders(sizes=["random"]):
11751180
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
1176-
yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0])
1181+
yield ArgsKwargs(
1182+
image_loader, startpoints=None, endpoints=None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0]
1183+
)
1184+
1185+
yield ArgsKwargs(make_image_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
11771186

11781187

11791188
def reference_inputs_perspective_image_tensor():
@@ -1200,25 +1209,38 @@ def reference_inputs_perspective_image_tensor():
12001209
def sample_inputs_perspective_bounding_box():
12011210
for bounding_box_loader in make_bounding_box_loaders():
12021211
yield ArgsKwargs(
1203-
bounding_box_loader, bounding_box_loader.format, None, None, coefficients=_PERSPECTIVE_COEFFS[0]
1212+
bounding_box_loader,
1213+
format=bounding_box_loader.format,
1214+
startpoints=None,
1215+
endpoints=None,
1216+
coefficients=_PERSPECTIVE_COEFFS[0],
12041217
)
12051218

1219+
format = datapoints.BoundingBoxFormat.XYXY
1220+
yield ArgsKwargs(
1221+
make_bounding_box_loader(format=format), format=format, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
1222+
)
1223+
12061224

12071225
def sample_inputs_perspective_mask():
12081226
for mask_loader in make_mask_loaders(sizes=["random"]):
1209-
yield ArgsKwargs(mask_loader, None, None, coefficients=_PERSPECTIVE_COEFFS[0])
1227+
yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])
1228+
1229+
yield ArgsKwargs(make_detection_mask_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
12101230

12111231

12121232
def reference_inputs_perspective_mask():
12131233
for mask_loader, perspective_coeffs in itertools.product(
12141234
make_mask_loaders(extra_dims=[()], num_objects=[1]), _PERSPECTIVE_COEFFS
12151235
):
1216-
yield ArgsKwargs(mask_loader, None, None, coefficients=perspective_coeffs)
1236+
yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=perspective_coeffs)
12171237

12181238

12191239
def sample_inputs_perspective_video():
12201240
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
1221-
yield ArgsKwargs(video_loader, None, None, coefficients=_PERSPECTIVE_COEFFS[0])
1241+
yield ArgsKwargs(video_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])
1242+
1243+
yield ArgsKwargs(make_video_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
12221244

12231245

12241246
KERNEL_INFOS.extend(

torchvision/prototype/datapoints/_bounding_box.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,11 @@ def perspective(
176176
coefficients: Optional[List[float]] = None,
177177
) -> BoundingBox:
178178
output = self._F.perspective_bounding_box(
179-
self.as_subclass(torch.Tensor), startpoints, endpoints, self.format, coefficients=coefficients
179+
self.as_subclass(torch.Tensor),
180+
format=self.format,
181+
startpoints=startpoints,
182+
endpoints=endpoints,
183+
coefficients=coefficients,
180184
)
181185
return BoundingBox.wrap_like(self, output)
182186

0 commit comments

Comments
 (0)