|
2 | 2 | import os
|
3 | 3 | from abc import ABC, abstractmethod
|
4 | 4 | from functools import lru_cache
|
5 |
| -from typing import Tuple |
| 5 | +from typing import Callable, List, Tuple |
6 | 6 |
|
7 | 7 | import numpy as np
|
8 | 8 | import pytest
|
9 | 9 | import torch
|
10 | 10 | import torch.fx
|
11 |
| -from common_utils import needs_cuda, cpu_and_gpu, assert_equal |
| 11 | +from common_utils import assert_equal, cpu_and_gpu, needs_cuda |
12 | 12 | from PIL import Image
|
13 | 13 | from torch import nn, Tensor
|
14 | 14 | from torch.autograd import gradcheck
|
@@ -1101,114 +1101,149 @@ def test_bbox_convert_jit(self):
|
1101 | 1101 | torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE)
|
1102 | 1102 |
|
1103 | 1103 |
|
1104 |
| -class TestBoxArea: |
1105 |
| - def test_box_area(self): |
1106 |
| - def area_check(box, expected, tolerance=1e-4): |
1107 |
| - out = ops.box_area(box) |
1108 |
| - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) |
1109 |
| - |
1110 |
| - # Check for int boxes |
1111 |
| - for dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: |
1112 |
| - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype) |
1113 |
| - expected = torch.tensor([10000, 0]) |
1114 |
| - area_check(box_tensor, expected) |
1115 |
| - |
1116 |
| - # Check for float32 and float64 boxes |
1117 |
| - for dtype in [torch.float32, torch.float64]: |
1118 |
| - box_tensor = torch.tensor( |
1119 |
| - [ |
1120 |
| - [285.3538, 185.5758, 1193.5110, 851.4551], |
1121 |
| - [285.1472, 188.7374, 1192.4984, 851.0669], |
1122 |
| - [279.2440, 197.9812, 1189.4746, 849.2019], |
1123 |
| - ], |
1124 |
| - dtype=dtype, |
1125 |
| - ) |
1126 |
| - expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64) |
1127 |
| - area_check(box_tensor, expected, tolerance=0.05) |
1128 |
| - |
1129 |
| - # Check for float16 box |
1130 |
| - box_tensor = torch.tensor( |
1131 |
| - [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], |
1132 |
| - dtype=torch.float16, |
1133 |
| - ) |
1134 |
| - expected = torch.tensor([605113.875, 600495.1875, 592247.25]) |
1135 |
| - area_check(box_tensor, expected) |
1136 |
| - |
1137 |
| - def test_box_area_jit(self): |
1138 |
| - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) |
1139 |
| - TOLERANCE = 1e-3 |
1140 |
| - expected = ops.box_area(box_tensor) |
1141 |
| - scripted_fn = torch.jit.script(ops.box_area) |
1142 |
| - scripted_area = scripted_fn(box_tensor) |
1143 |
| - torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=TOLERANCE) |
| 1104 | +class BoxTestBase(ABC): |
| 1105 | + @abstractmethod |
| 1106 | + def _target_fn(self) -> Tuple[bool, Callable]: |
| 1107 | + pass |
1144 | 1108 |
|
| 1109 | + def _perform_box_operation(self, box: Tensor, run_as_script: bool = False) -> Tensor: |
| 1110 | + is_binary_fn = self._target_fn()[0] |
| 1111 | + target_fn = self._target_fn()[1] |
| 1112 | + box_operation = torch.jit.script(target_fn) if run_as_script else target_fn |
| 1113 | + return box_operation(box, box) if is_binary_fn else box_operation(box) |
1145 | 1114 |
|
1146 |
| -class TestBoxIou: |
1147 |
| - def test_iou(self): |
1148 |
| - def iou_check(box, expected, tolerance=1e-4): |
1149 |
| - out = ops.box_iou(box, box) |
| 1115 | + def _run_test(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: |
| 1116 | + def assert_close(box: Tensor, expected: Tensor, tolerance): |
| 1117 | + out = self._perform_box_operation(box) |
1150 | 1118 | torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)
|
1151 | 1119 |
|
1152 |
| - # Check for int boxes |
1153 |
| - for dtype in [torch.int16, torch.int32, torch.int64]: |
1154 |
| - box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype) |
1155 |
| - expected = torch.tensor([[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]) |
1156 |
| - iou_check(box, expected) |
| 1120 | + for dtype in dtypes: |
| 1121 | + actual_box = torch.tensor(test_input, dtype=dtype) |
| 1122 | + expected_box = torch.tensor(expected) |
| 1123 | + assert_close(actual_box, expected_box, tolerance) |
1157 | 1124 |
|
1158 |
| - # Check for float boxes |
1159 |
| - for dtype in [torch.float16, torch.float32, torch.float64]: |
1160 |
| - box_tensor = torch.tensor( |
1161 |
| - [ |
1162 |
| - [285.3538, 185.5758, 1193.5110, 851.4551], |
1163 |
| - [285.1472, 188.7374, 1192.4984, 851.0669], |
1164 |
| - [279.2440, 197.9812, 1189.4746, 849.2019], |
1165 |
| - ], |
1166 |
| - dtype=dtype, |
1167 |
| - ) |
1168 |
| - expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) |
1169 |
| - iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4) |
| 1125 | + def _run_jit_test(self, test_input: List) -> None: |
| 1126 | + box_tensor = torch.tensor(test_input, dtype=torch.float) |
| 1127 | + expected = self._perform_box_operation(box_tensor, True) |
| 1128 | + scripted_area = self._perform_box_operation(box_tensor, True) |
| 1129 | + torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3) |
1170 | 1130 |
|
1171 |
| - def test_iou_jit(self): |
1172 |
| - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) |
1173 |
| - TOLERANCE = 1e-3 |
1174 |
| - expected = ops.box_iou(box_tensor, box_tensor) |
1175 |
| - scripted_fn = torch.jit.script(ops.box_iou) |
1176 |
| - scripted_iou = scripted_fn(box_tensor, box_tensor) |
1177 |
| - torch.testing.assert_close(scripted_iou, expected, rtol=0.0, atol=TOLERANCE) |
1178 | 1131 |
|
| 1132 | +class TestBoxArea(BoxTestBase): |
| 1133 | + def _target_fn(self) -> Tuple[bool, Callable]: |
| 1134 | + return (False, ops.box_area) |
1179 | 1135 |
|
1180 |
| -class TestGenBoxIou: |
1181 |
| - def test_gen_iou(self): |
1182 |
| - def gen_iou_check(box, expected, tolerance=1e-4): |
1183 |
| - out = ops.generalized_box_iou(box, box) |
1184 |
| - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) |
| 1136 | + def _generate_int_input() -> List[List[int]]: |
| 1137 | + return [[0, 0, 100, 100], [0, 0, 0, 0]] |
1185 | 1138 |
|
1186 |
| - # Check for int boxes |
1187 |
| - for dtype in [torch.int16, torch.int32, torch.int64]: |
1188 |
| - box = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=dtype) |
1189 |
| - expected = torch.tensor([[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]) |
1190 |
| - gen_iou_check(box, expected) |
| 1139 | + def _generate_int_expected() -> List[int]: |
| 1140 | + return [10000, 0] |
1191 | 1141 |
|
1192 |
| - # Check for float boxes |
1193 |
| - for dtype in [torch.float16, torch.float32, torch.float64]: |
1194 |
| - box_tensor = torch.tensor( |
1195 |
| - [ |
1196 |
| - [285.3538, 185.5758, 1193.5110, 851.4551], |
1197 |
| - [285.1472, 188.7374, 1192.4984, 851.0669], |
1198 |
| - [279.2440, 197.9812, 1189.4746, 849.2019], |
1199 |
| - ], |
1200 |
| - dtype=dtype, |
1201 |
| - ) |
1202 |
| - expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]) |
1203 |
| - gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3) |
| 1142 | + def _generate_float_input(index: int) -> List[List[float]]: |
| 1143 | + return [ |
| 1144 | + [ |
| 1145 | + [285.3538, 185.5758, 1193.5110, 851.4551], |
| 1146 | + [285.1472, 188.7374, 1192.4984, 851.0669], |
| 1147 | + [279.2440, 197.9812, 1189.4746, 849.2019], |
| 1148 | + ], |
| 1149 | + [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], |
| 1150 | + ][index] |
| 1151 | + |
| 1152 | + def _generate_float_expected(index: int) -> List[float]: |
| 1153 | + return [[604723.0806, 600965.4666, 592761.0085], [605113.875, 600495.1875, 592247.25]][index] |
| 1154 | + |
| 1155 | + @pytest.mark.parametrize( |
| 1156 | + "test_input, dtypes, tolerance, expected", |
| 1157 | + [ |
| 1158 | + pytest.param( |
| 1159 | + _generate_int_input(), |
| 1160 | + [torch.int8, torch.int16, torch.int32, torch.int64], |
| 1161 | + 1e-4, |
| 1162 | + _generate_int_expected(), |
| 1163 | + ), |
| 1164 | + pytest.param(_generate_float_input(0), [torch.float32, torch.float64], 0.05, _generate_float_expected(0)), |
| 1165 | + pytest.param(_generate_float_input(1), [torch.float16], 1e-4, _generate_float_expected(1)), |
| 1166 | + ], |
| 1167 | + ) |
| 1168 | + def test_box_area(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: |
| 1169 | + self._run_test(test_input, dtypes, tolerance, expected) |
| 1170 | + |
| 1171 | + def test_box_area_jit(self) -> None: |
| 1172 | + self._run_jit_test([[0, 0, 100, 100], [0, 0, 0, 0]]) |
| 1173 | + |
| 1174 | + |
| 1175 | +class TestBoxIou(BoxTestBase): |
| 1176 | + def _target_fn(self) -> Tuple[bool, Callable]: |
| 1177 | + return (True, ops.box_iou) |
| 1178 | + |
| 1179 | + def _generate_int_input() -> List[List[int]]: |
| 1180 | + return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] |
| 1181 | + |
| 1182 | + def _generate_int_expected() -> List[List[float]]: |
| 1183 | + return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] |
| 1184 | + |
| 1185 | + def _generate_float_input() -> List[List[float]]: |
| 1186 | + return [ |
| 1187 | + [285.3538, 185.5758, 1193.5110, 851.4551], |
| 1188 | + [285.1472, 188.7374, 1192.4984, 851.0669], |
| 1189 | + [279.2440, 197.9812, 1189.4746, 849.2019], |
| 1190 | + ] |
1204 | 1191 |
|
1205 |
| - def test_giou_jit(self): |
1206 |
| - box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float) |
1207 |
| - TOLERANCE = 1e-3 |
1208 |
| - expected = ops.generalized_box_iou(box_tensor, box_tensor) |
1209 |
| - scripted_fn = torch.jit.script(ops.generalized_box_iou) |
1210 |
| - scripted_iou = scripted_fn(box_tensor, box_tensor) |
1211 |
| - torch.testing.assert_close(scripted_iou, expected, rtol=0.0, atol=TOLERANCE) |
| 1192 | + def _generate_float_expected() -> List[List[float]]: |
| 1193 | + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] |
| 1194 | + |
| 1195 | + @pytest.mark.parametrize( |
| 1196 | + "test_input, dtypes, tolerance, expected", |
| 1197 | + [ |
| 1198 | + pytest.param( |
| 1199 | + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() |
| 1200 | + ), |
| 1201 | + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), |
| 1202 | + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-4, _generate_float_expected()), |
| 1203 | + ], |
| 1204 | + ) |
| 1205 | + def test_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: |
| 1206 | + self._run_test(test_input, dtypes, tolerance, expected) |
| 1207 | + |
| 1208 | + def test_iou_jit(self) -> None: |
| 1209 | + self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) |
| 1210 | + |
| 1211 | + |
| 1212 | +class TestGenBoxIou(BoxTestBase): |
| 1213 | + def _target_fn(self) -> Tuple[bool, Callable]: |
| 1214 | + return (True, ops.generalized_box_iou) |
| 1215 | + |
| 1216 | + def _generate_int_input() -> List[List[int]]: |
| 1217 | + return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] |
| 1218 | + |
| 1219 | + def _generate_int_expected() -> List[List[float]]: |
| 1220 | + return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] |
| 1221 | + |
| 1222 | + def _generate_float_input() -> List[List[float]]: |
| 1223 | + return [ |
| 1224 | + [285.3538, 185.5758, 1193.5110, 851.4551], |
| 1225 | + [285.1472, 188.7374, 1192.4984, 851.0669], |
| 1226 | + [279.2440, 197.9812, 1189.4746, 849.2019], |
| 1227 | + ] |
| 1228 | + |
| 1229 | + def _generate_float_expected() -> List[List[float]]: |
| 1230 | + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] |
| 1231 | + |
| 1232 | + @pytest.mark.parametrize( |
| 1233 | + "test_input, dtypes, tolerance, expected", |
| 1234 | + [ |
| 1235 | + pytest.param( |
| 1236 | + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() |
| 1237 | + ), |
| 1238 | + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), |
| 1239 | + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), |
| 1240 | + ], |
| 1241 | + ) |
| 1242 | + def test_gen_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: |
| 1243 | + self._run_test(test_input, dtypes, tolerance, expected) |
| 1244 | + |
| 1245 | + def test_giou_jit(self) -> None: |
| 1246 | + self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) |
1212 | 1247 |
|
1213 | 1248 |
|
1214 | 1249 | class TestMasksToBoxes:
|
|
0 commit comments