Skip to content

Commit 56fb0bf

Browse files
authored
cleanup prototype transforms functional test (#5668)
1 parent 151e162 commit 56fb0bf

File tree

1 file changed

+18
-19
lines changed

1 file changed

+18
-19
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -379,12 +379,15 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
379379
np.max(transformed_points[:, 1]),
380380
]
381381
out_bbox = features.BoundingBox(
382-
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=bbox.image_size, dtype=torch.float32
382+
out_bbox,
383+
format=features.BoundingBoxFormat.XYXY,
384+
image_size=bbox.image_size,
385+
dtype=torch.float32,
386+
device=bbox.device,
383387
)
384-
out_bbox = convert_bounding_box_format(
388+
return convert_bounding_box_format(
385389
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
386390
)
387-
return out_bbox.to(bbox.device)
388391

389392
image_size = (32, 38)
390393

@@ -439,8 +442,8 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
439442
[1, 1, 5, 5],
440443
]
441444
in_boxes = features.BoundingBox(
442-
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64
443-
).to(device)
445+
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64, device=device
446+
)
444447
# Tested parameters
445448
angle = 63
446449
scale = 0.89
@@ -473,9 +476,7 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
473476
shear=(0, 0),
474477
)
475478

476-
assert len(output_boxes) == len(expected_bboxes)
477-
for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()):
478-
np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box)
479+
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
479480

480481

481482
@pytest.mark.parametrize("angle", [-54, 56])
@@ -589,12 +590,15 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
589590
out_bbox[3] -= tr_y
590591

591592
out_bbox = features.BoundingBox(
592-
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float32
593+
out_bbox,
594+
format=features.BoundingBoxFormat.XYXY,
595+
image_size=image_size,
596+
dtype=torch.float32,
597+
device=bbox.device,
593598
)
594-
out_bbox = convert_bounding_box_format(
599+
return convert_bounding_box_format(
595600
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
596601
)
597-
return out_bbox.to(bbox.device)
598602

599603
image_size = (32, 38)
600604

@@ -630,9 +634,6 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
630634
expected_bboxes = torch.stack(expected_bboxes)
631635
else:
632636
expected_bboxes = expected_bboxes[0]
633-
print("input:", bboxes)
634-
print("output_bboxes:", output_bboxes)
635-
print("expected_bboxes:", expected_bboxes)
636637
torch.testing.assert_close(output_bboxes, expected_bboxes)
637638

638639

@@ -649,8 +650,8 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
649650
[image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10],
650651
]
651652
in_boxes = features.BoundingBox(
652-
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64
653-
).to(device)
653+
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64, device=device
654+
)
654655
# Tested parameters
655656
angle = 45
656657
center = None if expand else [12, 23]
@@ -687,6 +688,4 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
687688
center=center,
688689
)
689690

690-
assert len(output_boxes) == len(expected_bboxes)
691-
for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()):
692-
np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box)
691+
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)

0 commit comments

Comments
 (0)