Skip to content

Commit 77da44c

Browse files
committed
Making _segm_resnet() generic and reusable.
1 parent e04de77 commit 77da44c

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

torchvision/models/segmentation/segmentation.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,19 @@
1616
}
1717

1818

19-
def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True):
20-
backbone = resnet.__dict__[backbone_name](
21-
pretrained=pretrained_backbone,
22-
replace_stride_with_dilation=[False, True, True])
23-
24-
return_layers = {'layer4': 'out'}
19+
def _segm_model(name, backbone_name, num_classes, aux, pretrained_backbone=True):
20+
if 'resnet' in backbone_name:
21+
backbone = resnet.__dict__[backbone_name](
22+
pretrained=pretrained_backbone,
23+
replace_stride_with_dilation=[False, True, True])
24+
out_layer = 'layer4'
25+
aux_layer = 'layer3'
26+
else:
27+
raise NotImplementedError('backbone {} is not supported as of now'.format(backbone_name))
28+
29+
return_layers = {out_layer: 'out'}
2530
if aux:
26-
return_layers['layer3'] = 'aux'
31+
return_layers[aux_layer] = 'aux'
2732
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
2833

2934
aux_classifier = None
@@ -46,7 +51,7 @@ def _segm_resnet(name, backbone_name, num_classes, aux, pretrained_backbone=True
4651
def _load_model(arch_type, backbone, pretrained, progress, num_classes, aux_loss, **kwargs):
4752
if pretrained:
4853
aux_loss = True
49-
model = _segm_resnet(arch_type, backbone, num_classes, aux_loss, **kwargs)
54+
model = _segm_model(arch_type, backbone, num_classes, aux_loss, **kwargs)
5055
if pretrained:
5156
arch = arch_type + '_' + backbone + '_coco'
5257
model_url = model_urls[arch]

0 commit comments

Comments
 (0)