Skip to content

Commit 3099e0c

Browse files
authored
Add missing type hints to anchor_utils (#6735)
* Use the variable name sizes instead of scales for consistency * Add the missing type hints * Restore the naming back to scales instead of sizes to avoid backwards incompatibility
1 parent 12adc54 commit 3099e0c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchvision/models/detection/anchor_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def generate_anchors(
6161
aspect_ratios: List[float],
6262
dtype: torch.dtype = torch.float32,
6363
device: torch.device = torch.device("cpu"),
64-
):
64+
) -> Tensor:
6565
scales = torch.as_tensor(scales, dtype=dtype, device=device)
6666
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
6767
h_ratios = torch.sqrt(aspect_ratios)
@@ -76,7 +76,7 @@ def generate_anchors(
7676
def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
7777
self.cell_anchors = [cell_anchor.to(dtype=dtype, device=device) for cell_anchor in self.cell_anchors]
7878

79-
def num_anchors_per_location(self):
79+
def num_anchors_per_location(self) -> List[int]:
8080
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
8181

8282
# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
@@ -201,7 +201,7 @@ def _generate_wh_pairs(
201201
_wh_pairs.append(torch.as_tensor(wh_pairs, dtype=dtype, device=device))
202202
return _wh_pairs
203203

204-
def num_anchors_per_location(self):
204+
def num_anchors_per_location(self) -> List[int]:
205205
# Estimate num of anchors based on aspect ratios: 2 default boxes + 2 * ratios of feaure map.
206206
return [2 + 2 * len(r) for r in self.aspect_ratios]
207207

0 commit comments

Comments
 (0)