6
6
7
7
from ..utils import load_state_dict_from_url
8
8
9
- from .rpn import AnchorGenerator
9
+ from . import _utils as det_utils
10
+ from .anchor_utils import AnchorGenerator
10
11
from .transform import GeneralizedRCNNTransform
11
12
from .backbone_utils import resnet_fpn_backbone
12
13
from ...ops .feature_pyramid_network import LastLevelP6P7
@@ -173,7 +174,7 @@ class RetinaNet(nn.Module):
173
174
>>> import torch
174
175
>>> import torchvision
175
176
>>> from torchvision.models.detection import RetinaNet
176
- >>> from torchvision.models.detection.rpn import AnchorGenerator
177
+ >>> from torchvision.models.detection.anchor_utils import AnchorGenerator
177
178
>>> # load a pre-trained model for classification and return
178
179
>>> # only the features
179
180
>>> backbone = torchvision.models.mobilenet_v2(pretrained=True).features
@@ -187,8 +188,8 @@ class RetinaNet(nn.Module):
187
188
>>> # ratios. We have a Tuple[Tuple[int]] because each feature
188
189
>>> # map could potentially have different sizes and
189
190
>>> # aspect ratios
190
- >>> anchor_generator = AnchorGenerator(sizes=(( 32, 64, 128, 256, 512),) ,
191
- >>> aspect_ratios=(( 0.5, 1.0, 2.0),) )
191
+ >>> anchor_generator = AnchorGenerator(sizes=[[x, x * 2 ** (1.0 / 3), x * 2 ** (2.0 / 3)] for x in [ 32, 64, 128, 256, 512]] ,
192
+ >>> aspect_ratios=[[ 0.5, 1.0, 2.0]] * 5 )
192
193
>>>
193
194
>>> # put the pieces together inside a RetinaNet model
194
195
>>> model = RetinaNet(backbone,
@@ -218,14 +219,16 @@ def __init__(self, backbone, num_classes,
218
219
assert isinstance (anchor_generator , (AnchorGenerator , type (None )))
219
220
220
221
if anchor_generator is None :
221
- anchor_sizes = ((32 ,), (64 ,), (128 ,), (256 ,), (512 ,))
222
+ # TODO: Set correct default values
223
+ anchor_sizes = [[x , x * 2 ** (1.0 / 3 ), x * 2 ** (2.0 / 3 )] for x in [32 , 64 , 128 , 256 , 512 ]]
222
224
aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
223
- self . anchor_generator = AnchorGenerator (
225
+ anchor_generator = AnchorGenerator (
224
226
anchor_sizes , aspect_ratios
225
227
)
228
+ self .anchor_generator = anchor_generator
226
229
227
230
if head is None :
228
- head = RetinaNetHead (backbone .out_channels , num_classes , anchor_generator .num_anchors_per_location ())
231
+ head = RetinaNetHead (backbone .out_channels , num_classes , anchor_generator .num_anchors_per_location ()[ 0 ] )
229
232
self .head = head
230
233
231
234
if image_mean is None :
0 commit comments