Skip to content

Commit aecbb15

Browse files
authored
Add IntermediateLayerGetter on segmentation. (#5298)
1 parent b94004a commit aecbb15

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

torchvision/models/segmentation/deeplabv3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from .. import mobilenetv3
88
from .. import resnet
9-
from ..feature_extraction import create_feature_extractor
9+
from .._utils import IntermediateLayerGetter
1010
from ._utils import _SimpleSegmentationModel, _load_weights
1111
from .fcn import FCNHead
1212

@@ -121,7 +121,7 @@ def _deeplabv3_resnet(
121121
return_layers = {"layer4": "out"}
122122
if aux:
123123
return_layers["layer3"] = "aux"
124-
backbone = create_feature_extractor(backbone, return_layers)
124+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
125125

126126
aux_classifier = FCNHead(1024, num_classes) if aux else None
127127
classifier = DeepLabHead(2048, num_classes)
@@ -144,7 +144,7 @@ def _deeplabv3_mobilenetv3(
144144
return_layers = {str(out_pos): "out"}
145145
if aux:
146146
return_layers[str(aux_pos)] = "aux"
147-
backbone = create_feature_extractor(backbone, return_layers)
147+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
148148

149149
aux_classifier = FCNHead(aux_inplanes, num_classes) if aux else None
150150
classifier = DeepLabHead(out_inplanes, num_classes)

torchvision/models/segmentation/fcn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from torch import nn
44

55
from .. import resnet
6-
from ..feature_extraction import create_feature_extractor
6+
from .._utils import IntermediateLayerGetter
77
from ._utils import _SimpleSegmentationModel, _load_weights
88

99

@@ -57,7 +57,7 @@ def _fcn_resnet(
5757
return_layers = {"layer4": "out"}
5858
if aux:
5959
return_layers["layer3"] = "aux"
60-
backbone = create_feature_extractor(backbone, return_layers)
60+
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)
6161

6262
aux_classifier = FCNHead(1024, num_classes) if aux else None
6363
classifier = FCNHead(2048, num_classes)

torchvision/models/segmentation/lraspp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from ...utils import _log_api_usage_once
88
from .. import mobilenetv3
9-
from ..feature_extraction import create_feature_extractor
9+
from .._utils import IntermediateLayerGetter
1010
from ._utils import _load_weights
1111

1212

@@ -90,7 +90,7 @@ def _lraspp_mobilenetv3(backbone: mobilenetv3.MobileNetV3, num_classes: int) ->
9090
high_pos = stage_indices[-1] # use C5 which has output_stride = 16
9191
low_channels = backbone[low_pos].out_channels
9292
high_channels = backbone[high_pos].out_channels
93-
backbone = create_feature_extractor(backbone, {str(low_pos): "low", str(high_pos): "high"})
93+
backbone = IntermediateLayerGetter(backbone, return_layers={str(low_pos): "low", str(high_pos): "high"})
9494

9595
return LRASPP(backbone, low_channels, high_channels, num_classes)
9696

0 commit comments

Comments
 (0)