Skip to content

Commit 7ee068a

Browse files
committed
Use parameters from retinanet for anchor generation.
1 parent b9d7344 commit 7ee068a

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

torchvision/models/detection/retinanet.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
from ..utils import load_state_dict_from_url
88

9-
from .rpn import AnchorGenerator
9+
from . import _utils as det_utils
10+
from .anchor_utils import AnchorGenerator
1011
from .transform import GeneralizedRCNNTransform
1112
from .backbone_utils import resnet_fpn_backbone
1213
from ...ops.feature_pyramid_network import LastLevelP6P7
@@ -173,7 +174,7 @@ class RetinaNet(nn.Module):
173174
>>> import torch
174175
>>> import torchvision
175176
>>> from torchvision.models.detection import RetinaNet
176-
>>> from torchvision.models.detection.rpn import AnchorGenerator
177+
>>> from torchvision.models.detection.anchor_utils import AnchorGenerator
177178
>>> # load a pre-trained model for classification and return
178179
>>> # only the features
179180
>>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
@@ -187,8 +188,8 @@ class RetinaNet(nn.Module):
187188
>>> # ratios. We have a Tuple[Tuple[int]] because each feature
188189
>>> # map could potentially have different sizes and
189190
>>> # aspect ratios
190-
>>> anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
191-
>>> aspect_ratios=((0.5, 1.0, 2.0),))
191+
>>> anchor_generator = AnchorGenerator(sizes=[[x, x * 2 ** (1.0 / 3), x * 2 ** (2.0 / 3)] for x in [32, 64, 128, 256, 512]],
192+
>>> aspect_ratios=[[0.5, 1.0, 2.0]] * 5)
192193
>>>
193194
>>> # put the pieces together inside a RetinaNet model
194195
>>> model = RetinaNet(backbone,
@@ -218,14 +219,16 @@ def __init__(self, backbone, num_classes,
218219
assert isinstance(anchor_generator, (AnchorGenerator, type(None)))
219220

220221
if anchor_generator is None:
221-
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
222+
# TODO: Set correct default values
223+
anchor_sizes = [[x, x * 2 ** (1.0 / 3), x * 2 ** (2.0 / 3)] for x in [32, 64, 128, 256, 512]]
222224
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
223-
self.anchor_generator = AnchorGenerator(
225+
anchor_generator = AnchorGenerator(
224226
anchor_sizes, aspect_ratios
225227
)
228+
self.anchor_generator = anchor_generator
226229

227230
if head is None:
228-
head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location())
231+
head = RetinaNetHead(backbone.out_channels, num_classes, anchor_generator.num_anchors_per_location()[0])
229232
self.head = head
230233

231234
if image_mean is None:

0 commit comments

Comments
 (0)