diff --git a/docs/source/feature_extraction.rst b/docs/source/feature_extraction.rst new file mode 100644 index 00000000000..27dad21fa78 --- /dev/null +++ b/docs/source/feature_extraction.rst @@ -0,0 +1,126 @@ +torchvision.models.feature_extraction +===================================== + +.. currentmodule:: torchvision.models.feature_extraction + +Feature extraction utilities let us tap into our models to access intermediate +transformations of our inputs. This could be useful for a variety of +applications in computer vision. Just a few examples are: + +- Visualizing feature maps. +- Extracting features to compute image descriptors for tasks like facial + recognition, copy-detection, or image retrieval. +- Passing selected features to downstream sub-networks for end-to-end training + with a specific task in mind. For example, passing a hierarchy of features + to a Feature Pyramid Network with object detection heads. + +Torchvision provides :func:`create_feature_extractor` for this purpose. +It works by following roughly these steps: + +1. Symbolically tracing the model to get a graphical representation of + how it transforms the input, step by step. +2. Setting the user-selected graph nodes as ouputs. +3. Removing all redundant nodes (anything downstream of the ouput nodes). +4. Generating python code from the resulting graph and bundling that into a + PyTorch module together with the graph itself. + +| + +The `torch.fx documentation `_ +provides a more general and detailed explanation of the above procedure and +the inner workings of the symbolic tracing. + +Here is an example of how we might extract features for MaskRCNN: + +.. code-block:: python + + import torch + from torchvision.models import resnet50 + from torchvision.models.feature_extraction import get_graph_node_names + from torchvision.models.feature_extraction import create_feature_extractor + from torchvision.models.detection.mask_rcnn import MaskRCNN + from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork + + + # To assist you in designing the feature extractor you may want to print out + # the available nodes for resnet50. + m = resnet50() + train_nodes, eval_nodes = get_graph_node_names(resnet50()) + + # The lists returned, are the names of all the graph nodes (in order of + # execution) for the input model traced in train mode and in eval mode + # respectively. You'll find that `train_nodes` and `eval_nodes` are the same + # for this example. But if the model contains control flow that's dependent + # on the training mode, they may be different. + + # To specify the nodes you want to extract, you could select the final node + # that appears in each of the main layers: + return_nodes = { + # node_name: user-specified key for output dict + 'layer1.2.relu_2': 'layer1', + 'layer2.3.relu_2': 'layer2', + 'layer3.5.relu_2': 'layer3', + 'layer4.2.relu_2': 'layer4', + } + + # But `create_feature_extractor` can also accept truncated node specifications + # like "layer1", as it will just pick the last node that's a descendent of + # of the specification. (Tip: be careful with this, especially when a layer + # has multiple outputs. It's not always guaranteed that the last operation + # performed is the one that corresponds to the output you desire. You should + # consult the source code for the input model to confirm.) + return_nodes = { + 'layer1': 'layer1', + 'layer2': 'layer2', + 'layer3': 'layer3', + 'layer4': 'layer4', + } + + # Now you can build the feature extractor. This returns a module whose forward + # method returns a dictionary like: + # { + # 'layer1': ouput of layer 1, + # 'layer2': ouput of layer 2, + # 'layer3': ouput of layer 3, + # 'layer4': ouput of layer 4, + # } + create_feature_extractor(m, return_nodes=return_nodes) + + # Let's put all that together to wrap resnet50 with MaskRCNN + + # MaskRCNN requires a backbone with an attached FPN + class Resnet50WithFPN(torch.nn.Module): + def __init__(self): + super(Resnet50WithFPN, self).__init__() + # Get a resnet50 backbone + m = resnet50() + # Extract 4 main layers (note: you can also provide a list for return + # nodes if the keys and the values are the same) + self.body = create_feature_extractor( + m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4']) + # Dry run to get number of channels for FPN + inp = torch.randn(2, 3, 224, 224) + with torch.no_grad(): + out = self.body(inp) + in_channels_list = [o.shape[1] for o in out.values()] + # Build FPN + self.out_channels = 256 + self.fpn = FeaturePyramidNetwork( + in_channels_list, out_channels=self.out_channels) + + def forward(self, x): + x = self.body(x) + x = self.fpn(x) + return x + + + # Now we can build our model! + model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval() + + +API Reference +------------- + +.. autofunction:: create_feature_extractor + +.. autofunction:: get_graph_node_names \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index b25a85d8617..3e02cd34ad4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,6 +34,7 @@ architectures, and common image transformations for computer vision. datasets io models + feature_extraction ops transforms utils diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 712dccf11a8..d85128a46d0 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,11 +1,253 @@ +from functools import partial +from itertools import chain +import random + import torch +from torchvision import models +import torchvision from torchvision.models.detection.backbone_utils import resnet_fpn_backbone +from torchvision.models.feature_extraction import create_feature_extractor +from torchvision.models.feature_extraction import get_graph_node_names +from torchvision.models._utils import IntermediateLayerGetter import pytest +from common_utils import set_rng_seed + + +def get_available_models(): + # TODO add a registration mechanism to torchvision.models + return [k for k, v in models.__dict__.items() + if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + @pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) def test_resnet_fpn_backbone(backbone_name): x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] + + +# Needed by TestFxFeatureExtraction.test_leaf_module_and_function +def leaf_function(x): + return int(x) + + +class TestFxFeatureExtraction: + inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu') + model_defaults = { + 'num_classes': 1, + 'pretrained': False + } + leaf_modules = [torchvision.ops.StochasticDepth] + + def _create_feature_extractor(self, *args, **kwargs): + """ + Apply leaf modules + """ + tracer_kwargs = {} + if 'tracer_kwargs' not in kwargs: + tracer_kwargs = {'leaf_modules': self.leaf_modules} + else: + tracer_kwargs = kwargs.pop('tracer_kwargs') + return create_feature_extractor( + *args, **kwargs, + tracer_kwargs=tracer_kwargs, + suppress_diff_warning=True) + + def _get_return_nodes(self, model): + set_rng_seed(0) + exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk'] + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': self.leaf_modules}, + suppress_diff_warning=True) + # Get rid of any nodes that don't return tensors as they cause issues + # when testing backward pass. + train_nodes = [n for n in train_nodes + if not any(x in n for x in exclude_nodes_filter)] + eval_nodes = [n for n in eval_nodes + if not any(x in n for x in exclude_nodes_filter)] + return random.sample(train_nodes, 10), random.sample(eval_nodes, 10) + + @pytest.mark.parametrize('model_name', get_available_models()) + def test_build_fx_feature_extractor(self, model_name): + set_rng_seed(0) + model = models.__dict__[model_name](**self.model_defaults).eval() + train_return_nodes, eval_return_nodes = self._get_return_nodes(model) + # Check that it works with both a list and dict for return nodes + self._create_feature_extractor( + model, train_return_nodes={v: v for v in train_return_nodes}, + eval_return_nodes=eval_return_nodes) + self._create_feature_extractor( + model, train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes) + # Check must specify return nodes + with pytest.raises(AssertionError): + self._create_feature_extractor(model) + # Check return_nodes and train_return_nodes / eval_return nodes + # mutual exclusivity + with pytest.raises(AssertionError): + self._create_feature_extractor( + model, return_nodes=train_return_nodes, + train_return_nodes=train_return_nodes) + # Check train_return_nodes / eval_return nodes must both be specified + with pytest.raises(AssertionError): + self._create_feature_extractor( + model, train_return_nodes=train_return_nodes) + # Check invalid node name raises ValueError + with pytest.raises(ValueError): + # First just double check that this node really doesn't exist + if not any(n.startswith('l') or n.startswith('l.') for n + in chain(train_return_nodes, eval_return_nodes)): + self._create_feature_extractor( + model, train_return_nodes=['l'], eval_return_nodes=['l']) + else: # otherwise skip this check + raise ValueError + + @pytest.mark.parametrize('model_name', get_available_models()) + def test_forward_backward(self, model_name): + model = models.__dict__[model_name](**self.model_defaults).train() + train_return_nodes, eval_return_nodes = self._get_return_nodes(model) + model = self._create_feature_extractor( + model, train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes) + out = model(self.inp) + sum([o.mean() for o in out.values()]).backward() + + def test_feature_extraction_methods_equivalence(self): + model = models.resnet18(**self.model_defaults).eval() + return_layers = { + 'layer1': 'layer1', + 'layer2': 'layer2', + 'layer3': 'layer3', + 'layer4': 'layer4' + } + + ilg_model = IntermediateLayerGetter( + model, return_layers).eval() + fx_model = self._create_feature_extractor(model, return_layers) + + # Check that we have same parameters + for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), + fx_model.named_parameters()): + assert n1 == n2 + assert p1.equal(p2) + + # And that ouputs match + with torch.no_grad(): + ilg_out = ilg_model(self.inp) + fgn_out = fx_model(self.inp) + assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys())) + for k in ilg_out.keys(): + assert ilg_out[k].equal(fgn_out[k]) + + @pytest.mark.parametrize('model_name', get_available_models()) + def test_jit_forward_backward(self, model_name): + set_rng_seed(0) + model = models.__dict__[model_name](**self.model_defaults).train() + train_return_nodes, eval_return_nodes = self._get_return_nodes(model) + model = self._create_feature_extractor( + model, train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes) + model = torch.jit.script(model) + fgn_out = model(self.inp) + sum([o.mean() for o in fgn_out.values()]).backward() + + def test_train_eval(self): + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(p=1.) + + def forward(self, x): + x = x.mean() + x = self.dropout(x) # dropout + if self.training: + x += 100 # add + else: + x *= 0 # mul + x -= 0 # sub + return x + + model = TestModel() + + train_return_nodes = ['dropout', 'add', 'sub'] + eval_return_nodes = ['dropout', 'mul', 'sub'] + + def checks(model, mode): + with torch.no_grad(): + out = model(torch.ones(10, 10)) + if mode == 'train': + # Check that dropout is respected + assert out['dropout'].item() == 0 + # Check that control flow dependent on training_mode is respected + assert out['sub'].item() == 100 + assert 'add' in out + assert 'mul' not in out + elif mode == 'eval': + # Check that dropout is respected + assert out['dropout'].item() == 1 + # Check that control flow dependent on training_mode is respected + assert out['sub'].item() == 0 + assert 'mul' in out + assert 'add' not in out + + # Starting from train mode + model.train() + fx_model = self._create_feature_extractor( + model, train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes) + # Check that the models stay in their original training state + assert model.training + assert fx_model.training + # Check outputs + checks(fx_model, 'train') + # Check outputs after switching to eval mode + fx_model.eval() + checks(fx_model, 'eval') + + # Starting from eval mode + model.eval() + fx_model = self._create_feature_extractor( + model, train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes) + # Check that the models stay in their original training state + assert not model.training + assert not fx_model.training + # Check outputs + checks(fx_model, 'eval') + # Check outputs after switching to train mode + fx_model.train() + checks(fx_model, 'train') + + def test_leaf_module_and_function(self): + class LeafModule(torch.nn.Module): + def forward(self, x): + # This would raise a TypeError if it were not in a leaf module + int(x.shape[0]) + return torch.nn.functional.relu(x + 4) + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 1, 3) + self.leaf_module = LeafModule() + + def forward(self, x): + leaf_function(x.shape[0]) + x = self.conv(x) + return self.leaf_module(x) + + model = self._create_feature_extractor( + TestModule(), return_nodes=['leaf_module'], + tracer_kwargs={'leaf_modules': [LeafModule], + 'autowrap_functions': [leaf_function]}).train() + + # Check that LeafModule is not in the list of nodes + assert 'relu' not in [str(n) for n in model.graph.nodes] + assert 'leaf_module' in [str(n) for n in model.graph.nodes] + + # Check forward + out = model(self.inp) + # And backward + out['leaf_module'].mean().backward() diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 3c1519c1b42..e57f4773c8c 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -13,3 +13,4 @@ from . import detection from . import video from . import quantization +from . import feature_extraction diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py new file mode 100644 index 00000000000..7d5d184cd97 --- /dev/null +++ b/torchvision/models/feature_extraction.py @@ -0,0 +1,508 @@ +from typing import Dict, Callable, List, Union, Optional, Tuple +from collections import OrderedDict +import warnings +import re +from copy import deepcopy +from itertools import chain + +import torch +from torch import nn +from torch import fx +from torch.fx.graph_module import _copy_attr + + +__all__ = ['create_feature_extractor', 'get_graph_node_names'] + + +class LeafModuleAwareTracer(fx.Tracer): + """ + An fx.Tracer that allows the user to specify a set of leaf modules, ie. + modules that are not to be traced through. The resulting graph ends up + having single nodes referencing calls to the leaf modules' forward methods. + """ + def __init__(self, *args, **kwargs): + self.leaf_modules = {} + if 'leaf_modules' in kwargs: + leaf_modules = kwargs.pop('leaf_modules') + self.leaf_modules = leaf_modules + super(LeafModuleAwareTracer, self).__init__(*args, **kwargs) + + def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool: + if isinstance(m, tuple(self.leaf_modules)): + return True + return super().is_leaf_module(m, module_qualname) + + +class NodePathTracer(LeafModuleAwareTracer): + """ + NodePathTracer is an FX tracer that, for each operation, also records the + name of the Node from which the operation originated. A node name here is + a `.` seperated path walking the hierarchy from top level module down to + leaf operation or leaf module. The name of the top level module is not + included as part of the node name. For example, if we trace a module whose + forward method applies a ReLU module, the name for that node will simply + be 'relu'. + + Some notes on the specifics: + - Nodes are recorded to `self.node_to_qualname` which is a dictionary + mapping a given Node object to its node name. + - Nodes are recorded in the order which they are executed during + tracing. + - When a duplicate node name is encountered, a suffix of the form + _{int} is added. The counter starts from 1. + """ + def __init__(self, *args, **kwargs): + super(NodePathTracer, self).__init__(*args, **kwargs) + # Track the qualified name of the Node being traced + self.current_module_qualname = '' + # A map from FX Node to the qualified name\# + # NOTE: This is loosely like the "qualified name" mentioned in the + # torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted + # for the purposes of the torchvision feature extractor + self.node_to_qualname = OrderedDict() + + def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): + """ + Override of `fx.Tracer.call_module` + This override: + 1) Stores away the qualified name of the caller for restoration later + 2) Adds the qualified name of the caller to + `current_module_qualname` for retrieval by `create_proxy` + 3) Once a leaf module is reached, calls `create_proxy` + 4) Restores the caller's qualified name into current_module_qualname + """ + old_qualname = self.current_module_qualname + try: + module_qualname = self.path_of_module(m) + self.current_module_qualname = module_qualname + if not self.is_leaf_module(m, module_qualname): + out = forward(*args, **kwargs) + return out + return self.create_proxy('call_module', module_qualname, args, kwargs) + finally: + self.current_module_qualname = old_qualname + + def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, + name=None, type_expr=None, *_) -> fx.proxy.Proxy: + """ + Override of `Tracer.create_proxy`. This override intercepts the recording + of every operation and stores away the current traced module's qualified + name in `node_to_qualname` + """ + proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) + self.node_to_qualname[proxy.node] = self._get_node_qualname( + self.current_module_qualname, proxy.node) + return proxy + + def _get_node_qualname( + self, module_qualname: str, node: fx.node.Node) -> str: + node_qualname = module_qualname + if node.op == 'call_module': + # Node terminates in a leaf module so the module_qualname is a + # complete description of the node + for existing_qualname in reversed(self.node_to_qualname.values()): + # Check to see if existing_qualname is of the form + # {node_qualname} or {node_qualname}_{int} + if re.match(rf'{node_qualname}(_[0-9]+)?$', + existing_qualname) is not None: + postfix = existing_qualname.replace(node_qualname, '') + if len(postfix): + # Existing_qualname is of the form {node_qualname}_{int} + next_index = int(postfix[1:]) + 1 + else: + # existing_qualname is of the form {node_qualname} + next_index = 1 + node_qualname += f'_{next_index}' + break + pass + else: + # Node terminates in non- leaf module so the node name needs to be + # appended + if len(node_qualname) > 0: + # Only append '.' if we are deeper than the top level module + node_qualname += '.' + node_qualname += str(node) + return node_qualname + + +def _is_subseq(x, y): + """Check if y is a subseqence of x + https://stackoverflow.com/a/24017747/4391249 + """ + iter_x = iter(x) + return all(any(x_item == y_item for x_item in iter_x) for y_item in y) + + +def _warn_graph_differences( + train_tracer: NodePathTracer, eval_tracer: NodePathTracer): + """ + Utility function for warning the user if there are differences between + the train graph nodes and the eval graph nodes. + """ + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + + if len(train_nodes) == len(eval_nodes) and all( + t == e for t, e in zip(train_nodes, eval_nodes)): + return + + suggestion_msg = ( + "When choosing nodes for feature extraction, you may need to specify " + "output nodes for train and eval mode separately.") + + if _is_subseq(train_nodes, eval_nodes): + msg = ("NOTE: The nodes obtained by tracing the model in eval mode " + "are a subsequence of those obtained in train mode. ") + elif _is_subseq(eval_nodes, train_nodes): + msg = ("NOTE: The nodes obtained by tracing the model in train mode " + "are a subsequence of those obtained in eval mode. ") + else: + msg = ("The nodes obtained by tracing the model in train mode " + "are different to those obtained in eval mode. ") + warnings.warn(msg + suggestion_msg) + + +def get_graph_node_names( + model: nn.Module, tracer_kwargs: Dict = {}, + suppress_diff_warning: bool = False) -> Tuple[List[str], List[str]]: + """ + Dev utility to return node names in order of execution. See note on node + names under :func:`create_feature_extractor`. Useful for seeing which node + names are available for feature extraction. There are two reasons that + node names can't easily be read directly from the code for a model: + + 1. Not all submodules are traced through. Modules from `torch.nn` all + fall within this category. + 2. Nodes representing the repeated application of the same operation + or leaf module get a `_{counter}` postfix. + + The model is traced twice: once in train mode, and once in eval mode. Both + sets of nodes are returned. + + Args: + model (nn.Module): model for which we'd like to print node names + tracer_kwargs (dict, optional): a dictionary of keywork arguments for + `NodePathTracer` (they are eventually passed onto + `torch.fx.Tracer`). + suppress_diff_warning (bool, optional): whether to suppress a warning + when there are discrepancies between the train and eval version of + the graph. Defaults to False. + + Returns: + tuple(list, list): a list of node names from tracing the model in + train mode, and another from tracing the model in eval mode. + + Examples:: + + >>> model = torchvision.models.resnet18() + >>> train_nodes, eval_nodes = get_graph_node_names(model) + """ + is_training = model.training + train_tracer = NodePathTracer(**tracer_kwargs) + train_tracer.trace(model.train()) + eval_tracer = NodePathTracer(**tracer_kwargs) + eval_tracer.trace(model.eval()) + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + if not suppress_diff_warning: + _warn_graph_differences(train_tracer, eval_tracer) + # Restore training state + model.train(is_training) + return train_nodes, eval_nodes + + +class DualGraphModule(fx.GraphModule): + """ + A derivative of `fx.GraphModule`. Differs in the following ways: + - Requires a train and eval version of the underlying graph + - Copies submodules according to the nodes of both train and eval graphs. + - Calling train(mode) switches between train graph and eval graph. + """ + def __init__(self, + root: torch.nn.Module, + train_graph: fx.Graph, + eval_graph: fx.Graph, + class_name: str = 'GraphModule'): + """ + Args: + root (nn.Module): module from which the copied module hierarchy is + built + train_graph (fx.Graph): the graph that should be used in train mode + eval_graph (fx.Graph): the graph that should be used in eval mode + """ + super(fx.GraphModule, self).__init__() + + self.__class__.__name__ = class_name + + self.train_graph = train_graph + self.eval_graph = eval_graph + + # Copy all get_attr and call_module ops (indicated by BOTH train and + # eval graphs) + for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): + if node.op in ['get_attr', 'call_module']: + assert isinstance(node.target, str) + _copy_attr(root, self, node.target) + + # train mode by default + self.train() + self.graph = train_graph + + # (borrowed from fx.GraphModule): + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \ + "Train mode and eval mode should use the same tracer class" + self._tracer_cls = None + if self.graph._tracer_cls and '' not in self.graph._tracer_cls.__qualname__: + self._tracer_cls = self.graph._tracer_cls + + def train(self, mode=True): + """ + Swap out the graph depending on the selected training mode. + NOTE this should be safe when calling model.eval() because that just + calls this with mode == False. + """ + # NOTE: Only set self.graph if the current graph is not the desired + # one. This saves us from recompiling the graph where not necessary. + if mode and not self.training: + self.graph = self.train_graph + elif not mode and self.training: + self.graph = self.eval_graph + return super().train(mode=mode) + + +def create_feature_extractor( + model: nn.Module, + return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + tracer_kwargs: Dict = {}, + suppress_diff_warning: bool = False) -> fx.GraphModule: + """ + Creates a new graph module that returns intermediate nodes from a given + model as dictionary with user specified keys as strings, and the requested + outputs as values. This is achieved by re-writing the computation graph of + the model via FX to return the desired nodes as outputs. All unused nodes + are removed, together with their corresponding parameters. + + A note on node specification: For the purposes of this feature extraction + utility, a node name is specified as a `.` seperated path walking the + hierarchy from top level module down to leaf operation or leaf module. For + instance `blocks.5.3.bn1`. The keys of the `return_nodes` argument should + point to either a node's name, or some truncated version of it. For + example, one could provide `blocks.5` as a key, and the last node with + that prefix will be selected. :func:`get_graph_node_names` is a useful + helper function for getting a list of node names of a model. + + Not all models will be FX traceable, although with some massaging they can + be made to cooperate. Here's a (not exhaustive) list of tips: + + - If you don't need to trace through a particular, problematic + sub-module, turn it into a "leaf module" by passing a list of + `leaf_modules` as one of the `tracer_kwargs` (see example below). It + will not be traced through, but rather, the resulting graph will + hold a reference to that module's forward method. + - Likewise, you may turn functions into leaf functions by passing a + list of `autowrap_functions` as one of the `tracer_kwargs` (see + example below). + - Some inbuilt Python functions can be problematic. For instance, + `int` will raise an error during tracing. You may wrap them in your + own function and then pass that in `autowrap_functions` as one of + the `tracer_kwargs`. + + For further information on FX see the + `torch.fx documentation `_. + + Args: + model (nn.Module): model on which we will extract the features + return_nodes (list or dict, optional): either a `List` or a `Dict` + containing the names (or partial names - see note above) + of the nodes for which the activations will be returned. If it is + a `Dict`, the keys are the node names, and the values + are the user-specified keys for the graph module's returned + dictionary. If it is a `List`, it is treated as a `Dict` mapping + node specification strings directly to output names. In the case + that `train_return_nodes` and `eval_return_nodes` are specified, + this should not be specified. + train_return_nodes (list or dict, optional): similar to + `return_nodes`. This can be used if the return nodes + for train mode are different than those from eval mode. + If this is specified, `eval_return_nodes` must also be specified, + and `return_nodes` should not be specified. + eval_return_nodes (list or dict, optional): similar to + `return_nodes`. This can be used if the return nodes + for train mode are different than those from eval mode. + If this is specified, `train_return_nodes` must also be specified, + and `return_nodes` should not be specified. + tracer_kwargs (dict, optional): a dictionary of keywork arguments for + `NodePathTracer` (which passes them onto it's parent class + `torch.fx.Tracer`). + suppress_diff_warning (bool, optional): whether to suppress a warning + when there are discrepancies between the train and eval version of + the graph. Defaults to False. + + Examples:: + + >>> # Feature extraction with resnet + >>> model = torchvision.models.resnet18() + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> model = create_feature_extractor( + >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = model(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + + >>> # Specifying leaf modules and leaf functions + >>> def leaf_function(x): + >>> # This would raise a TypeError if traced through + >>> return int(x) + >>> + >>> class LeafModule(torch.nn.Module): + >>> def forward(self, x): + >>> # This would raise a TypeError if traced through + >>> int(x.shape[0]) + >>> return torch.nn.functional.relu(x + 4) + >>> + >>> class MyModule(torch.nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.conv = torch.nn.Conv2d(3, 1, 3) + >>> self.leaf_module = LeafModule() + >>> + >>> def forward(self, x): + >>> leaf_function(x.shape[0]) + >>> x = self.conv(x) + >>> return self.leaf_module(x) + >>> + >>> model = create_feature_extractor( + >>> MyModule(), return_nodes=['leaf_module'], + >>> tracer_kwargs={'leaf_modules': [LeafModule], + >>> 'autowrap_functions': [leaf_function]}) + + """ + is_training = model.training + + assert any(arg is not None for arg in [ + return_nodes, train_return_nodes, eval_return_nodes]), ( + "Either `return_nodes` or `train_return_nodes` and " + "`eval_return_nodes` together, should be specified") + + assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \ + ("If any of `train_return_nodes` and `eval_return_nodes` are " + "specified, then both should be specified") + + assert ((return_nodes is None) ^ (train_return_nodes is None)), \ + ("If `train_return_nodes` and `eval_return_nodes` are specified, " + "then both should be specified") + + # Put *_return_nodes into Dict[str, str] format + def to_strdict(n) -> Dict[str, str]: + if isinstance(n, list): + return {str(i): str(i) for i in n} + return {str(k): str(v) for k, v in n.items()} + + if train_return_nodes is None: + return_nodes = to_strdict(return_nodes) + train_return_nodes = deepcopy(return_nodes) + eval_return_nodes = deepcopy(return_nodes) + else: + train_return_nodes = to_strdict(train_return_nodes) + eval_return_nodes = to_strdict(eval_return_nodes) + + # Repeat the tracing and graph rewriting for train and eval mode + tracers = {} + graphs = {} + mode_return_nodes: Dict[str, Dict[str, str]] = { + 'train': train_return_nodes, + 'eval': eval_return_nodes + } + for mode in ['train', 'eval']: + if mode == 'train': + model.train() + elif mode == 'eval': + model.eval() + + # Instantiate our NodePathTracer and use that to trace the model + tracer = NodePathTracer(**tracer_kwargs) + graph = tracer.trace(model) + + name = model.__class__.__name__ if isinstance( + model, nn.Module) else model.__name__ + graph_module = fx.GraphModule(tracer.root, graph, name) + + available_nodes = list(tracer.node_to_qualname.values()) + # FIXME We don't know if we should expect this to happen + assert len(set(available_nodes)) == len(available_nodes), \ + "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" + # Check that all outputs in return_nodes are present in the model + for query in mode_return_nodes[mode].keys(): + # To check if a query is available we need to check that at least + # one of the available names starts with it up to a . + if not any([re.match(rf'^{query}(\.|$)', n) is not None + for n in available_nodes]): + raise ValueError( + f"node: '{query}' is not present in model. Hint: use " + "`get_graph_node_names` to make sure the " + "`return_nodes` you specified are present. It may even " + "be that you need to specify `train_return_nodes` and " + "`eval_return_nodes` separately.") + + # Remove existing output nodes (train mode) + orig_output_nodes = [] + for n in reversed(graph_module.graph.nodes): + if n.op == 'output': + orig_output_nodes.append(n) + assert len(orig_output_nodes) + for n in orig_output_nodes: + graph_module.graph.erase_node(n) + + # Find nodes corresponding to return_nodes and make them into output_nodes + nodes = [n for n in graph_module.graph.nodes] + output_nodes = OrderedDict() + for n in reversed(nodes): + module_qualname = tracer.node_to_qualname.get(n) + if module_qualname is None: + # NOTE - Know cases where this happens: + # - Node representing creation of a tensor constant - probably + # not interesting as a return node + # - When packing outputs into a named tuple like in InceptionV3 + continue + for query in mode_return_nodes[mode]: + depth = query.count('.') + if '.'.join(module_qualname.split('.')[:depth + 1]) == query: + output_nodes[mode_return_nodes[mode][query]] = n + mode_return_nodes[mode].pop(query) + break + output_nodes = OrderedDict(reversed(list(output_nodes.items()))) + + # And add them in the end of the graph + with graph_module.graph.inserting_after(nodes[-1]): + graph_module.graph.output(output_nodes) + + # Remove unused modules / parameters + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + + # Keep track of the tracer and graph so we can choose the main one + tracers[mode] = tracer + graphs[mode] = graph + + # Warn user if there are any discrepancies between the graphs of the + # train and eval modes + if not suppress_diff_warning: + _warn_graph_differences(tracers['train'], tracers['eval']) + + # Build the final graph module + graph_module = DualGraphModule( + model, graphs['train'], graphs['eval'], class_name=name) + + # Restore original training mode + model.train(is_training) + graph_module.train(is_training) + + return graph_module