-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Add FX feature extraction as an alternative to intermediate_layer_getter #4302
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
b84f8bd
add fx feature extraction util
alexander-soare b0e6402
Make it possible to use train and eval mode
alexander-soare 23bb71f
FX feature extraction - Tweaks and small bug fixes
alexander-soare 9dce734
FX feature extraction - add tests
alexander-soare fa23bd8
move to feature_extraction.py, add LeafModuleAwareTracer, add docs
alexander-soare 0581ba7
Tweaks to docs
alexander-soare 7ecd15b
addressing latest round of feedback
alexander-soare d4efb7d
undo line spacing changes
alexander-soare d6a834e
change type hints in docstrings
alexander-soare 8348b7b
fix sphinx indentation
alexander-soare fc831e0
expose feature_extraction
alexander-soare 82fea80
add maskrcnn example
alexander-soare 8b51e04
add api refernce subheading
alexander-soare 2c8e2f8
Merge branch 'main' into fx-feature-extraction
fmassa d7591c1
address latest review notes, refactor names, fix regex, cosmetics
alexander-soare 2111ef9
Merge branch 'main' into fx-feature-extraction
fmassa 2d6cdfd
Add back efficientnet to models
fmassa 31f86f6
fix tests for effnet
alexander-soare a4973f8
fix linting issue
alexander-soare e8fec33
fix test tracer kwargs
alexander-soare 48024a2
Merge branch 'main' into fx-feature-extraction
fmassa a7a5818
Merge branch 'main' into fx-feature-extraction
fmassa File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
torchvision.models.feature_extraction | ||
===================================== | ||
|
||
.. currentmodule:: torchvision.models.feature_extraction | ||
|
||
Feature extraction utilities let us tap into our models to access intermediate | ||
transformations of our inputs. This could be useful for a variety of | ||
applications in computer vision. Just a few examples are: | ||
|
||
- Visualizing feature maps. | ||
- Extracting features to compute image descriptors for tasks like facial | ||
recognition, copy-detection, or image retrieval. | ||
- Passing selected features to downstream sub-networks for end-to-end training | ||
with a specific task in mind. For example, passing a hierarchy of features | ||
to a Feature Pyramid Network with object detection heads. | ||
|
||
Torchvision provides :func:`create_feature_extractor` for this purpose. | ||
It works by following roughly these steps: | ||
|
||
1. Symbolically tracing the model to get a graphical representation of | ||
how it transforms the input, step by step. | ||
2. Setting the user-selected graph nodes as ouputs. | ||
3. Removing all redundant nodes (anything downstream of the ouput nodes). | ||
4. Generating python code from the resulting graph and bundling that into a | ||
PyTorch module together with the graph itself. | ||
|
||
| | ||
|
||
The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_ | ||
provides a more general and detailed explanation of the above procedure and | ||
the inner workings of the symbolic tracing. | ||
|
||
Here is an example of how we might extract features for MaskRCNN: | ||
|
||
.. code-block:: python | ||
|
||
import torch | ||
from torchvision.models import resnet50 | ||
from torchvision.models.feature_extraction import get_graph_node_names | ||
from torchvision.models.feature_extraction import create_feature_extractor | ||
from torchvision.models.detection.mask_rcnn import MaskRCNN | ||
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork | ||
|
||
|
||
# To assist you in designing the feature extractor you may want to print out | ||
# the available nodes for resnet50. | ||
m = resnet50() | ||
train_nodes, eval_nodes = get_graph_node_names(resnet50()) | ||
|
||
# The lists returned, are the names of all the graph nodes (in order of | ||
# execution) for the input model traced in train mode and in eval mode | ||
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same | ||
# for this example. But if the model contains control flow that's dependent | ||
# on the training mode, they may be different. | ||
|
||
# To specify the nodes you want to extract, you could select the final node | ||
# that appears in each of the main layers: | ||
return_nodes = { | ||
# node_name: user-specified key for output dict | ||
'layer1.2.relu_2': 'layer1', | ||
'layer2.3.relu_2': 'layer2', | ||
'layer3.5.relu_2': 'layer3', | ||
'layer4.2.relu_2': 'layer4', | ||
} | ||
|
||
# But `create_feature_extractor` can also accept truncated node specifications | ||
# like "layer1", as it will just pick the last node that's a descendent of | ||
# of the specification. (Tip: be careful with this, especially when a layer | ||
# has multiple outputs. It's not always guaranteed that the last operation | ||
# performed is the one that corresponds to the output you desire. You should | ||
# consult the source code for the input model to confirm.) | ||
return_nodes = { | ||
'layer1': 'layer1', | ||
'layer2': 'layer2', | ||
'layer3': 'layer3', | ||
'layer4': 'layer4', | ||
} | ||
|
||
# Now you can build the feature extractor. This returns a module whose forward | ||
# method returns a dictionary like: | ||
# { | ||
# 'layer1': ouput of layer 1, | ||
# 'layer2': ouput of layer 2, | ||
# 'layer3': ouput of layer 3, | ||
# 'layer4': ouput of layer 4, | ||
# } | ||
create_feature_extractor(m, return_nodes=return_nodes) | ||
|
||
# Let's put all that together to wrap resnet50 with MaskRCNN | ||
|
||
# MaskRCNN requires a backbone with an attached FPN | ||
class Resnet50WithFPN(torch.nn.Module): | ||
def __init__(self): | ||
super(Resnet50WithFPN, self).__init__() | ||
# Get a resnet50 backbone | ||
m = resnet50() | ||
# Extract 4 main layers (note: you can also provide a list for return | ||
# nodes if the keys and the values are the same) | ||
self.body = create_feature_extractor( | ||
m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4']) | ||
# Dry run to get number of channels for FPN | ||
inp = torch.randn(2, 3, 224, 224) | ||
with torch.no_grad(): | ||
out = self.body(inp) | ||
in_channels_list = [o.shape[1] for o in out.values()] | ||
# Build FPN | ||
self.out_channels = 256 | ||
self.fpn = FeaturePyramidNetwork( | ||
in_channels_list, out_channels=self.out_channels) | ||
|
||
def forward(self, x): | ||
x = self.body(x) | ||
x = self.fpn(x) | ||
return x | ||
|
||
|
||
# Now we can build our model! | ||
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval() | ||
|
||
|
||
API Reference | ||
------------- | ||
|
||
.. autofunction:: create_feature_extractor | ||
|
||
.. autofunction:: get_graph_node_names |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,253 @@ | ||
from functools import partial | ||
from itertools import chain | ||
import random | ||
|
||
import torch | ||
from torchvision import models | ||
import torchvision | ||
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone | ||
from torchvision.models.feature_extraction import create_feature_extractor | ||
from torchvision.models.feature_extraction import get_graph_node_names | ||
from torchvision.models._utils import IntermediateLayerGetter | ||
|
||
import pytest | ||
|
||
from common_utils import set_rng_seed | ||
|
||
|
||
def get_available_models(): | ||
# TODO add a registration mechanism to torchvision.models | ||
return [k for k, v in models.__dict__.items() | ||
if callable(v) and k[0].lower() == k[0] and k[0] != "_"] | ||
alexander-soare marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) | ||
def test_resnet_fpn_backbone(backbone_name): | ||
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') | ||
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) | ||
assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] | ||
|
||
|
||
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function | ||
def leaf_function(x): | ||
return int(x) | ||
|
||
|
||
class TestFxFeatureExtraction: | ||
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu') | ||
model_defaults = { | ||
'num_classes': 1, | ||
'pretrained': False | ||
} | ||
leaf_modules = [torchvision.ops.StochasticDepth] | ||
|
||
def _create_feature_extractor(self, *args, **kwargs): | ||
""" | ||
Apply leaf modules | ||
""" | ||
tracer_kwargs = {} | ||
if 'tracer_kwargs' not in kwargs: | ||
tracer_kwargs = {'leaf_modules': self.leaf_modules} | ||
else: | ||
tracer_kwargs = kwargs.pop('tracer_kwargs') | ||
return create_feature_extractor( | ||
*args, **kwargs, | ||
tracer_kwargs=tracer_kwargs, | ||
suppress_diff_warning=True) | ||
|
||
def _get_return_nodes(self, model): | ||
set_rng_seed(0) | ||
exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk'] | ||
train_nodes, eval_nodes = get_graph_node_names( | ||
model, tracer_kwargs={'leaf_modules': self.leaf_modules}, | ||
suppress_diff_warning=True) | ||
# Get rid of any nodes that don't return tensors as they cause issues | ||
# when testing backward pass. | ||
train_nodes = [n for n in train_nodes | ||
if not any(x in n for x in exclude_nodes_filter)] | ||
eval_nodes = [n for n in eval_nodes | ||
if not any(x in n for x in exclude_nodes_filter)] | ||
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10) | ||
|
||
@pytest.mark.parametrize('model_name', get_available_models()) | ||
def test_build_fx_feature_extractor(self, model_name): | ||
set_rng_seed(0) | ||
model = models.__dict__[model_name](**self.model_defaults).eval() | ||
train_return_nodes, eval_return_nodes = self._get_return_nodes(model) | ||
# Check that it works with both a list and dict for return nodes | ||
self._create_feature_extractor( | ||
model, train_return_nodes={v: v for v in train_return_nodes}, | ||
eval_return_nodes=eval_return_nodes) | ||
self._create_feature_extractor( | ||
model, train_return_nodes=train_return_nodes, | ||
eval_return_nodes=eval_return_nodes) | ||
# Check must specify return nodes | ||
with pytest.raises(AssertionError): | ||
self._create_feature_extractor(model) | ||
# Check return_nodes and train_return_nodes / eval_return nodes | ||
# mutual exclusivity | ||
with pytest.raises(AssertionError): | ||
self._create_feature_extractor( | ||
model, return_nodes=train_return_nodes, | ||
train_return_nodes=train_return_nodes) | ||
# Check train_return_nodes / eval_return nodes must both be specified | ||
with pytest.raises(AssertionError): | ||
self._create_feature_extractor( | ||
model, train_return_nodes=train_return_nodes) | ||
# Check invalid node name raises ValueError | ||
with pytest.raises(ValueError): | ||
# First just double check that this node really doesn't exist | ||
if not any(n.startswith('l') or n.startswith('l.') for n | ||
in chain(train_return_nodes, eval_return_nodes)): | ||
self._create_feature_extractor( | ||
model, train_return_nodes=['l'], eval_return_nodes=['l']) | ||
else: # otherwise skip this check | ||
raise ValueError | ||
|
||
@pytest.mark.parametrize('model_name', get_available_models()) | ||
def test_forward_backward(self, model_name): | ||
model = models.__dict__[model_name](**self.model_defaults).train() | ||
train_return_nodes, eval_return_nodes = self._get_return_nodes(model) | ||
model = self._create_feature_extractor( | ||
model, train_return_nodes=train_return_nodes, | ||
eval_return_nodes=eval_return_nodes) | ||
out = model(self.inp) | ||
sum([o.mean() for o in out.values()]).backward() | ||
|
||
def test_feature_extraction_methods_equivalence(self): | ||
model = models.resnet18(**self.model_defaults).eval() | ||
return_layers = { | ||
'layer1': 'layer1', | ||
'layer2': 'layer2', | ||
'layer3': 'layer3', | ||
'layer4': 'layer4' | ||
} | ||
|
||
ilg_model = IntermediateLayerGetter( | ||
model, return_layers).eval() | ||
fx_model = self._create_feature_extractor(model, return_layers) | ||
|
||
# Check that we have same parameters | ||
for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), | ||
fx_model.named_parameters()): | ||
assert n1 == n2 | ||
assert p1.equal(p2) | ||
|
||
# And that ouputs match | ||
with torch.no_grad(): | ||
ilg_out = ilg_model(self.inp) | ||
fgn_out = fx_model(self.inp) | ||
assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys())) | ||
for k in ilg_out.keys(): | ||
assert ilg_out[k].equal(fgn_out[k]) | ||
|
||
@pytest.mark.parametrize('model_name', get_available_models()) | ||
def test_jit_forward_backward(self, model_name): | ||
set_rng_seed(0) | ||
model = models.__dict__[model_name](**self.model_defaults).train() | ||
train_return_nodes, eval_return_nodes = self._get_return_nodes(model) | ||
model = self._create_feature_extractor( | ||
model, train_return_nodes=train_return_nodes, | ||
eval_return_nodes=eval_return_nodes) | ||
model = torch.jit.script(model) | ||
fgn_out = model(self.inp) | ||
sum([o.mean() for o in fgn_out.values()]).backward() | ||
|
||
def test_train_eval(self): | ||
class TestModel(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.dropout = torch.nn.Dropout(p=1.) | ||
|
||
def forward(self, x): | ||
x = x.mean() | ||
x = self.dropout(x) # dropout | ||
if self.training: | ||
x += 100 # add | ||
else: | ||
x *= 0 # mul | ||
x -= 0 # sub | ||
return x | ||
|
||
model = TestModel() | ||
|
||
train_return_nodes = ['dropout', 'add', 'sub'] | ||
eval_return_nodes = ['dropout', 'mul', 'sub'] | ||
|
||
def checks(model, mode): | ||
with torch.no_grad(): | ||
out = model(torch.ones(10, 10)) | ||
if mode == 'train': | ||
# Check that dropout is respected | ||
assert out['dropout'].item() == 0 | ||
# Check that control flow dependent on training_mode is respected | ||
assert out['sub'].item() == 100 | ||
assert 'add' in out | ||
assert 'mul' not in out | ||
elif mode == 'eval': | ||
# Check that dropout is respected | ||
assert out['dropout'].item() == 1 | ||
# Check that control flow dependent on training_mode is respected | ||
assert out['sub'].item() == 0 | ||
assert 'mul' in out | ||
assert 'add' not in out | ||
|
||
# Starting from train mode | ||
model.train() | ||
fx_model = self._create_feature_extractor( | ||
model, train_return_nodes=train_return_nodes, | ||
eval_return_nodes=eval_return_nodes) | ||
# Check that the models stay in their original training state | ||
assert model.training | ||
assert fx_model.training | ||
# Check outputs | ||
checks(fx_model, 'train') | ||
# Check outputs after switching to eval mode | ||
fx_model.eval() | ||
checks(fx_model, 'eval') | ||
|
||
# Starting from eval mode | ||
model.eval() | ||
fx_model = self._create_feature_extractor( | ||
model, train_return_nodes=train_return_nodes, | ||
eval_return_nodes=eval_return_nodes) | ||
# Check that the models stay in their original training state | ||
assert not model.training | ||
assert not fx_model.training | ||
# Check outputs | ||
checks(fx_model, 'eval') | ||
# Check outputs after switching to train mode | ||
fx_model.train() | ||
checks(fx_model, 'train') | ||
|
||
def test_leaf_module_and_function(self): | ||
class LeafModule(torch.nn.Module): | ||
def forward(self, x): | ||
# This would raise a TypeError if it were not in a leaf module | ||
int(x.shape[0]) | ||
return torch.nn.functional.relu(x + 4) | ||
|
||
class TestModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.conv = torch.nn.Conv2d(3, 1, 3) | ||
self.leaf_module = LeafModule() | ||
|
||
def forward(self, x): | ||
leaf_function(x.shape[0]) | ||
x = self.conv(x) | ||
return self.leaf_module(x) | ||
|
||
model = self._create_feature_extractor( | ||
TestModule(), return_nodes=['leaf_module'], | ||
tracer_kwargs={'leaf_modules': [LeafModule], | ||
'autowrap_functions': [leaf_function]}).train() | ||
|
||
# Check that LeafModule is not in the list of nodes | ||
assert 'relu' not in [str(n) for n in model.graph.nodes] | ||
assert 'leaf_module' in [str(n) for n in model.graph.nodes] | ||
|
||
# Check forward | ||
out = model(self.inp) | ||
# And backward | ||
out['leaf_module'].mean().backward() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ | |
from . import detection | ||
from . import video | ||
from . import quantization | ||
from . import feature_extraction |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.