Skip to content

Add FX feature extraction as an alternative to intermediate_layer_getter #4302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Sep 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b84f8bd
add fx feature extraction util
alexander-soare Aug 21, 2021
b0e6402
Make it possible to use train and eval mode
alexander-soare Aug 28, 2021
23bb71f
FX feature extraction - Tweaks and small bug fixes
alexander-soare Sep 1, 2021
9dce734
FX feature extraction - add tests
alexander-soare Sep 1, 2021
fa23bd8
move to feature_extraction.py, add LeafModuleAwareTracer, add docs
alexander-soare Sep 1, 2021
0581ba7
Tweaks to docs
alexander-soare Sep 1, 2021
7ecd15b
addressing latest round of feedback
alexander-soare Sep 3, 2021
d4efb7d
undo line spacing changes
alexander-soare Sep 3, 2021
d6a834e
change type hints in docstrings
alexander-soare Sep 3, 2021
8348b7b
fix sphinx indentation
alexander-soare Sep 3, 2021
fc831e0
expose feature_extraction
alexander-soare Sep 4, 2021
82fea80
add maskrcnn example
alexander-soare Sep 4, 2021
8b51e04
add api refernce subheading
alexander-soare Sep 4, 2021
2c8e2f8
Merge branch 'main' into fx-feature-extraction
fmassa Sep 6, 2021
d7591c1
address latest review notes, refactor names, fix regex, cosmetics
alexander-soare Sep 6, 2021
2111ef9
Merge branch 'main' into fx-feature-extraction
fmassa Sep 6, 2021
2d6cdfd
Add back efficientnet to models
fmassa Sep 6, 2021
31f86f6
fix tests for effnet
alexander-soare Sep 6, 2021
a4973f8
fix linting issue
alexander-soare Sep 6, 2021
e8fec33
fix test tracer kwargs
alexander-soare Sep 6, 2021
48024a2
Merge branch 'main' into fx-feature-extraction
fmassa Sep 6, 2021
a7a5818
Merge branch 'main' into fx-feature-extraction
fmassa Sep 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions docs/source/feature_extraction.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
torchvision.models.feature_extraction
=====================================

.. currentmodule:: torchvision.models.feature_extraction

Feature extraction utilities let us tap into our models to access intermediate
transformations of our inputs. This could be useful for a variety of
applications in computer vision. Just a few examples are:

- Visualizing feature maps.
- Extracting features to compute image descriptors for tasks like facial
recognition, copy-detection, or image retrieval.
- Passing selected features to downstream sub-networks for end-to-end training
with a specific task in mind. For example, passing a hierarchy of features
to a Feature Pyramid Network with object detection heads.

Torchvision provides :func:`create_feature_extractor` for this purpose.
It works by following roughly these steps:

1. Symbolically tracing the model to get a graphical representation of
how it transforms the input, step by step.
2. Setting the user-selected graph nodes as ouputs.
3. Removing all redundant nodes (anything downstream of the ouput nodes).
4. Generating python code from the resulting graph and bundling that into a
PyTorch module together with the graph itself.

|

The `torch.fx documentation <https://pytorch.org/docs/stable/fx.html>`_
provides a more general and detailed explanation of the above procedure and
the inner workings of the symbolic tracing.

Here is an example of how we might extract features for MaskRCNN:

.. code-block:: python

import torch
from torchvision.models import resnet50
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.detection.mask_rcnn import MaskRCNN
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork


# To assist you in designing the feature extractor you may want to print out
# the available nodes for resnet50.
m = resnet50()
train_nodes, eval_nodes = get_graph_node_names(resnet50())

# The lists returned, are the names of all the graph nodes (in order of
# execution) for the input model traced in train mode and in eval mode
# respectively. You'll find that `train_nodes` and `eval_nodes` are the same
# for this example. But if the model contains control flow that's dependent
# on the training mode, they may be different.

# To specify the nodes you want to extract, you could select the final node
# that appears in each of the main layers:
return_nodes = {
# node_name: user-specified key for output dict
'layer1.2.relu_2': 'layer1',
'layer2.3.relu_2': 'layer2',
'layer3.5.relu_2': 'layer3',
'layer4.2.relu_2': 'layer4',
}

# But `create_feature_extractor` can also accept truncated node specifications
# like "layer1", as it will just pick the last node that's a descendent of
# of the specification. (Tip: be careful with this, especially when a layer
# has multiple outputs. It's not always guaranteed that the last operation
# performed is the one that corresponds to the output you desire. You should
# consult the source code for the input model to confirm.)
return_nodes = {
'layer1': 'layer1',
'layer2': 'layer2',
'layer3': 'layer3',
'layer4': 'layer4',
}

