16
16
}
17
17
18
18
19
- def _segm_resnet (name , backbone_name , num_classes , aux , pretrained_backbone = True ):
20
- backbone = resnet .__dict__ [backbone_name ](
21
- pretrained = pretrained_backbone ,
22
- replace_stride_with_dilation = [False , True , True ])
23
-
24
- return_layers = {'layer4' : 'out' }
19
+ def _segm_model (name , backbone_name , num_classes , aux , pretrained_backbone = True ):
20
+ if 'resnet' in backbone_name :
21
+ backbone = resnet .__dict__ [backbone_name ](
22
+ pretrained = pretrained_backbone ,
23
+ replace_stride_with_dilation = [False , True , True ])
24
+ out_layer = 'layer4'
25
+ aux_layer = 'layer3'
26
+ else :
27
+ raise NotImplementedError ('backbone {} is not supported as of now' .format (backbone_name ))
28
+
29
+ return_layers = {out_layer : 'out' }
25
30
if aux :
26
- return_layers ['layer3' ] = 'aux'
31
+ return_layers [aux_layer ] = 'aux'
27
32
backbone = IntermediateLayerGetter (backbone , return_layers = return_layers )
28
33
29
34
aux_classifier = None
@@ -46,7 +51,7 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True
46
51
def _load_model (arch_type , backbone , pretrained , progress , num_classes , aux_loss , ** kwargs ):
47
52
if pretrained :
48
53
aux_loss = True
49
- model = _segm_resnet (arch_type , backbone , num_classes , aux_loss , ** kwargs )
54
+ model = _segm_model (arch_type , backbone , num_classes , aux_loss , ** kwargs )
50
55
if pretrained :
51
56
arch = arch_type + '_' + backbone + '_coco'
52
57
model_url = model_urls [arch ]
0 commit comments