Skip to content

Commit cde7ff0

Browse files
authored
Pass custom scales on DefaultBoxGenerator and change default estimation. (#3766)
1 parent 3975ec5 commit cde7ff0

File tree

3 files changed

+31
-19
lines changed

3 files changed

+31
-19
lines changed

test/test_models_detection_anchor_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ def test_defaultbox_generator(self):
7676
dboxes = model(images, features)
7777

7878
dboxes_output = torch.tensor([
79-
[6.9750, 6.9750, 8.0250, 8.0250],
80-
[6.7315, 6.7315, 8.2685, 8.2685],
81-
[6.7575, 7.1288, 8.2425, 7.8712],
82-
[7.1288, 6.7575, 7.8712, 8.2425]
79+
[6.3750, 6.3750, 8.6250, 8.6250],
80+
[4.7443, 4.7443, 10.2557, 10.2557],
81+
[5.9090, 6.7045, 9.0910, 8.2955],
82+
[6.7045, 5.9090, 8.2955, 9.0910]
8383
])
8484

8585
self.assertEqual(len(dboxes), 2)

torchvision/models/detection/anchor_utils.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,19 @@ class DefaultBoxGenerator(nn.Module):
138138
Args:
139139
aspect_ratios (List[List[int]]): A list with all the aspect ratios used in each feature map.
140140
min_ratio (float): The minimum scale :math:`\text{s}_{\text{min}}` of the default boxes used in the estimation
141-
of the scales of each feature map.
141+
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
142142
max_ratio (float): The maximum scale :math:`\text{s}_{\text{max}}` of the default boxes used in the estimation
143-
of the scales of each feature map.
143+
of the scales of each feature map. It is used only if the ``scales`` parameter is not provided.
144+
scales (List[float]], optional): The scales of the default boxes. If not provided it will be estimated using
145+
the ``min_ratio`` and ``max_ratio`` parameters.
144146
steps (List[int]], optional): It's a hyper-parameter that affects the tiling of defalt boxes. If not provided
145147
it will be estimated from the data.
146148
clip (bool): Whether the standardized values of default boxes should be clipped between 0 and 1. The clipping
147149
is applied while the boxes are encoded in format ``(cx, cy, w, h)``.
148150
"""
149151

150152
def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_ratio: float = 0.9,
151-
steps: Optional[List[int]] = None, clip: bool = True):
153+
scales: Optional[List[float]] = None, steps: Optional[List[int]] = None, clip: bool = True):
152154
super().__init__()
153155
if steps is not None:
154156
assert len(aspect_ratios) == len(steps)
@@ -158,15 +160,15 @@ def __init__(self, aspect_ratios: List[List[int]], min_ratio: float = 0.15, max_
158160
num_outputs = len(aspect_ratios)
159161

160162
# Estimation of default boxes scales
161-
# Inspired from https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_pascal.py#L311-L317
162-
min_centile = int(100 * min_ratio)
163-
max_centile = int(100 * max_ratio)
164-
conv4_centile = min_centile // 2 # assume half of min_ratio as in paper
165-
step = (max_centile - min_centile) // (num_outputs - 2)
166-
centiles = [conv4_centile, min_centile]
167-
for c in range(min_centile, max_centile + 1, step):
168-
centiles.append(c + step)
169-
self.scales = [c / 100 for c in centiles]
163+
if scales is None:
164+
if num_outputs > 1:
165+
range_ratio = max_ratio - min_ratio
166+
self.scales = [min_ratio + range_ratio * k / (num_outputs - 1.0) for k in range(num_outputs)]
167+
self.scales.append(1.0)
168+
else:
169+
self.scales = [min_ratio, max_ratio]
170+
else:
171+
self.scales = scales
170172

171173
self._wh_pairs = []
172174
for k in range(num_outputs):
@@ -207,9 +209,17 @@ def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Ten
207209
for k, f_k in enumerate(grid_sizes):
208210
# Now add the default boxes for each width-height pair
209211
for j in range(f_k[0]):
210-
cy = (j + 0.5) / (float(f_k[0]) if self.steps is None else image_size[1] / self.steps[k])
212+
if self.steps is not None:
213+
y_f_k = image_size[1] / self.steps[k]
214+
else:
215+
y_f_k = float(f_k[0])
216+
cy = (j + 0.5) / y_f_k
211217
for i in range(f_k[1]):
212-
cx = (i + 0.5) / (float(f_k[1]) if self.steps is None else image_size[0] / self.steps[k])
218+
if self.steps is not None:
219+
x_f_k = image_size[0] / self.steps[k]
220+
else:
221+
x_f_k = float(f_k[1])
222+
cx = (i + 0.5) / x_f_k
213223
default_boxes.extend([[cx, cy, w, h] for w, h in self._wh_pairs[k]])
214224

215225
dboxes = []

torchvision/models/detection/ssd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,9 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i
552552
pretrained_backbone = False
553553

554554
backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers, True)
555-
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], steps=[8, 16, 32, 64, 100, 300])
555+
anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]],
556+
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],
557+
steps=[8, 16, 32, 64, 100, 300])
556558
model = SSD(backbone, anchor_generator, (300, 300), num_classes,
557559
image_mean=[0.48235, 0.45882, 0.40784], image_std=[1., 1., 1.], **kwargs)
558560
if pretrained:

0 commit comments

Comments
 (0)