# Now you can build the feature extractor. This returns a module whose forward
# method returns a dictionary like:
# {
# 'layer1': ouput of layer 1,
# 'layer2': ouput of layer 2,
# 'layer3': ouput of layer 3,
# 'layer4': ouput of layer 4,
# }
create_feature_extractor(m, return_nodes=return_nodes)

# Let's put all that together to wrap resnet50 with MaskRCNN

# MaskRCNN requires a backbone with an attached FPN
class Resnet50WithFPN(torch.nn.Module):
def __init__(self):
super(Resnet50WithFPN, self).__init__()
# Get a resnet50 backbone
m = resnet50()
# Extract 4 main layers (note: you can also provide a list for return
# nodes if the keys and the values are the same)
self.body = create_feature_extractor(
m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4'])
# Dry run to get number of channels for FPN
inp = torch.randn(2, 3, 224, 224)
with torch.no_grad():
out = self.body(inp)
in_channels_list = [o.shape[1] for o in out.values()]
# Build FPN
self.out_channels = 256
self.fpn = FeaturePyramidNetwork(
in_channels_list, out_channels=self.out_channels)

def forward(self, x):
x = self.body(x)
x = self.fpn(x)
return x


# Now we can build our model!
model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval()


API Reference
-------------

.. autofunction:: create_feature_extractor

.. autofunction:: get_graph_node_names
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ architectures, and common image transformations for computer vision.
datasets
io
models
feature_extraction
ops
transforms
utils
Expand Down
242 changes: 242 additions & 0 deletions test/test_backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,253 @@
from functools import partial
from itertools import chain
import random

import torch
from torchvision import models
import torchvision
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.feature_extraction import create_feature_extractor
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models._utils import IntermediateLayerGetter

import pytest

from common_utils import set_rng_seed


def get_available_models():
# TODO add a registration mechanism to torchvision.models
return [k for k, v in models.__dict__.items()
if callable(v) and k[0].lower() == k[0] and k[0] != "_"]


@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50'))
def test_resnet_fpn_backbone(backbone_name):
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu')
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
assert list(y.keys()) == ['0', '1', '2', '3', 'pool']


# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
def leaf_function(x):
return int(x)


class TestFxFeatureExtraction:
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu')
model_defaults = {
'num_classes': 1,
'pretrained': False
}
leaf_modules = [torchvision.ops.StochasticDepth]

def _create_feature_extractor(self, *args, **kwargs):
"""
Apply leaf modules
"""
tracer_kwargs = {}
if 'tracer_kwargs' not in kwargs:
tracer_kwargs = {'leaf_modules': self.leaf_modules}
else:
tracer_kwargs = kwargs.pop('tracer_kwargs')
return create_feature_extractor(
*args, **kwargs,
tracer_kwargs=tracer_kwargs,
suppress_diff_warning=True)

def _get_return_nodes(self, model):
set_rng_seed(0)
exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk']
train_nodes, eval_nodes = get_graph_node_names(
model, tracer_kwargs={'leaf_modules': self.leaf_modules},
suppress_diff_warning=True)
# Get rid of any nodes that don't return tensors as they cause issues
# when testing backward pass.
train_nodes = [n for n in train_nodes
if not any(x in n for x in exclude_nodes_filter)]
eval_nodes = [n for n in eval_nodes
if not any(x in n for x in exclude_nodes_filter)]
return random.sample(train_nodes, 10), random.sample(eval_nodes, 10)

@pytest.mark.parametrize('model_name', get_available_models())
def test_build_fx_feature_extractor(self, model_name):
set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).eval()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
# Check that it works with both a list and dict for return nodes
self._create_feature_extractor(
model, train_return_nodes={v: v for v in train_return_nodes},
eval_return_nodes=eval_return_nodes)
self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
# Check must specify return nodes
with pytest.raises(AssertionError):
self._create_feature_extractor(model)
# Check return_nodes and train_return_nodes / eval_return nodes
# mutual exclusivity
with pytest.raises(AssertionError):
self._create_feature_extractor(
model, return_nodes=train_return_nodes,
train_return_nodes=train_return_nodes)
# Check train_return_nodes / eval_return nodes must both be specified
with pytest.raises(AssertionError):
self._create_feature_extractor(
model, train_return_nodes=train_return_nodes)
# Check invalid node name raises ValueError
with pytest.raises(ValueError):
# First just double check that this node really doesn't exist
if not any(n.startswith('l') or n.startswith('l.') for n
in chain(train_return_nodes, eval_return_nodes)):
self._create_feature_extractor(
model, train_return_nodes=['l'], eval_return_nodes=['l'])
else: # otherwise skip this check
raise ValueError

