Skip to content

Commit 6032775

Browse files
authored
Refactor BoxOps tests to use parameterize (#5380)
* Refactor BoxOps tests to use parameterize * Refactor BoxOps tests to use parameterize * Refactor BoxOps to use parameterize, addressed comments from PR#5380 * Refactor BoxOps to use parameterize, addressed minor styling comments from PR#5380 * Refactor BoxOps to use parameterize, addressed typing errorsfrom PR#5380 * Refactor BoxOps to use parameterize, addressed minor naming comments for PR#5380
1 parent 21790df commit 6032775

File tree

1 file changed

+134
-99
lines changed

1 file changed

+134
-99
lines changed

test/test_ops.py

Lines changed: 134 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
import os
33
from abc import ABC, abstractmethod
44
from functools import lru_cache
5-
from typing import Tuple
5+
from typing import Callable, List, Tuple
66

77
import numpy as np
88
import pytest
99
import torch
1010
import torch.fx
11-
from common_utils import needs_cuda, cpu_and_gpu, assert_equal
11+
from common_utils import assert_equal, cpu_and_gpu, needs_cuda
1212
from PIL import Image
1313
from torch import nn, Tensor
1414
from torch.autograd import gradcheck
@@ -1101,114 +1101,149 @@ def test_bbox_convert_jit(self):
11011101
torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE)
11021102

11031103

1104-
class TestBoxArea:
1105-
def test_box_area(self):
1106-
def area_check(box, expected, tolerance=1e-4):
1107-
out = ops.box_area(box)
1108-
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
1109-
1110-
# Check for int boxes
1111-
for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
1112-
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype)
1113-
expected = torch.tensor([10000, 0])
1114-
area_check(box_tensor, expected)
1115-
1116-
# Check for float32 and float64 boxes
1117-
for dtype in [torch.float32, torch.float64]:
1118-
box_tensor = torch.tensor(
1119-
[
1120-
[285.3538, 185.5758, 1193.5110, 851.4551],
1121-
[285.1472, 188.7374, 1192.4984, 851.0669],
1122-
[279.2440, 197.9812, 1189.4746, 849.2019],
1123-
],
1124-
dtype=dtype,
1125-
)
1126-
expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64)
1127-
area_check(box_tensor, expected, tolerance=0.05)
1128-
1129-
# Check for float16 box
1130-
box_tensor = torch.tensor(
1131-
[[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]],
1132-
dtype=torch.float16,
1133-
)
1134-
expected = torch.tensor([605113.875, 600495.1875, 592247.25])
1135-
area_check(box_tensor, expected)
1136-
1137-
def test_box_area_jit(self):
1138-
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
1139-
TOLERANCE = 1e-3
1140-
expected = ops.box_area(box_tensor)
1141-
scripted_fn = torch.jit.script(ops.box_area)
1142-
scripted_area = scripted_fn(box_tensor)
1143-
torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=TOLERANCE)
1104+
class BoxTestBase(ABC):
1105+
@abstractmethod
1106+
def _target_fn(self) -> Tuple[bool, Callable]:
1107+
pass
11441108

1109+
def _perform_box_operation(self, box: Tensor, run_as_script: bool = False) -> Tensor:
1110+
is_binary_fn = self._target_fn()[0]
1111+
target_fn = self._target_fn()[1]
1112+
box_operation = torch.jit.script(target_fn) if run_as_script else target_fn
1113+
return box_operation(box, box) if is_binary_fn else box_operation(box)
11451114

1146-
class TestBoxIou:
1147-
def test_iou(self):
1148-
def iou_check(box, expected, tolerance=1e-4):
1149-
out = ops.box_iou(box, box)
1115+
def _run_test(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None:
1116+
def assert_close(box: Tensor, expected: Tensor, tolerance):
1117+
out = self._perform_box_operation(box)
11501118
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
11511119

1152-
# Check for int boxes
1153-
for dtype in [torch.int16, torch.int32, torch.int64]:
1154-
box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
1155-
expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]])
1156-
iou_check(box, expected)
1120+
for dtype in dtypes:
1121+
actual_box = torch.tensor(test_input, dtype=dtype)
1122+
expected_box = torch.tensor(expected)
1123+
assert_close(actual_box, expected_box, tolerance)
11571124

1158-
# Check for float boxes
1159-
for dtype in [torch.float16, torch.float32, torch.float64]:
1160-
box_tensor = torch.tensor(
1161-
[
1162-
[285.3538, 185.5758, 1193.5110, 851.4551],
1163-
[285.1472, 188.7374, 1192.4984, 851.0669],
1164-
[279.2440, 197.9812, 1189.4746, 849.2019],
1165-
],
1166-
dtype=dtype,
1167-
)
1168-
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
1169-
iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4)
1125+
def _run_jit_test(self, test_input: List) -> None:
1126+
box_tensor = torch.tensor(test_input, dtype=torch.float)
1127+
expected = self._perform_box_operation(box_tensor, True)
1128+
scripted_area = self._perform_box_operation(box_tensor, True)
1129+
torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3)
11701130

1171-
def test_iou_jit(self):
1172-
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
1173-
TOLERANCE = 1e-3
1174-
expected = ops.box_iou(box_tensor, box_tensor)
1175-
scripted_fn = torch.jit.script(ops.box_iou)
1176-
scripted_iou = scripted_fn(box_tensor, box_tensor)
1177-
torch.testing.assert_close(scripted_iou, expected, rtol=0.0, atol=TOLERANCE)
11781131

1132+
class TestBoxArea(BoxTestBase):
1133+
def _target_fn(self) -> Tuple[bool, Callable]:
1134+
return (False, ops.box_area)
11791135

