Skip to content

Commit 72d650a

Browse files
Add FX feature extraction as an alternative to intermediate_layer_getter (#4302)
* add fx feature extraction util * Make it possible to use train and eval mode * FX feature extraction - Tweaks and small bug fixes * FX feature extraction - add tests * move to feature_extraction.py, add LeafModuleAwareTracer, add docs * Tweaks to docs * addressing latest round of feedback * undo line spacing changes * change type hints in docstrings * fix sphinx indentation * expose feature_extraction * add maskrcnn example * add api refernce subheading * address latest review notes, refactor names, fix regex, cosmetics * Add back efficientnet to models * fix tests for effnet * fix linting issue * fix test tracer kwargs Co-authored-by: Francisco Massa <[email protected]>
1 parent 981ccfd commit 72d650a

File tree

5 files changed

+878
-0
lines changed

5 files changed

+878
-0
lines changed

docs/source/feature_extraction.rst

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
torchvision.models.feature_extraction
2+
=====================================
3+
4+
.. currentmodule:: torchvision.models.feature_extraction
5+
6+
Feature extraction utilities let us tap into our models to access intermediate
7+
transformations of our inputs. This could be useful for a variety of
8+
applications in computer vision. Just a few examples are:
9+
10+
- Visualizing feature maps.
11+
- Extracting features to compute image descriptors for tasks like facial
12+
recognition, copy-detection, or image retrieval.
13+
- Passing selected features to downstream sub-networks for end-to-end training
14+
with a specific task in mind. For example, passing a hierarchy of features
15+
to a Feature Pyramid Network with object detection heads.
16+
17+
Torchvision provides :func:`create_feature_extractor` for this purpose.
18+
It works by following roughly these steps:
19+
20+
1. Symbolically tracing the model to get a graphical representation of
21+
how it transforms the input, step by step.
22+
2. Setting the user-selected graph nodes as ouputs.
23+
3. Removing all redundant nodes (anything downstream of the ouput nodes).
24+
4. Generating python code from the resulting graph and bundling that into a
25+
PyTorch module together with the graph itself.
26+
27+
|
28+
29+
The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
30+
provides a more general and detailed explanation of the above procedure and
31+
the inner workings of the symbolic tracing.
32+
33+
Here is an example of how we might extract features for MaskRCNN:
34+
35+
.. code-block:: python
36+
37+
import torch
38+
from torchvision.models import resnet50
39+
from torchvision.models.feature_extraction import get_graph_node_names
40+
from torchvision.models.feature_extraction import create_feature_extractor
41+
from torchvision.models.detection.mask_rcnn import MaskRCNN
42+
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork
43+
44+
45+
# To assist you in designing the feature extractor you may want to print out
46+
# the available nodes for resnet50.
47+
m = resnet50()
48+
train_nodes, eval_nodes = get_graph_node_names(resnet50())
49+
50+
# The lists returned, are the names of all the graph nodes (in order of
51+
# execution) for the input model traced in train mode and in eval mode
52+
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same
53+
# for this example. But if the model contains control flow that's dependent
54+
# on the training mode, they may be different.
55+
56+
# To specify the nodes you want to extract, you could select the final node
57+
# that appears in each of the main layers:
58+
return_nodes = {
59+
# node_name: user-specified key for output dict
60+
'layer1.2.relu_2': 'layer1',
61+
'layer2.3.relu_2': 'layer2',
62+
'layer3.5.relu_2': 'layer3',
63+
'layer4.2.relu_2': 'layer4',
64+
}
65+
66+
# But `create_feature_extractor` can also accept truncated node specifications
67+
# like "layer1", as it will just pick the last node that's a descendent of
68+
# of the specification. (Tip: be careful with this, especially when a layer
69+
# has multiple outputs. It's not always guaranteed that the last operation
70+
# performed is the one that corresponds to the output you desire. You should
71+
# consult the source code for the input model to confirm.)
72+
return_nodes = {
73+
'layer1': 'layer1',
74+
'layer2': 'layer2',
75+
'layer3': 'layer3',
76+
'layer4': 'layer4',
77+
}
78+
79+
# Now you can build the feature extractor. This returns a module whose forward
80+
# method returns a dictionary like:
81+
# {
82+
# 'layer1': ouput of layer 1,
83+
# 'layer2': ouput of layer 2,
84+
# 'layer3': ouput of layer 3,
85+
# 'layer4': ouput of layer 4,
86+
# }
87+
create_feature_extractor(m, return_nodes=return_nodes)
88+
89+
# Let's put all that together to wrap resnet50 with MaskRCNN
90+
91+
# MaskRCNN requires a backbone with an attached FPN
92+
class Resnet50WithFPN(torch.nn.Module):
93+
def __init__(self):
94+
super(Resnet50WithFPN, self).__init__()
95+
# Get a resnet50 backbone
96+
m = resnet50()
97+
# Extract 4 main layers (note: you can also provide a list for return
98+
# nodes if the keys and the values are the same)
99+
self.body = create_feature_extractor(
100+
m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4'])
101+
# Dry run to get number of channels for FPN
102+
inp = torch.randn(2, 3, 224, 224)
103+
with torch.no_grad():
104+
out = self.body(inp)
105+
in_channels_list = [o.shape[1] for o in out.values()]
106+
# Build FPN
107+
self.out_channels = 256
108+
self.fpn = FeaturePyramidNetwork(
109+
in_channels_list, out_channels=self.out_channels)
110+
111+
def forward(self, x):
112+
x = self.body(x)
113+
x = self.fpn(x)
114+
return x
115+
116+
117+
# Now we can build our model!
118+
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()
119+
120+
121+
API Reference
122+
-------------
123+
124+
.. autofunction:: create_feature_extractor
125+
126+
.. autofunction:: get_graph_node_names

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ architectures, and common image transformations for computer vision.
3434
datasets
3535
io
3636
models
37+
feature_extraction
3738
ops
3839
transforms
3940
utils

test/test_backbone_utils.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,253 @@
1+
from functools import partial
2+
from itertools import chain
3+
import random
4+
15
import torch
6+
from torchvision import models
7+
import torchvision
28
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
9+
from torchvision.models.feature_extraction import create_feature_extractor
10+
from torchvision.models.feature_extraction import get_graph_node_names
11+
from torchvision.models._utils import IntermediateLayerGetter
312

413
import pytest
514

15+
from common_utils import set_rng_seed
16+
17+
18+
def get_available_models():
19+
# TODO add a registration mechanism to torchvision.models
20+
return [k for k, v in models.__dict__.items()
21+
if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
22+
623

724
@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50'))
825
def test_resnet_fpn_backbone(backbone_name):
926
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu')
1027
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
1128
assert list(y.keys()) == ['0', '1', '2', '3', 'pool']
29+
30+
31+
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
32+
def leaf_function(x):
33+
return int(x)
34+
35+
36+
class TestFxFeatureExtraction:
37+
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu')
38+
model_defaults = {
39+
'num_classes': 1,
40+
'pretrained': False
41+
}
42+
leaf_modules = [torchvision.ops.StochasticDepth]
43+
44+
def _create_feature_extractor(self, *args, **kwargs):
45+
"""
46+
Apply leaf modules
47+
"""
48+
tracer_kwargs = {}
49+
if 'tracer_kwargs' not in kwargs:
50+
tracer_kwargs = {'leaf_modules': self.leaf_modules}
51+
else:
52+
tracer_kwargs = kwargs.pop('tracer_kwargs')
53+
return create_feature_extractor(
54+
*args, **kwargs,
55+
tracer_kwargs=tracer_kwargs,
56+
suppress_diff_warning=True)
57+
58+
def _get_return_nodes(self, model):
59+
set_rng_seed(0)
60+
exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk']
61+
train_nodes, eval_nodes = get_graph_node_names(
62+
model, tracer_kwargs={'leaf_modules': self.leaf_modules},
63+
suppress_diff_warning=True)
64+
# Get rid of any nodes that don't return tensors as they cause issues
65+
# when testing backward pass.
66+
train_nodes = [n for n in train_nodes
67+
if not any(x in n for x in exclude_nodes_filter)]
68+
eval_nodes = [n for n in eval_nodes
69+
if not any(x in n for x in exclude_nodes_filter)]
70+
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)
71+
72+
@pytest.mark.parametrize('model_name', get_available_models())
73+
def test_build_fx_feature_extractor(self, model_name):
74+
set_rng_seed(0)
75+
model = models.__dict__[model_name](**self.model_defaults).eval()
76+
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
77+
# Check that it works with both a list and dict for return nodes
78+
self._create_feature_extractor(
79+
model, train_return_nodes={v: v for v in train_return_nodes},
80+
eval_return_nodes=eval_return_nodes)
81+
self._create_feature_extractor(
82+
model, train_return_nodes=train_return_nodes,
83+
eval_return_nodes=eval_return_nodes)
84+
# Check must specify return nodes
85+
with pytest.raises(AssertionError):
86+
self._create_feature_extractor(model)
87+
# Check return_nodes and train_return_nodes / eval_return nodes
88+
# mutual exclusivity
89+
with pytest.raises(AssertionError):
90+
self._create_feature_extractor(
91+
model, return_nodes=train_return_nodes,
92+
train_return_nodes=train_return_nodes)
93+
# Check train_return_nodes / eval_return nodes must both be specified
94+
with pytest.raises(AssertionError):
95+
self._create_feature_extractor(
96+
model, train_return_nodes=train_return_nodes)
97+
# Check invalid node name raises ValueError
98+
with pytest.raises(ValueError):
99+
# First just double check that this node really doesn't exist
100+
if not any(n.startswith('l') or n.startswith('l.') for n
101+
in chain(train_return_nodes, eval_return_nodes)):
102+
self._create_feature_extractor(
103+
model, train_return_nodes=['l'], eval_return_nodes=['l'])
104+
else: # otherwise skip this check
105+
raise ValueError
106+
107+
@pytest.mark.parametrize('model_name', get_available_models())
108+
def test_forward_backward(self, model_name):
109+
model = models.__dict__[model_name](**self.model_defaults).train()
110+
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
111+
model = self._create_feature_extractor(
112+
model, train_return_nodes=train_return_nodes,
113+
eval_return_nodes=eval_return_nodes)
114+
out = model(self.inp)
115+
sum([o.mean() for o in out.values()]).backward()
116+
117+
def test_feature_extraction_methods_equivalence(self):
118+
model = models.resnet18(**self.model_defaults).eval()
119+
return_layers = {
120+
'layer1': 'layer1',
121+
'layer2': 'layer2',
122+
'layer3': 'layer3',
123+
'layer4': 'layer4'
124+
}
125+
126+
ilg_model = IntermediateLayerGetter(
127+
model, return_layers).eval()
128+
fx_model = self._create_feature_extractor(model, return_layers)
129+
130+
# Check that we have same parameters
131+
for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(),
132+
fx_model.named_parameters()):
133+
assert n1 == n2
134+
assert p1.equal(p2)
135+
136+
# And that ouputs match
137+
with torch.no_grad():
138+
ilg_out = ilg_model(self.inp)
139+
fgn_out = fx_model(self.inp)
140+
assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
141+
for k in ilg_out.keys():
142+
assert ilg_out[k].equal(fgn_out[k])
143+
144+
@pytest.mark.parametrize('model_name', get_available_models())
145+
def test_jit_forward_backward(self, model_name):
146+
set_rng_seed(0)
147+
model = models.__dict__[model_name](**self.model_defaults).train()
148+
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
149+
model = self._create_feature_extractor(
150+
model, train_return_nodes=train_return_nodes,
151+
eval_return_nodes=eval_return_nodes)
152+
model = torch.jit.script(model)
153+
fgn_out = model(self.inp)
154+
sum([o.mean() for o in fgn_out.values()]).backward()
155+
156+
def test_train_eval(self):
157+
class TestModel(torch.nn.Module):
158+
def __init__(self):
159+
super().__init__()
160+
self.dropout = torch.nn.Dropout(p=1.)
161+
162+
def forward(self, x):
163+
x = x.mean()
164+
x = self.dropout(x) # dropout
165+
if self.training:
166+
x += 100 # add
167+
else:
168+
x *= 0 # mul
169+
x -= 0 # sub
170+
return x
171+
172+
model = TestModel()
173+
174+
train_return_nodes = ['dropout', 'add', 'sub']
175+
eval_return_nodes = ['dropout', 'mul', 'sub']
176+
177+
def checks(model, mode):
178+
with torch.no_grad():
179+
out = model(torch.ones(10, 10))
180+
if mode == 'train':
181+
# Check that dropout is respected
182+
assert out['dropout'].item() == 0
183+
# Check that control flow dependent on training_mode is respected
184+
assert out['sub'].item() == 100
185+
assert 'add' in out
186+
assert 'mul' not in out
187+
elif mode == 'eval':
188+
# Check that dropout is respected
189+
assert out['dropout'].item() == 1
190+
# Check that control flow dependent on training_mode is respected
191+
assert out['sub'].item() == 0
192+
assert 'mul' in out
193+
assert 'add' not in out
194+
195+
# Starting from train mode
196+
model.train()
197+
fx_model = self._create_feature_extractor(
198+
model, train_return_nodes=train_return_nodes,
199+
eval_return_nodes=eval_return_nodes)
200+
# Check that the models stay in their original training state
201+
assert model.training
202+
assert fx_model.training
203+
# Check outputs
204+
checks(fx_model, 'train')
205+
# Check outputs after switching to eval mode
206+
fx_model.eval()
207+
checks(fx_model, 'eval')
208+
209+
# Starting from eval mode
210+
model.eval()
211+
fx_model = self._create_feature_extractor(
212+
model, train_return_nodes=train_return_nodes,
213+
eval_return_nodes=eval_return_nodes)
214+
# Check that the models stay in their original training state
215+
assert not model.training
216+
assert not fx_model.training
217+
# Check outputs
218+
checks(fx_model, 'eval')
219+
# Check outputs after switching to train mode
220+
fx_model.train()
221+
checks(fx_model, 'train')
222+
223+
def test_leaf_module_and_function(self):
224+
class LeafModule(torch.nn.Module):
225+
def forward(self, x):
226+
# This would raise a TypeError if it were not in a leaf module
227+
int(x.shape[0])
228+
return torch.nn.functional.relu(x + 4)
229+
230+
class TestModule(torch.nn.Module):
231+
def __init__(self):
232+
super().__init__()
233+
self.conv = torch.nn.Conv2d(3, 1, 3)
234+
self.leaf_module = LeafModule()
235+
236+
def forward(self, x):
237+
leaf_function(x.shape[0])
238+
x = self.conv(x)
239+
return self.leaf_module(x)
240+
241+
model = self._create_feature_extractor(
242+
TestModule(), return_nodes=['leaf_module'],
243+
tracer_kwargs={'leaf_modules': [LeafModule],
244+
'autowrap_functions': [leaf_function]}).train()
245+
246+
# Check that LeafModule is not in the list of nodes
247+
assert 'relu' not in [str(n) for n in model.graph.nodes]
248+
assert 'leaf_module' in [str(n) for n in model.graph.nodes]
249+
250+
# Check forward
251+
out = model(self.inp)
252+
# And backward
253+
out['leaf_module'].mean().backward()

torchvision/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
from . import detection
1414
from . import video
1515
from . import quantization
16+
from . import feature_extraction

0 commit comments

Comments
 (0)