@pytest.mark.parametrize('model_name', get_available_models())
def test_forward_backward(self, model_name):
model = models.__dict__[model_name](**self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
out = model(self.inp)
sum([o.mean() for o in out.values()]).backward()

def test_feature_extraction_methods_equivalence(self):
model = models.resnet18(**self.model_defaults).eval()
return_layers = {
'layer1': 'layer1',
'layer2': 'layer2',
'layer3': 'layer3',
'layer4': 'layer4'
}

ilg_model = IntermediateLayerGetter(
model, return_layers).eval()
fx_model = self._create_feature_extractor(model, return_layers)

# Check that we have same parameters
for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(),
fx_model.named_parameters()):
assert n1 == n2
assert p1.equal(p2)

# And that ouputs match
with torch.no_grad():
ilg_out = ilg_model(self.inp)
fgn_out = fx_model(self.inp)
assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys()))
for k in ilg_out.keys():
assert ilg_out[k].equal(fgn_out[k])

@pytest.mark.parametrize('model_name', get_available_models())
def test_jit_forward_backward(self, model_name):
set_rng_seed(0)
model = models.__dict__[model_name](**self.model_defaults).train()
train_return_nodes, eval_return_nodes = self._get_return_nodes(model)
model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
model = torch.jit.script(model)
fgn_out = model(self.inp)
sum([o.mean() for o in fgn_out.values()]).backward()

def test_train_eval(self):
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.dropout = torch.nn.Dropout(p=1.)

def forward(self, x):
x = x.mean()
x = self.dropout(x) # dropout
if self.training:
x += 100 # add
else:
x *= 0 # mul
x -= 0 # sub
return x

model = TestModel()

train_return_nodes = ['dropout', 'add', 'sub']
eval_return_nodes = ['dropout', 'mul', 'sub']

def checks(model, mode):
with torch.no_grad():
out = model(torch.ones(10, 10))
if mode == 'train':
# Check that dropout is respected
assert out['dropout'].item() == 0
# Check that control flow dependent on training_mode is respected
assert out['sub'].item() == 100
assert 'add' in out
assert 'mul' not in out
elif mode == 'eval':
# Check that dropout is respected
assert out['dropout'].item() == 1
# Check that control flow dependent on training_mode is respected
assert out['sub'].item() == 0
assert 'mul' in out
assert 'add' not in out

# Starting from train mode
model.train()
fx_model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
# Check that the models stay in their original training state
assert model.training
assert fx_model.training
# Check outputs
checks(fx_model, 'train')
# Check outputs after switching to eval mode
fx_model.eval()
checks(fx_model, 'eval')

# Starting from eval mode
model.eval()
fx_model = self._create_feature_extractor(
model, train_return_nodes=train_return_nodes,
eval_return_nodes=eval_return_nodes)
# Check that the models stay in their original training state
assert not model.training
assert not fx_model.training
# Check outputs
checks(fx_model, 'eval')
# Check outputs after switching to train mode
fx_model.train()
checks(fx_model, 'train')

def test_leaf_module_and_function(self):
class LeafModule(torch.nn.Module):
def forward(self, x):
# This would raise a TypeError if it were not in a leaf module
int(x.shape[0])
return torch.nn.functional.relu(x + 4)

class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 1, 3)
self.leaf_module = LeafModule()

def forward(self, x):
leaf_function(x.shape[0])
x = self.conv(x)
return self.leaf_module(x)

model = self._create_feature_extractor(
TestModule(), return_nodes=['leaf_module'],
tracer_kwargs={'leaf_modules': [LeafModule],
'autowrap_functions': [leaf_function]}).train()

# Check that LeafModule is not in the list of nodes
assert 'relu' not in [str(n) for n in model.graph.nodes]
assert 'leaf_module' in [str(n) for n in model.graph.nodes]

# Check forward
out = model(self.inp)
# And backward
out['leaf_module'].mean().backward()
1 change: 1 addition & 0 deletions torchvision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from . import detection
from . import video
from . import quantization
from . import feature_extraction
Loading