1180-
class TestGenBoxIou:
1181-
def test_gen_iou(self):
1182-
def gen_iou_check(box, expected, tolerance=1e-4):
1183-
out = ops.generalized_box_iou(box, box)
1184-
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
1136+
def _generate_int_input() -> List[List[int]]:
1137+
return [[0, 0, 100, 100], [0, 0, 0, 0]]
11851138

1186-
# Check for int boxes
1187-
for dtype in [torch.int16, torch.int32, torch.int64]:
1188-
box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype)
1189-
expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]])
1190-
gen_iou_check(box, expected)
1139+
def _generate_int_expected() -> List[int]:
1140+
return [10000, 0]
11911141

1192-
# Check for float boxes
1193-
for dtype in [torch.float16, torch.float32, torch.float64]:
1194-
box_tensor = torch.tensor(
1195-
[
1196-
[285.3538, 185.5758, 1193.5110, 851.4551],
1197-
[285.1472, 188.7374, 1192.4984, 851.0669],
1198-
[279.2440, 197.9812, 1189.4746, 849.2019],
1199-
],
1200-
dtype=dtype,
1201-
)
1202-
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
1203-
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
1142+
def _generate_float_input(index: int) -> List[List[float]]:
1143+
return [
1144+
[
1145+
[285.3538, 185.5758, 1193.5110, 851.4551],
1146+
[285.1472, 188.7374, 1192.4984, 851.0669],
1147+
[279.2440, 197.9812, 1189.4746, 849.2019],
1148+
],
1149+
[[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]],
1150+
][index]
1151+
1152+
def _generate_float_expected(index: int) -> List[float]:
1153+
return [[604723.0806, 600965.4666, 592761.0085], [605113.875, 600495.1875, 592247.25]][index]
1154+
1155+
@pytest.mark.parametrize(
1156+
"test_input, dtypes, tolerance, expected",
1157+
[
1158+
pytest.param(
1159+
_generate_int_input(),
1160+
[torch.int8, torch.int16, torch.int32, torch.int64],
1161+
1e-4,
1162+
_generate_int_expected(),
1163+
),
1164+
pytest.param(_generate_float_input(0), [torch.float32, torch.float64], 0.05, _generate_float_expected(0)),
1165+
pytest.param(_generate_float_input(1), [torch.float16], 1e-4, _generate_float_expected(1)),
1166+
],
1167+
)
1168+
def test_box_area(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None:
1169+
self._run_test(test_input, dtypes, tolerance, expected)
1170+
1171+
def test_box_area_jit(self) -> None:
1172+
self._run_jit_test([[0, 0, 100, 100], [0, 0, 0, 0]])
1173+
1174+
1175+
class TestBoxIou(BoxTestBase):
1176+
def _target_fn(self) -> Tuple[bool, Callable]:
1177+
return (True, ops.box_iou)
1178+
1179+
def _generate_int_input() -> List[List[int]]:
1180+
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
1181+
1182+
def _generate_int_expected() -> List[List[float]]:
1183+
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
1184+
1185+
def _generate_float_input() -> List[List[float]]:
1186+
return [
1187+
[285.3538, 185.5758, 1193.5110, 851.4551],
1188+
[285.1472, 188.7374, 1192.4984, 851.0669],
1189+
[279.2440, 197.9812, 1189.4746, 849.2019],
1190+
]
12041191

1205-
def test_giou_jit(self):
1206-
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
1207-
TOLERANCE = 1e-3
1208-
expected = ops.generalized_box_iou(box_tensor, box_tensor)
1209-
scripted_fn = torch.jit.script(ops.generalized_box_iou)
1210-
scripted_iou = scripted_fn(box_tensor, box_tensor)
1211-
torch.testing.assert_close(scripted_iou, expected, rtol=0.0, atol=TOLERANCE)
1192+
def _generate_float_expected() -> List[List[float]]:
1193+
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1194+
1195+
@pytest.mark.parametrize(
1196+
"test_input, dtypes, tolerance, expected",
1197+
[
1198+
pytest.param(
1199+
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
1200+
),
1201+
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
1202+
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-4, _generate_float_expected()),
1203+
],
1204+
)
1205+
def test_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None:
1206+
self._run_test(test_input, dtypes, tolerance, expected)
1207+
1208+
def test_iou_jit(self) -> None:
1209+
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])
1210+
1211+
1212+
class TestGenBoxIou(BoxTestBase):
1213+
def _target_fn(self) -> Tuple[bool, Callable]:
1214+
return (True, ops.generalized_box_iou)
1215+
1216+
def _generate_int_input() -> List[List[int]]:
1217+
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
1218+
1219+
def _generate_int_expected() -> List[List[float]]:
1220+
return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]
1221+
1222+
def _generate_float_input() -> List[List[float]]:
1223+
return [
1224+
[285.3538, 185.5758, 1193.5110, 851.4551],
1225+
[285.1472, 188.7374, 1192.4984, 851.0669],
1226+
[279.2440, 197.9812, 1189.4746, 849.2019],
1227+
]
1228+
1229+
def _generate_float_expected() -> List[List[float]]:
1230+
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]
1231+
1232+
@pytest.mark.parametrize(
1233+
"test_input, dtypes, tolerance, expected",
1234+
[
1235+
pytest.param(
1236+
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
1237+
),
1238+
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
1239+
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()),
1240+
],
1241+
)
1242+
def test_gen_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None:
1243+
self._run_test(test_input, dtypes, tolerance, expected)
1244+
1245+
def test_giou_jit(self) -> None:
1246+
self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])
12121247

12131248

12141249
class TestMasksToBoxes:

0 commit comments

Comments
 (0)