From b84f8bd368f623deb70e4da55178ef8edffd024b Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 21 Aug 2021 15:37:13 +0100 Subject: [PATCH 01/18] add fx feature extraction util --- torchvision/models/_utils.py | 288 ++++++++++++++++++++++++++++++++++- 1 file changed, 287 insertions(+), 1 deletion(-) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index df5ab9a044c..2dde075cc6a 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,7 +1,14 @@ +from typing import Any, Dict, Callable, List, Union from collections import OrderedDict +import warnings +import re +from pprint import pprint +from inspect import ismethod +import torch +from torch import Tensor from torch import nn -from typing import Dict +from torch import fx class IntermediateLayerGetter(nn.ModuleDict): @@ -64,3 +71,282 @@ def forward(self, x): out_name = self.return_layers[name] out[out_name] = x return out + + +class NodePathTracer(fx.Tracer): + """ + NodePathTracer is an FX tracer that, for each operation, also records the + qualified name of the Node from which the operation originated. A + qualified name here is a `.` seperated path walking the hierarchy from top + level module down to leaf operation or leaf module. The name of the top + level module is not included as part of the qualified name. For example, + if we trace a module who's forward method applies a ReLU module, the + qualified name for that node will simply be 'relu'. + + Some notes on the specifics: + - Nodes are recorded to `self.node_to_qualname` which is a dictionary + mapping a given Node object to its qualified name. + - Nodes are recorded in the order which they are executed during + tracing. + - When a duplicate qualified name is encountered, a suffix of the form + _{int} is added. The counter starts from 1. + """ + def __init__(self, *args, **kwargs): + super(NodePathTracer, self).__init__(*args, **kwargs) + # Track the qualified name of the Node being traced + self.current_module_qualname = '' + # A map from FX Node to the qualified name + self.node_to_qualname = OrderedDict() + + def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): + """ + Override of `fx.Tracer.call_module` + This override: + 1) Stores away the qualified name of the caller for restoration later + 2) Adds the qualified name of the caller to + `current_module_qualname` for retrieval by `create_proxy` + 3) Once a leaf module is reached, calls `create_proxy` + 4) Restores the caller's qualified name into current_module_qualname + """ + old_qualname = self.current_module_qualname + try: + module_qualname = self.path_of_module(m) + self.current_module_qualname = module_qualname + if not self.is_leaf_module(m, module_qualname): + out = forward(*args, **kwargs) + return out + return self.create_proxy('call_module', module_qualname, args, kwargs) + finally: + self.current_module_qualname = old_qualname + + def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, + name=None, type_expr=None) -> fx.proxy.Proxy: + """ + Override of `Tracer.create_proxy`. This override intercepts the recording + of every operation and stores away the current traced module's qualified + name in `node_to_qualname` + """ + proxy = super().create_proxy(kind, target, args, kwargs, name, type_expr) + self.node_to_qualname[proxy.node] = self._get_node_qualname( + self.current_module_qualname, proxy.node) + return proxy + + def _get_node_qualname( + self, module_qualname: str, node: fx.node.Node) -> str: + node_qualname = module_qualname + if node.op == 'call_module': + # Node terminates in a leaf module so the module_qualname is a + # complete description of the node + for existing_qualname in reversed(self.node_to_qualname.values()): + # Check to see if existing_qualname is of the form + # {node_qualname} or {node_qualname}_{int} + if re.match(rf'{node_qualname}(_[0-9]+)?$', + existing_qualname) is not None: + postfix = existing_qualname.replace(node_qualname, '') + if len(postfix): + # Existing_qualname is of the form {node_qualname}_{int} + next_index = int(postfix[1:]) + 1 + else: + # existing_qualname is of the form {node_qualname} + next_index = 1 + node_qualname += f'_{next_index}' + break + else: + # Node terminates in non- leaf module so the node name needs to be + # appended + if len(node_qualname) > 0: + # Only append '.' if we are deeper than the top level module + node_qualname += '.' + node_qualname += str(node) + return node_qualname + + +def print_graph_node_qualified_names( + model: nn.Module, tracer_kwargs: Dict = {}): + """ + Dev utility to prints nodes in order of execution. Useful for choosing + nodes for a FeatureGraphNet design. There are two reasons that qualified + node names can't easily be read directly from the code for a model: + 1. Not all submodules are traced through. Modules from `torch.nn` all + fall within this category. + 2. Node qualified names that occur more than once in the graph get a + `_{counter}` postfix. + + Args: + model (nn.Module): model on which we will extract the features + tracer_kwargs (Dict): a dictionary of keywork arguments for + `NodePathTracer` (which passes them onto it's parent class + `torch.fx.Tracer`). + """ + tracer = NodePathTracer(**tracer_kwargs) + tracer.trace(model) + pprint(list(tracer.node_to_qualname.values())) + + +def build_feature_graph_net( + model: nn.Module, + return_nodes: Union[List[str], Dict[str, str]], + tracer_kwargs: Dict = {}) -> fx.GraphModule: + """ + Creates a new graph module that returns intermediate nodes from a given + model as dictionary with user specified keys as strings, and the requested + outputs as values. This is achieved by re-writing the computation graph of + the model via FX to return the desired nodes as outputs. All unused nodes + are removed, together with their corresponding parameters. + + A note on node specification: A node qualified name is specified as a `.` + seperated path walking the hierarchy from top level module down to leaf + operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the + `return_nodes` argument should point to either a node's qualified name, + or some truncated version of it. For example, one could provide `blocks.5` + as a key, and the last node with that prefix will be selected. + `print_graph_node_qualified_names` is a useful helper function for getting + a list of qualified names of a model. + + Args: + model (nn.Module): model on which we will extract the features + return_nodes (Union[List[name], Dict[name, new_name]])): either a list + or a dict containing the names (or partial names - see note above) + of the nodes for which the activations will be returned. If it is + a `Dict`, the keys are the qualified node names, and the values + are the user-specified keys for the graph module's returned + dictionary. If it is a `List`, it is treated as a `Dict` mapping + node specification strings directly to output names. + tracer_kwargs (Dict): a dictionary of keywork arguments for + `NodePathTracer` (which passes them onto it's parent class + `torch.fx.Tracer`). + + NOTE: Static control flow will be frozen into place for the resulting + `GraphModule`. Among other consequences, this means that control flow + that relies on whether the model is in train or eval mode will be + frozen into place (except for leaf modules which are not traced + through). Therefore, calling `.train()` or `.eval()` on the resulting + `GraphModule` may not have all the desired effects. + + Examples:: + + >>> model = torchvision.models.resnet18() + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> graph_module = torchvision.models._utils.build_feature_graph_net(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = graph_module(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + + """ + if isinstance(return_nodes, list): + return_nodes = {n: n for n in return_nodes} + return_nodes = {str(k): str(v) for k, v in return_nodes.items()} + + # Instantiate our NodePathTracer and use that to trace the model + tracer = NodePathTracer(**tracer_kwargs) + graph = tracer.trace(model) + + name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ + graph_module = fx.GraphModule(tracer.root, graph, name) + + available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] + # FIXME We don't know if we should expect this to happen + assert len(set(available_nodes)) == len(available_nodes), \ + "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" + # Check that all outputs in return_nodes are present in the model + for query in return_nodes.keys(): + if not any([m.startswith(query) for m in available_nodes]): + raise ValueError(f"return_node: {query} is not present in model") + + # Remove existing output nodes + orig_output_node = None + for n in reversed(graph_module.graph.nodes): + if n.op == "output": + orig_output_node = n + assert orig_output_node + # And remove it + graph_module.graph.erase_node(orig_output_node) + # Find nodes corresponding to return_nodes and make them into output_nodes + nodes = [n for n in graph_module.graph.nodes] + output_nodes = OrderedDict() + for n in reversed(nodes): + if 'tensor_constant' in str(n): + # NOTE Without this control flow we would get a None value for + # `module_qualname = tracer.node_to_qualname.get(n)`. + # On the other hand, we can safely assume that we'll never need to + # get this as an interesting intermediate node. + continue + module_qualname = tracer.node_to_qualname.get(n) + for query in return_nodes: + depth = query.count('.') + if '.'.join(module_qualname.split('.')[:depth + 1]) == query: + output_nodes[return_nodes[query]] = n + return_nodes.pop(query) + break + output_nodes = OrderedDict(reversed(list(output_nodes.items()))) + + # And add them in the end of the graph + with graph_module.graph.inserting_after(nodes[-1]): + graph_module.graph.output(output_nodes) + + # Remove unused modules / parameters + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + graph_module = fx.GraphModule(graph_module, graph_module.graph, name) + return graph_module + + +class FeatureGraphNet(nn.Module): + """ + Wrap a `GraphModule` from `build_feature_graph_net` while also keeping the + original model's non-parameter properties for reference. The original + model's paremeters are discarded. + + See `build_feature_graph_net` docstring for more information. + + NOTE: This puts the input model into eval mode prior to tracing. This + means that any control flow dependent on the model being in train mode + will be lost. + """ + def __init__(self, model: nn.Module, + return_nodes: Union[List[str], Dict[str, str]], + tracer_kwargs: Dict = {}): + """ + Args: + model (nn.Module): model on which we will extract the features + return_nodes (Union[List[name], Dict[name, new_name]])): either a list + or a dict containing the names (or partial names - see note above) + of the nodes for which the activations will be returned. If it is + a `Dict`, the keys are the qualified node names, and the values + are the user-specified keys for the graph module's returned + dictionary. If it is a `List`, it is treated as a `Dict` mapping + node specification strings directly to output names. + tracer_kwargs (Dict): a dictionary of keywork arguments for + `NodePathTracer` (which passes them onto it's parent class + `torch.fx.Tracer`). + """ + super(FeatureGraphNet, self).__init__() + model.eval() + self.graph_module = build_feature_graph_net( + model, return_nodes, tracer_kwargs) + # Keep non-parameter model properties for reference + for attr_str in model.__dir__(): + attr = getattr(model, attr_str) + if (not attr_str.startswith('_') + and attr_str not in self.__dir__() + and not ismethod(attr) + and not isinstance(attr, (nn.Module, nn.Parameter))): + setattr(self, attr_str, attr) + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + return self.graph_module(x) + + def train(self, mode: bool = True): + """ + NOTE: This also covers `self.eval()` as that just calls self.train(False) + """ + if mode: + warnings.warn( + "Setting a FeatureGraphNet to training mode won't necessarily" + " have the desired effect. Control flow depending on" + " `self.training` will follow the `False` path. See" + " `FeatureGraphNet` doc-string for more details.") + + super(FeatureGraphNet, self).train(mode) From b0e64029eaa947228cfb6fdb6f288f799f7a6d96 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 28 Aug 2021 16:47:32 +0100 Subject: [PATCH 02/18] Make it possible to use train and eval mode --- torchvision/models/_utils.py | 357 +++++++++++++++++++++++------------ 1 file changed, 236 insertions(+), 121 deletions(-) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 2dde075cc6a..70ce60607cd 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -1,14 +1,16 @@ -from typing import Any, Dict, Callable, List, Union +from typing import Any, Dict, Callable, List, Union, Optional from collections import OrderedDict import warnings import re from pprint import pprint from inspect import ismethod +from copy import deepcopy +from itertools import chain import torch -from torch import Tensor from torch import nn from torch import fx +from torch.fx.graph_module import _copy_attr class IntermediateLayerGetter(nn.ModuleDict): @@ -120,7 +122,7 @@ def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): self.current_module_qualname = old_qualname def create_proxy(self, kind: str, target: fx.node.Target, args, kwargs, - name=None, type_expr=None) -> fx.proxy.Proxy: + name=None, type_expr=None, *_) -> fx.proxy.Proxy: """ Override of `Tracer.create_proxy`. This override intercepts the recording of every operation and stores away the current traced module's qualified @@ -151,6 +153,7 @@ def _get_node_qualname( next_index = 1 node_qualname += f'_{next_index}' break + pass else: # Node terminates in non- leaf module so the node name needs to be # appended @@ -161,6 +164,43 @@ def _get_node_qualname( return node_qualname +def _is_subseq(x, y): + """Check if y is a subseqence of x + https://stackoverflow.com/a/24017747/4391249 + """ + iter_x = iter(x) + return all(any(x_item == y_item for x_item in iter_x) for y_item in y) + + +def _warn_graph_differences( + train_tracer: NodePathTracer, eval_tracer: NodePathTracer): + """ + Utility function for warning the user if there are differences between + the train graph and the eval graph. + """ + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + + if len(train_nodes) == len(eval_nodes) and [ + t == e for t, e in zip(train_nodes, eval_nodes)]: + return + + suggestion_msg = ( + "When choosing nodes for feature extraction, you may need to specify " + "output nodes for train and eval mode separately.") + + if _is_subseq(train_nodes, eval_nodes): + msg = ("NOTE: The nodes obtained by tracing the model in eval mode " + "are a subsequence of those obtained in train mode. ") + elif _is_subseq(eval_nodes, train_nodes): + msg = ("NOTE: The nodes obtained by tracing the model in train mode " + "are a subsequence of those obtained in eval mode. ") + else: + msg = ("The nodes obtained by tracing the model in train mode " + "are different to those obtained in eval mode. ") + warnings.warn(msg + suggestion_msg) + + def print_graph_node_qualified_names( model: nn.Module, tracer_kwargs: Dict = {}): """ @@ -171,6 +211,9 @@ def print_graph_node_qualified_names( fall within this category. 2. Node qualified names that occur more than once in the graph get a `_{counter}` postfix. + The model will be traced twice: once in train mode, and once in eval mode. + If there are discrepancies between the graphs produced, both sets will + be printed and the user will be warned. Args: model (nn.Module): model on which we will extract the features @@ -178,14 +221,93 @@ def print_graph_node_qualified_names( `NodePathTracer` (which passes them onto it's parent class `torch.fx.Tracer`). """ - tracer = NodePathTracer(**tracer_kwargs) - tracer.trace(model) - pprint(list(tracer.node_to_qualname.values())) + train_tracer = NodePathTracer(**tracer_kwargs) + train_tracer.trace(model.train()) + eval_tracer = NodePathTracer(**tracer_kwargs) + eval_tracer.trace(model.eval()) + train_nodes = list(train_tracer.node_to_qualname.values()) + eval_nodes = list(eval_tracer.node_to_qualname.values()) + if len(train_nodes) == len(eval_nodes) and [ + t == e for t, e in zip(train_nodes, eval_nodes)]: + # Nodes are aligned in train vs eval mode + pprint(list(train_tracer.node_to_qualname.values())) + return + print("Nodes from train mode:") + pprint(list(train_tracer.node_to_qualname.values())) + print() + print("Nodes from eval mode:") + pprint(list(eval_tracer.node_to_qualname.values())) + print() + _warn_graph_differences(train_tracer, eval_tracer) + + +class DualGraphModule(fx.GraphModule): + """ + A derivative of `fx.GraphModule`. Differs in the following ways: + - Requires a train and eval version of the underlying graph + - Copies submodules according to the nodes of both train and eval graphs. + - Calling train(mode) switches between train graph and eval graph. + """ + def __init__(self, + root: torch.nn.Module, + train_graph: fx.Graph, + eval_graph: fx.Graph, + class_name: str = 'GraphModule'): + """ + Args: + root (torch.nn.Module): module from which the copied module + hierarchy is built + train_graph (Graph): the graph that should be used in train mode + eval_graph (Graph): the graph that should be used in eval mode + """ + super(fx.GraphModule, self).__init__() + + self.__class__.__name__ = class_name + + self.train_graph = train_graph + self.eval_graph = eval_graph + + # Copy all get_attr and call_module ops (indicated by BOTH train and + # eval graphs) + for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): + if node.op in ['get_attr', 'call_module']: + assert isinstance(node.target, str) + _copy_attr(root, self, node.target) + + # eval mode by default + self.eval() + self.graph = eval_graph + + # (borrowed from fx.GraphModule): + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + assert self.eval_graph._tracer_cls == self.train_graph._tracer_cls, \ + "Train mode and eval mode should use the same tracer class" + self._tracer_cls = None + if self.graph._tracer_cls and '' not in self.graph._tracer_cls.__qualname__: + self._tracer_cls = self.graph._tracer_cls + + def train(self, mode=True): + """ + Swap out the graph depending on the training mode. + NOTE this should be safe when calling model.eval() because that just + calls this with mode == False. + """ + if mode: + self.graph = self.train_graph + else: + self.graph = self.eval_graph + return super().train(mode=mode) def build_feature_graph_net( model: nn.Module, return_nodes: Union[List[str], Dict[str, str]], + train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, + eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Dict = {}) -> fx.GraphModule: """ Creates a new graph module that returns intermediate nodes from a given @@ -203,6 +325,10 @@ def build_feature_graph_net( `print_graph_node_qualified_names` is a useful helper function for getting a list of qualified names of a model. + An attempt is made to keep all non-parametric properties of the original + model, but existing properties of the constructed `GraphModule` are not + overwritten. + Args: model (nn.Module): model on which we will extract the features return_nodes (Union[List[name], Dict[name, new_name]])): either a list @@ -216,13 +342,6 @@ def build_feature_graph_net( `NodePathTracer` (which passes them onto it's parent class `torch.fx.Tracer`). - NOTE: Static control flow will be frozen into place for the resulting - `GraphModule`. Among other consequences, this means that control flow - that relies on whether the model is in train or eval mode will be - frozen into place (except for leaf modules which are not traced - through). Therefore, calling `.train()` or `.eval()` on the resulting - `GraphModule` may not have all the desired effects. - Examples:: >>> model = torchvision.models.resnet18() @@ -235,118 +354,114 @@ def build_feature_graph_net( >>> ('feat2', torch.Size([1, 256, 14, 14]))] """ - if isinstance(return_nodes, list): - return_nodes = {n: n for n in return_nodes} - return_nodes = {str(k): str(v) for k, v in return_nodes.items()} - - # Instantiate our NodePathTracer and use that to trace the model - tracer = NodePathTracer(**tracer_kwargs) - graph = tracer.trace(model) - - name = model.__class__.__name__ if isinstance(model, nn.Module) else model.__name__ - graph_module = fx.GraphModule(tracer.root, graph, name) - - available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] - # FIXME We don't know if we should expect this to happen - assert len(set(available_nodes)) == len(available_nodes), \ - "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" - # Check that all outputs in return_nodes are present in the model - for query in return_nodes.keys(): - if not any([m.startswith(query) for m in available_nodes]): - raise ValueError(f"return_node: {query} is not present in model") - - # Remove existing output nodes - orig_output_node = None - for n in reversed(graph_module.graph.nodes): - if n.op == "output": - orig_output_node = n - assert orig_output_node - # And remove it - graph_module.graph.erase_node(orig_output_node) - # Find nodes corresponding to return_nodes and make them into output_nodes - nodes = [n for n in graph_module.graph.nodes] - output_nodes = OrderedDict() - for n in reversed(nodes): - if 'tensor_constant' in str(n): - # NOTE Without this control flow we would get a None value for - # `module_qualname = tracer.node_to_qualname.get(n)`. - # On the other hand, we can safely assume that we'll never need to - # get this as an interesting intermediate node. - continue - module_qualname = tracer.node_to_qualname.get(n) - for query in return_nodes: - depth = query.count('.') - if '.'.join(module_qualname.split('.')[:depth + 1]) == query: - output_nodes[return_nodes[query]] = n - return_nodes.pop(query) - break - output_nodes = OrderedDict(reversed(list(output_nodes.items()))) + is_training = model.training + + assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \ + ("If any of `train_return_nodes` and `eval_return_nodes` are " + "specified, then both should be specified") + + # Put *_return_nodes into Dict[str, str] format + def to_strdict(n) -> Dict[str, str]: + if isinstance(n, list): + return {str(i): str(i) for i in n} + return {str(k): str(v) for k, v in n.items()} + + if train_return_nodes is None: + return_nodes = to_strdict(return_nodes) + train_return_nodes = deepcopy(return_nodes) + eval_return_nodes = deepcopy(return_nodes) + else: + train_return_nodes = to_strdict(train_return_nodes) + eval_return_nodes = to_strdict(eval_return_nodes) + + # Repeat the tracing and graph rewriting for train and eval mode + tracers = {} + graphs = {} + mode_return_nodes : Dict[str, Dict[str, str]] = { + 'train': train_return_nodes, + 'eval': eval_return_nodes + } + for mode in ['train', 'eval']: + if mode == 'train': + model.train() + elif mode == 'eval': + model.eval() + + # Instantiate our NodePathTracer and use that to trace the model + tracer = NodePathTracer(**tracer_kwargs) + graph = tracer.trace(model) + + name = model.__class__.__name__ if isinstance( + model, nn.Module) else model.__name__ + graph_module = fx.GraphModule(tracer.root, graph, name) + + available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] + # FIXME We don't know if we should expect this to happen + assert len(set(available_nodes)) == len(available_nodes), \ + "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" + # Check that all outputs in return_nodes are present in the model + for query in mode_return_nodes[mode].keys(): + if not any([m.startswith(query) for m in available_nodes]): + raise ValueError(f"return_node: {query} is not present in model") + + # Remove existing output nodes (train mode) + orig_output_nodes = [] + for n in reversed(graph_module.graph.nodes): + if n.op == "output": + orig_output_nodes.append(n) + assert len(orig_output_nodes) + for n in orig_output_nodes: + graph_module.graph.erase_node(n) + + # Find nodes corresponding to return_nodes and make them into output_nodes + nodes = [n for n in graph_module.graph.nodes] + output_nodes = OrderedDict() + for n in reversed(nodes): + if 'tensor_constant' in str(n): + # NOTE Without this control flow we would get a None value for + # `module_qualname = tracer.node_to_qualname.get(n)`. + # On the other hand, we can safely assume that we'll never need to + # get this as an interesting intermediate node. + continue + module_qualname = tracer.node_to_qualname.get(n) + for query in mode_return_nodes[mode]: + depth = query.count('.') + if '.'.join(module_qualname.split('.')[:depth + 1]) == query: + output_nodes[mode_return_nodes[mode][query]] = n + mode_return_nodes[mode].pop(query) + break + output_nodes = OrderedDict(reversed(list(output_nodes.items()))) - # And add them in the end of the graph - with graph_module.graph.inserting_after(nodes[-1]): - graph_module.graph.output(output_nodes) + # And add them in the end of the graph + with graph_module.graph.inserting_after(nodes[-1]): + graph_module.graph.output(output_nodes) - # Remove unused modules / parameters - graph_module.graph.eliminate_dead_code() - graph_module.recompile() - graph_module = fx.GraphModule(graph_module, graph_module.graph, name) - return graph_module + # Remove unused modules / parameters + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + # Keep track of the tracer and graph so we can choose the main one + tracers[mode] = tracer + graphs[mode] = graph -class FeatureGraphNet(nn.Module): - """ - Wrap a `GraphModule` from `build_feature_graph_net` while also keeping the - original model's non-parameter properties for reference. The original - model's paremeters are discarded. + # Warn user if there are any discrepancies between the graphs of the + # train and eval modes + _warn_graph_differences(tracers['train'], tracers['eval']) - See `build_feature_graph_net` docstring for more information. + # Build the final graph module + graph_module = DualGraphModule( + model, graphs['train'], graphs['eval'], class_name=name) - NOTE: This puts the input model into eval mode prior to tracing. This - means that any control flow dependent on the model being in train mode - will be lost. - """ - def __init__(self, model: nn.Module, - return_nodes: Union[List[str], Dict[str, str]], - tracer_kwargs: Dict = {}): - """ - Args: - model (nn.Module): model on which we will extract the features - return_nodes (Union[List[name], Dict[name, new_name]])): either a list - or a dict containing the names (or partial names - see note above) - of the nodes for which the activations will be returned. If it is - a `Dict`, the keys are the qualified node names, and the values - are the user-specified keys for the graph module's returned - dictionary. If it is a `List`, it is treated as a `Dict` mapping - node specification strings directly to output names. - tracer_kwargs (Dict): a dictionary of keywork arguments for - `NodePathTracer` (which passes them onto it's parent class - `torch.fx.Tracer`). - """ - super(FeatureGraphNet, self).__init__() - model.eval() - self.graph_module = build_feature_graph_net( - model, return_nodes, tracer_kwargs) - # Keep non-parameter model properties for reference - for attr_str in model.__dir__(): - attr = getattr(model, attr_str) - if (not attr_str.startswith('_') - and attr_str not in self.__dir__() - and not ismethod(attr) - and not isinstance(attr, (nn.Module, nn.Parameter))): - setattr(self, attr_str, attr) - - def forward(self, x: Tensor) -> Dict[str, Tensor]: - return self.graph_module(x) - - def train(self, mode: bool = True): - """ - NOTE: This also covers `self.eval()` as that just calls self.train(False) - """ - if mode: - warnings.warn( - "Setting a FeatureGraphNet to training mode won't necessarily" - " have the desired effect. Control flow depending on" - " `self.training` will follow the `False` path. See" - " `FeatureGraphNet` doc-string for more details.") + # Keep non-parameter model properties for reference + for attr_str in model.__dir__(): + attr = getattr(model, attr_str) + if (not attr_str.startswith('_') + and attr_str not in graph_module.__dir__() + and not ismethod(attr) + and not isinstance(attr, (nn.Module, nn.Parameter))): + setattr(graph_module, attr_str, attr) - super(FeatureGraphNet, self).train(mode) + # Restore original training mode + graph_module.train(is_training) + + return graph_module From 23bb71f6549107f7897c135b275b25038fa2b7c8 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 1 Sep 2021 15:22:23 +0100 Subject: [PATCH 03/18] FX feature extraction - Tweaks and small bug fixes --- torchvision/models/_utils.py | 52 ++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 70ce60607cd..5b97ffca09c 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -181,8 +181,8 @@ def _warn_graph_differences( train_nodes = list(train_tracer.node_to_qualname.values()) eval_nodes = list(eval_tracer.node_to_qualname.values()) - if len(train_nodes) == len(eval_nodes) and [ - t == e for t, e in zip(train_nodes, eval_nodes)]: + if len(train_nodes) == len(eval_nodes) and all( + t == e for t, e in zip(train_nodes, eval_nodes)): return suggestion_msg = ( @@ -227,8 +227,8 @@ def print_graph_node_qualified_names( eval_tracer.trace(model.eval()) train_nodes = list(train_tracer.node_to_qualname.values()) eval_nodes = list(eval_tracer.node_to_qualname.values()) - if len(train_nodes) == len(eval_nodes) and [ - t == e for t, e in zip(train_nodes, eval_nodes)]: + if len(train_nodes) == len(eval_nodes) and all( + t == e for t, e in zip(train_nodes, eval_nodes)): # Nodes are aligned in train vs eval mode pprint(list(train_tracer.node_to_qualname.values())) return @@ -274,9 +274,9 @@ def __init__(self, assert isinstance(node.target, str) _copy_attr(root, self, node.target) - # eval mode by default - self.eval() - self.graph = eval_graph + # train mode by default + self.train() + self.graph = train_graph # (borrowed from fx.GraphModule): # Store the Tracer class responsible for creating a Graph separately as part of the @@ -292,20 +292,22 @@ def __init__(self, def train(self, mode=True): """ - Swap out the graph depending on the training mode. + Swap out the graph depending on the selected training mode. NOTE this should be safe when calling model.eval() because that just calls this with mode == False. """ - if mode: + # NOTE: Only set self.graph if the current graph is not the desired + # one. This saves us from recompiling the graph where not necessary. + if mode and not self.training: self.graph = self.train_graph - else: + elif not mode and self.training: self.graph = self.eval_graph return super().train(mode=mode) def build_feature_graph_net( model: nn.Module, - return_nodes: Union[List[str], Dict[str, str]], + return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, tracer_kwargs: Dict = {}) -> fx.GraphModule: @@ -331,13 +333,25 @@ def build_feature_graph_net( Args: model (nn.Module): model on which we will extract the features - return_nodes (Union[List[name], Dict[name, new_name]])): either a list + return_nodes (Optional[Union[List[str], Dict[str, str]]]): either a list or a dict containing the names (or partial names - see note above) of the nodes for which the activations will be returned. If it is a `Dict`, the keys are the qualified node names, and the values are the user-specified keys for the graph module's returned dictionary. If it is a `List`, it is treated as a `Dict` mapping - node specification strings directly to output names. + node specification strings directly to output names. In the case + that `train_return_nodes` and `eval_return_nodes` are specified, + this should not be specified. + train_return_nodes (Optional[Union[List[str], Dict[str, str]]]): + similar to `return_nodes`. This can be used if the return nodes + for train mode are different than those from eval mode. + If this is specified, `eval_return_nodes` must also be specified, + and `return_nodes` should not be specified. + eval_return_nodes (Optional[Union[List[str], Dict[str, str]]]): + similar to `return_nodes`. This can be used if the return nodes + for train mode are different than those from eval mode. + If this is specified, `train_return_nodes` must also be specified, + and `return_nodes` should not be specified. tracer_kwargs (Dict): a dictionary of keywork arguments for `NodePathTracer` (which passes them onto it's parent class `torch.fx.Tracer`). @@ -360,6 +374,10 @@ def build_feature_graph_net( ("If any of `train_return_nodes` and `eval_return_nodes` are " "specified, then both should be specified") + assert ((return_nodes is None) ^ (train_return_nodes is None)), \ + ("If `train_return_nodes` and `eval_return_nodes` are specified, " + "then both should be specified") + # Put *_return_nodes into Dict[str, str] format def to_strdict(n) -> Dict[str, str]: if isinstance(n, list): @@ -402,7 +420,12 @@ def to_strdict(n) -> Dict[str, str]: # Check that all outputs in return_nodes are present in the model for query in mode_return_nodes[mode].keys(): if not any([m.startswith(query) for m in available_nodes]): - raise ValueError(f"return_node: {query} is not present in model") + raise ValueError( + f"node: '{query}' is not present in model. Hint: use " + "`print_graph_node_qualified_names` to make sure the " + "`return_nodes` you specified are present. It may even " + "be that you need to specify `train_return_nodes` and " + "`eval_return_nodes` seperately.") # Remove existing output nodes (train mode) orig_output_nodes = [] @@ -462,6 +485,7 @@ def to_strdict(n) -> Dict[str, str]: setattr(graph_module, attr_str, attr) # Restore original training mode + model.train(is_training) graph_module.train(is_training) return graph_module From 9dce734d33b27d8f0a3349acdcf71b776568f747 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 1 Sep 2021 15:26:38 +0100 Subject: [PATCH 04/18] FX feature extraction - add tests --- test/test_backbone_utils.py | 161 +++++++++++++++++++++++++++++++++++ torchvision/models/_utils.py | 4 +- 2 files changed, 163 insertions(+), 2 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 712dccf11a8..9caac29b7db 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,5 +1,9 @@ +import unittest import torch +import torchvision from torchvision.models.detection.backbone_utils import resnet_fpn_backbone +from torchvision.models._utils import build_feature_graph_net +from torchvision.models._utils import IntermediateLayerGetter import pytest @@ -9,3 +13,160 @@ def test_resnet_fpn_backbone(backbone_name): x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] + + +class TestFeatureExtraction(unittest.TestCase): + model = torchvision.models.resnet18(pretrained=False, num_classes=1).eval() + return_layers = { + 'layer1': 'layer1', + 'layer2': 'layer2', + 'layer3': 'layer3', + 'layer4': 'layer4' + } + inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu') + expected_out_shapes = [ + torch.Size([1, 64, 56, 56]), + torch.Size([1, 128, 28, 28]), + torch.Size([1, 256, 14, 14]), + torch.Size([1, 512, 7, 7]) + ] + + def test_build_feature_graph_net(self): + # Check that it works with both a list and dict for return nodes + build_feature_graph_net(self.model, self.return_layers) + build_feature_graph_net(self.model, list(self.return_layers.keys())) + # Check return_nodes and train_return_nodes / eval_return nodes + # mutual exclusivity + with pytest.raises(AssertionError): + build_feature_graph_net( + self.model, return_nodes=self.return_layers, + train_return_nodes=self.return_layers) + # Check train_return_nodes / eval_return nodes must both be specified + with pytest.raises(AssertionError): + build_feature_graph_net( + self.model, train_return_nodes=self.return_layers) + + def test_feature_graph_net_forward_backward(self): + model = build_feature_graph_net(self.model, self.return_layers) + out = model(self.inp) + # Check output shape + for o, e in zip(out.values(), self.expected_out_shapes): + assert o.shape == e + # Backward + sum([o.mean() for o in out.values()]).backward() + + def test_intermediate_layer_getter_forward_backward(self): + model = IntermediateLayerGetter(self.model, self.return_layers).eval() + out = model(self.inp) + # Check output shape + for o, e in zip(out.values(), self.expected_out_shapes): + assert o.shape == e + # Backward + sum([o.mean() for o in out.values()]).backward() + + def test_feature_extraction_methods_equivalence(self): + ilg_model = IntermediateLayerGetter( + self.model, self.return_layers).eval() + fgn_model = build_feature_graph_net(self.model, self.return_layers) + + # Check that we have same parameters + for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), + fgn_model.named_parameters()): + self.assertEqual(n1, n2) + self.assertTrue(p1.equal(p2)) + + # And state_dict matches + for (n1, p1), (n2, p2) in zip(ilg_model.state_dict().items(), + fgn_model.state_dict().items()): + self.assertEqual(n1, n2) + self.assertTrue(p1.equal(p2)) + + with torch.no_grad(): + ilg_out = ilg_model(self.inp) + fgn_out = fgn_model(self.inp) + + self.assertEqual(ilg_out.keys(), fgn_out.keys()) + for k in ilg_out.keys(): + o1 = ilg_out[k] + o2 = fgn_out[k] + self.assertTrue(o1.equal(o2)) + + def test_intermediate_layer_getter_scriptable_forward_backward(self): + ilg_model = IntermediateLayerGetter( + self.model, self.return_layers).eval() + ilg_model = torch.jit.script(ilg_model) + ilg_out = ilg_model(self.inp) + sum([o.mean() for o in ilg_out.values()]).backward() + + def test_feature_graph_net_scriptable_forward_backward(self): + fgn_model = build_feature_graph_net(self.model, self.return_layers) + fgn_model = torch.jit.script(fgn_model) + fgn_out = fgn_model(self.inp) + sum([o.mean() for o in fgn_out.values()]).backward() + + def test_feature_graph_net_train_eval(self): + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(p=1.) + + def forward(self, x): + x = x.mean() + x = self.dropout(x) # dropout + if self.training: + x += 100 # add + else: + x *= 0 # mul + x -= 0 # sub + return x + + model = TestModel() + + train_return_nodes = ['dropout', 'add', 'sub'] + eval_return_nodes = ['dropout', 'mul', 'sub'] + + def checks(model, mode): + with torch.no_grad(): + out = model(torch.ones(10, 10)) + if mode == 'train': + # Check that dropout is respected + assert out['dropout'].item() == 0 + # Check that control flow dependent on training_mode is respected + assert out['sub'].item() == 100 + assert 'add' in out + assert 'mul' not in out + elif mode == 'eval': + # Check that dropout is respected + assert out['dropout'].item() == 1 + # Check that control flow dependent on training_mode is respected + assert out['sub'].item() == 0 + assert 'mul' in out + assert 'add' not in out + + # Starting from train mode + model.train() + fgn_model = build_feature_graph_net( + model, train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes) + # Check that the models stay in their original training state + assert model.training + assert fgn_model.training + # Check outputs + checks(fgn_model, 'train') + # Check outputs after switching to eval mode + fgn_model.eval() + checks(fgn_model, 'eval') + + # Starting from eval mode + model.eval() + fgn_model = build_feature_graph_net( + model, train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes) + # Check that the models stay in their original training state + assert not model.training + assert not fgn_model.training + # Check outputs + checks(fgn_model, 'eval') + # Check outputs after switching to train mode + fgn_model.train() + checks(fgn_model, 'train') diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 5b97ffca09c..8d90e6694e7 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -339,7 +339,7 @@ def build_feature_graph_net( a `Dict`, the keys are the qualified node names, and the values are the user-specified keys for the graph module's returned dictionary. If it is a `List`, it is treated as a `Dict` mapping - node specification strings directly to output names. In the case + node specification strings directly to output names. In the case that `train_return_nodes` and `eval_return_nodes` are specified, this should not be specified. train_return_nodes (Optional[Union[List[str], Dict[str, str]]]): @@ -395,7 +395,7 @@ def to_strdict(n) -> Dict[str, str]: # Repeat the tracing and graph rewriting for train and eval mode tracers = {} graphs = {} - mode_return_nodes : Dict[str, Dict[str, str]] = { + mode_return_nodes: Dict[str, Dict[str, str]] = { 'train': train_return_nodes, 'eval': eval_return_nodes } From fa23bd80591f50a1a38c9cef407a303886c3b3e5 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 1 Sep 2021 19:33:54 +0100 Subject: [PATCH 05/18] move to feature_extraction.py, add LeafModuleAwareTracer, add docs --- docs/source/feature_extraction.rst | 30 +++++ docs/source/index.rst | 1 + test/test_backbone_utils.py | 44 ++++++- .../models/detection/backbone_utils.py | 2 +- .../{_utils.py => feature_extraction.py} | 114 +++++++++++++++--- .../models/segmentation/segmentation.py | 2 +- 6 files changed, 174 insertions(+), 19 deletions(-) create mode 100644 docs/source/feature_extraction.rst rename torchvision/models/{_utils.py => feature_extraction.py} (83%) diff --git a/docs/source/feature_extraction.rst b/docs/source/feature_extraction.rst new file mode 100644 index 00000000000..993d8a22a3a --- /dev/null +++ b/docs/source/feature_extraction.rst @@ -0,0 +1,30 @@ +torchvision.feature_extraction +============================== + +.. currentmodule:: torchvision.models.feature_extraction + +Feature extraction utilities let us tap into our models to access intermediate +transformations of our inputs. This could be useful for a variety of +applications in computer vision. Just a few examples are: + +- Visualizing feature maps. +- Extracting features to compute image descriptors for tasks like facial + recognition, copy-detection, or image retrieval. +- Passing selected features to downstream sub-networks for end-to-end training + with a specific task in mind. For example, passing a hierarchy of features + to a Feature Pyramid Network with object detection heads. + +Torchvision provides two helpers for doing feature extraction. :class:`IntermediateLayerGetter` +is easy to understand and effective. That said, it only allows coarse control +over which features are extracted, and makes some assumptions about the layout +of the input module. :func:`build_feature_graph_net` is far more +flexible, but does have some rough edges as it requires that the input model +is symbolically traceable (see +`torch.fx documentation `_ for more +information on symbolic tracing). + +.. autoclass:: IntermediateLayerGetter + +.. autofunction:: build_feature_graph_net + +.. autofunction:: print_graph_node_qualified_names \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index b25a85d8617..5bd01c4242c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -32,6 +32,7 @@ architectures, and common image transformations for computer vision. :caption: Package Reference datasets + feature_extraction io models ops diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 9caac29b7db..491841eeaf3 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -2,8 +2,8 @@ import torch import torchvision from torchvision.models.detection.backbone_utils import resnet_fpn_backbone -from torchvision.models._utils import build_feature_graph_net -from torchvision.models._utils import IntermediateLayerGetter +from torchvision.models.feature_extraction import build_feature_graph_net +from torchvision.models.feature_extraction import IntermediateLayerGetter import pytest @@ -15,6 +15,11 @@ def test_resnet_fpn_backbone(backbone_name): assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] +# Needed by TestFeatureExtraction.test_feature_graph_net_leaf_module_and_function +def leaf_function(x): + return int(x) + + class TestFeatureExtraction(unittest.TestCase): model = torchvision.models.resnet18(pretrained=False, num_classes=1).eval() return_layers = { @@ -35,6 +40,9 @@ def test_build_feature_graph_net(self): # Check that it works with both a list and dict for return nodes build_feature_graph_net(self.model, self.return_layers) build_feature_graph_net(self.model, list(self.return_layers.keys())) + # Check must specify return nodes + with pytest.raises(AssertionError): + build_feature_graph_net(self.model) # Check return_nodes and train_return_nodes / eval_return nodes # mutual exclusivity with pytest.raises(AssertionError): @@ -170,3 +178,35 @@ def checks(model, mode): # Check outputs after switching to train mode fgn_model.train() checks(fgn_model, 'train') + + def test_feature_graph_net_leaf_module_and_function(self): + class LeafModule(torch.nn.Module): + def forward(self, x): + # This would raise a TypeError if it were not in a leaf module + int(x.shape[0]) + return torch.nn.functional.relu(x + 4) + + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 1, 3) + self.leaf_module = LeafModule() + + def forward(self, x): + leaf_function(x.shape[0]) + x = self.conv(x) + return self.leaf_module(x) + + model = build_feature_graph_net( + TestModule(), return_nodes=['leaf_module'], + tracer_kwargs={'leaf_modules': [LeafModule], + 'autowrap_functions': [leaf_function]}) + + # Check that LeafModule is not in the list of nodes + assert 'relu' not in [str(n) for n in model.graph.nodes] + assert 'leaf_module' in [str(n) for n in model.graph.nodes] + + # Check forward + out = model(self.inp) + # And backward + out['leaf_module'].mean().backward() diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 3178a81b52c..701fab94457 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -3,7 +3,7 @@ from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool from torchvision.ops import misc as misc_nn_ops -from .._utils import IntermediateLayerGetter +from ..feature_extraction import IntermediateLayerGetter from .. import mobilenet from .. import resnet diff --git a/torchvision/models/_utils.py b/torchvision/models/feature_extraction.py similarity index 83% rename from torchvision/models/_utils.py rename to torchvision/models/feature_extraction.py index 8d90e6694e7..c0b7b44bdaa 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/feature_extraction.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Callable, List, Union, Optional +from typing import Dict, Callable, List, Union, Optional, Set from collections import OrderedDict import warnings import re @@ -15,7 +15,7 @@ class IntermediateLayerGetter(nn.ModuleDict): """ - Module wrapper that returns intermediate layers from a model + Module wrapper that returns intermediate layers from a model. It has a strong assumption that the modules have been registered into the model in the same order as they are used. @@ -26,6 +26,9 @@ class IntermediateLayerGetter(nn.ModuleDict): assigned to the model. So if `model` is passed, `model.feature1` can be returned, but not `model.feature1.layer2`. + For a more flexibile feature extraction, see + :func:`build_feature_graph_net`. + Args: model (nn.Module): model on which we will extract the features return_layers (Dict[name, new_name]): a dict containing the names @@ -75,14 +78,33 @@ def forward(self, x): return out -class NodePathTracer(fx.Tracer): +class LeafModuleAwareTracer(fx.Tracer): + """ + An fx.Tracer that allows the user to specify a set of leaf modules, ie. + modules that are not to be traced through. The resulting graph ends up + having single nodes referencing calls to the leaf modules' forward methods. + """ + def __init__(self, *args, **kwargs): + self.leaf_modules = {} + if 'leaf_modules' in kwargs: + leaf_modules = kwargs.pop('leaf_modules') + self.leaf_modules = leaf_modules + super(LeafModuleAwareTracer, self).__init__(*args, **kwargs) + + def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool: + if isinstance(m, tuple(self.leaf_modules)): + return True + return super().is_leaf_module(m, module_qualname) + + +class NodePathTracer(LeafModuleAwareTracer): """ NodePathTracer is an FX tracer that, for each operation, also records the qualified name of the Node from which the operation originated. A qualified name here is a `.` seperated path walking the hierarchy from top level module down to leaf operation or leaf module. The name of the top level module is not included as part of the qualified name. For example, - if we trace a module who's forward method applies a ReLU module, the + if we trace a module whose forward method applies a ReLU module, the qualified name for that node will simply be 'relu'. Some notes on the specifics: @@ -207,19 +229,27 @@ def print_graph_node_qualified_names( Dev utility to prints nodes in order of execution. Useful for choosing nodes for a FeatureGraphNet design. There are two reasons that qualified node names can't easily be read directly from the code for a model: + 1. Not all submodules are traced through. Modules from `torch.nn` all - fall within this category. + fall within this category. + 2. Node qualified names that occur more than once in the graph get a - `_{counter}` postfix. - The model will be traced twice: once in train mode, and once in eval mode. - If there are discrepancies between the graphs produced, both sets will - be printed and the user will be warned. + `_{counter}` postfix. + + The model is traced twice: once in train mode, and once in eval mode. + If there are discrepancies between the graphs produced, both sets of nodes + will be printed and the user will be warned. Args: - model (nn.Module): model on which we will extract the features + model (nn.Module): model for which we'd like to print node names tracer_kwargs (Dict): a dictionary of keywork arguments for - `NodePathTracer` (which passes them onto it's parent class + `NodePathTracer` (which in turn passes them onto it's parent class `torch.fx.Tracer`). + + Examples:: + + >>> model = torchvision.models.resnet18() + >>> print_graph_node_qualified_names(model) """ train_tracer = NodePathTracer(**tracer_kwargs) train_tracer.trace(model.train()) @@ -324,17 +354,38 @@ def build_feature_graph_net( `return_nodes` argument should point to either a node's qualified name, or some truncated version of it. For example, one could provide `blocks.5` as a key, and the last node with that prefix will be selected. - `print_graph_node_qualified_names` is a useful helper function for getting - a list of qualified names of a model. + :func:`print_graph_node_qualified_names` is a useful helper function + for getting a list of qualified node names of a model. An attempt is made to keep all non-parametric properties of the original model, but existing properties of the constructed `GraphModule` are not overwritten. + Not all models will be FX traceable, although with some massaging they can + be made to cooperate. Here's a (not exhaustive) list of tips: + + - If you don't need to trace through a particular, problematic + sub-module, turn it into a "leaf module" by passing a list of + `leaf_modules` as one of the `tracer_kwargs` (see example below). It + will not be traced through, but rather, the resulting graph will + hold a reference to that module's forward method. + - Likewise, you may turn functions into leaf functions by passing a + list of `autowrap_functions` as one of the `tracer_kwargs` (see + example below). + - Some inbuilt Python functions can be problematic. For instance, + `int` will raise an error during tracig. You may wrap them in your + own function and then pass that in `autowrap_functions` as one of + the `tracer_kwargs`. + - Use `torch.matmul(tensor_a, tensor_b)` instead of `tensor_a @ + tensor_b`. + + For further information on FX see the + `torch.fx documentation `_. + Args: model (nn.Module): model on which we will extract the features - return_nodes (Optional[Union[List[str], Dict[str, str]]]): either a list - or a dict containing the names (or partial names - see note above) + return_nodes (Optional[Union[List[str], Dict[str, str]]]): either a `List` + or a `Dict` containing the names (or partial names - see note above) of the nodes for which the activations will be returned. If it is a `Dict`, the keys are the qualified node names, and the values are the user-specified keys for the graph module's returned @@ -358,6 +409,7 @@ def build_feature_graph_net( Examples:: + >>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> graph_module = torchvision.models._utils.build_feature_graph_net(m, @@ -367,9 +419,41 @@ def build_feature_graph_net( >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] + >>> # Specifying leaf modules and leaf functions + >>> def leaf_function(x): + >>> # This would raise a TypeError if traced through + >>> return int(x) + >>> + >>> class LeafModule(torch.nn.Module): + >>> def forward(self, x): + >>> # This would raise a TypeError if traced through + >>> int(x.shape[0]) + >>> return torch.nn.functional.relu(x + 4) + >>> + >>> class MyModule(torch.nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.conv = torch.nn.Conv2d(3, 1, 3) + >>> self.leaf_module = LeafModule() + >>> + >>> def forward(self, x): + >>> leaf_function(x.shape[0]) + >>> x = self.conv(x) + >>> return self.leaf_module(x) + >>> + >>> model = build_feature_graph_net( + >>> MyModule(), return_nodes=['leaf_module'], + >>> tracer_kwargs={'leaf_modules': [LeafModule], + >>> 'autowrap_functions': [leaf_function]}) + """ is_training = model.training + assert any(arg is not None for arg in [ + return_nodes, train_return_nodes, eval_return_nodes]), ( + "Either `return_nodes` or `train_return_nodes` and " + "`eval_return_nodes` together, should be specified") + assert not ((train_return_nodes is None) ^ (eval_return_nodes is None)), \ ("If any of `train_return_nodes` and `eval_return_nodes` are " "specified, then both should be specified") diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index 0f2f14c97ba..426f0e703fe 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -1,4 +1,4 @@ -from .._utils import IntermediateLayerGetter +from ..feature_extraction import IntermediateLayerGetter from ..._internally_replaced_utils import load_state_dict_from_url from .. import mobilenetv3 from .. import resnet From 0581ba760ff7b7558a23d986676c8c46bc27fd12 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 1 Sep 2021 19:42:08 +0100 Subject: [PATCH 06/18] Tweaks to docs --- docs/source/feature_extraction.rst | 2 +- torchvision/models/feature_extraction.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/docs/source/feature_extraction.rst b/docs/source/feature_extraction.rst index 993d8a22a3a..94fa6d456b5 100644 --- a/docs/source/feature_extraction.rst +++ b/docs/source/feature_extraction.rst @@ -14,7 +14,7 @@ applications in computer vision. Just a few examples are: with a specific task in mind. For example, passing a hierarchy of features to a Feature Pyramid Network with object detection heads. -Torchvision provides two helpers for doing feature extraction. :class:`IntermediateLayerGetter` +Torchvision provides two utilities for doing feature extraction. :class:`IntermediateLayerGetter` is easy to understand and effective. That said, it only allows coarse control over which features are extracted, and makes some assumptions about the layout of the input module. :func:`build_feature_graph_net` is far more diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index c0b7b44bdaa..5288aba4281 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -231,10 +231,9 @@ def print_graph_node_qualified_names( node names can't easily be read directly from the code for a model: 1. Not all submodules are traced through. Modules from `torch.nn` all - fall within this category. - + fall within this category. 2. Node qualified names that occur more than once in the graph get a - `_{counter}` postfix. + `_{counter}` postfix. The model is traced twice: once in train mode, and once in eval mode. If there are discrepancies between the graphs produced, both sets of nodes From 7ecd15bad4e777e2e8e6466f6f05db3b1fc278ee Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 3 Sep 2021 14:02:19 +0100 Subject: [PATCH 07/18] addressing latest round of feedback --- docs/source/feature_extraction.rst | 30 +-- test/test_backbone_utils.py | 195 +++++++++------- torchvision/models/_utils.py | 61 +++++ .../models/detection/backbone_utils.py | 2 +- torchvision/models/feature_extraction.py | 214 ++++++------------ .../models/segmentation/segmentation.py | 2 +- 6 files changed, 261 insertions(+), 243 deletions(-) create mode 100644 torchvision/models/_utils.py diff --git a/docs/source/feature_extraction.rst b/docs/source/feature_extraction.rst index 94fa6d456b5..a235765d396 100644 --- a/docs/source/feature_extraction.rst +++ b/docs/source/feature_extraction.rst @@ -1,5 +1,5 @@ -torchvision.feature_extraction -============================== +torchvision.models.feature_extraction +===================================== .. currentmodule:: torchvision.models.feature_extraction @@ -14,17 +14,21 @@ applications in computer vision. Just a few examples are: with a specific task in mind. For example, passing a hierarchy of features to a Feature Pyramid Network with object detection heads. -Torchvision provides two utilities for doing feature extraction. :class:`IntermediateLayerGetter` -is easy to understand and effective. That said, it only allows coarse control -over which features are extracted, and makes some assumptions about the layout -of the input module. :func:`build_feature_graph_net` is far more -flexible, but does have some rough edges as it requires that the input model -is symbolically traceable (see -`torch.fx documentation `_ for more -information on symbolic tracing). +Torchvision provides :func:`build_feature_extractor` for this purpose. +It works by following roughly these steps: -.. autoclass:: IntermediateLayerGetter +1. Symbolically tracing the model to get a graphical representation of + how it transforms the input, step by step. +2. Setting the user-selected graph nodes as ouputs. +3. Removing all redundant nodes (anything downstream of the ouput nodes). +4. Generating python code from the resulting graph and bundling that, together + with the graph, and bundling that into a PyTorch module. -.. autofunction:: build_feature_graph_net +| -.. autofunction:: print_graph_node_qualified_names \ No newline at end of file +See `torch.fx documentation `_ for +more information on symbolic tracing. + +.. autofunction:: build_feature_extractor + +.. autofunction:: get_graph_node_names \ No newline at end of file diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 491841eeaf3..f48925e747e 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,12 +1,30 @@ -import unittest +from functools import partial +import random + import torch -import torchvision +from torchvision import models from torchvision.models.detection.backbone_utils import resnet_fpn_backbone -from torchvision.models.feature_extraction import build_feature_graph_net -from torchvision.models.feature_extraction import IntermediateLayerGetter +from torchvision.models.feature_extraction import build_feature_extractor +from torchvision.models.feature_extraction import get_graph_node_names +from torchvision.models._utils import IntermediateLayerGetter import pytest +from common_utils import set_rng_seed + + +# Suppress diff warning from build_feature_extractor +build_feature_extractor = partial( + build_feature_extractor, suppress_diff_warning=True) +get_graph_node_names = partial( + get_graph_node_names, suppress_diff_warning=True) + + +def get_available_models(): + # TODO add a registration mechanism to torchvision.models + return [k for k, v in models.__dict__.items() + if callable(v) and k[0].lower() == k[0] and k[0] != "_"] + @pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) def test_resnet_fpn_backbone(backbone_name): @@ -15,104 +33,105 @@ def test_resnet_fpn_backbone(backbone_name): assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] -# Needed by TestFeatureExtraction.test_feature_graph_net_leaf_module_and_function +# Needed by TestFxFeatureExtraction.test_leaf_module_and_function def leaf_function(x): return int(x) -class TestFeatureExtraction(unittest.TestCase): - model = torchvision.models.resnet18(pretrained=False, num_classes=1).eval() - return_layers = { - 'layer1': 'layer1', - 'layer2': 'layer2', - 'layer3': 'layer3', - 'layer4': 'layer4' - } +class TestFxFeatureExtraction: inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device='cpu') - expected_out_shapes = [ - torch.Size([1, 64, 56, 56]), - torch.Size([1, 128, 28, 28]), - torch.Size([1, 256, 14, 14]), - torch.Size([1, 512, 7, 7]) - ] - - def test_build_feature_graph_net(self): + model_defaults = { + 'num_classes': 1, + 'pretrained': False + } + + def _get_return_nodes(self, model): + set_rng_seed(0) + exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk'] + train_nodes, eval_nodes = get_graph_node_names(model) + # Get rid of any nodes that don't return tensors as they cause issues + # when testing backward pass. + train_nodes = [n for n in train_nodes + if not any(x in n for x in exclude_nodes_filter)] + eval_nodes = [n for n in eval_nodes + if not any(x in n for x in exclude_nodes_filter)] + return random.sample(train_nodes, 10), random.sample(eval_nodes, 10) + + @pytest.mark.parametrize('model_name', get_available_models()) + def test_build_fx_feature_extractor(self, model_name): + set_rng_seed(0) + model = models.__dict__[model_name](**self.model_defaults).eval() + train_return_nodes, eval_return_nodes = self._get_return_nodes(model) # Check that it works with both a list and dict for return nodes - build_feature_graph_net(self.model, self.return_layers) - build_feature_graph_net(self.model, list(self.return_layers.keys())) + build_feature_extractor( + model, train_return_nodes={v: v for v in train_return_nodes}, + eval_return_nodes=eval_return_nodes) + build_feature_extractor( + model, train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes) # Check must specify return nodes with pytest.raises(AssertionError): - build_feature_graph_net(self.model) + build_feature_extractor(model) # Check return_nodes and train_return_nodes / eval_return nodes # mutual exclusivity with pytest.raises(AssertionError): - build_feature_graph_net( - self.model, return_nodes=self.return_layers, - train_return_nodes=self.return_layers) + build_feature_extractor(model, return_nodes=train_return_nodes, + train_return_nodes=train_return_nodes) # Check train_return_nodes / eval_return nodes must both be specified with pytest.raises(AssertionError): - build_feature_graph_net( - self.model, train_return_nodes=self.return_layers) - - def test_feature_graph_net_forward_backward(self): - model = build_feature_graph_net(self.model, self.return_layers) - out = model(self.inp) - # Check output shape - for o, e in zip(out.values(), self.expected_out_shapes): - assert o.shape == e - # Backward - sum([o.mean() for o in out.values()]).backward() - - def test_intermediate_layer_getter_forward_backward(self): - model = IntermediateLayerGetter(self.model, self.return_layers).eval() + build_feature_extractor( + model, train_return_nodes=train_return_nodes) + + @pytest.mark.parametrize('model_name', get_available_models()) + def test_forward_backward(self, model_name): + model = models.__dict__[model_name](**self.model_defaults).train() + train_return_nodes, eval_return_nodes = self._get_return_nodes(model) + model = build_feature_extractor( + model, train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes) out = model(self.inp) - # Check output shape - for o, e in zip(out.values(), self.expected_out_shapes): - assert o.shape == e - # Backward sum([o.mean() for o in out.values()]).backward() def test_feature_extraction_methods_equivalence(self): + model = models.resnet18(**self.model_defaults).eval() + return_layers = { + 'layer1': 'layer1', + 'layer2': 'layer2', + 'layer3': 'layer3', + 'layer4': 'layer4' + } + ilg_model = IntermediateLayerGetter( - self.model, self.return_layers).eval() - fgn_model = build_feature_graph_net(self.model, self.return_layers) + model, return_layers).eval() + fx_model = build_feature_extractor(model, return_layers) # Check that we have same parameters for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), - fgn_model.named_parameters()): - self.assertEqual(n1, n2) - self.assertTrue(p1.equal(p2)) - - # And state_dict matches - for (n1, p1), (n2, p2) in zip(ilg_model.state_dict().items(), - fgn_model.state_dict().items()): - self.assertEqual(n1, n2) - self.assertTrue(p1.equal(p2)) + fx_model.named_parameters()): + assert n1 == n2 + assert p1.equal(p2) + # And that ouputs match with torch.no_grad(): ilg_out = ilg_model(self.inp) - fgn_out = fgn_model(self.inp) - - self.assertEqual(ilg_out.keys(), fgn_out.keys()) + fgn_out = fx_model(self.inp) + assert all(k1 == k2 for k1, k2 in zip(ilg_out.keys(), fgn_out.keys())) for k in ilg_out.keys(): - o1 = ilg_out[k] - o2 = fgn_out[k] - self.assertTrue(o1.equal(o2)) - - def test_intermediate_layer_getter_scriptable_forward_backward(self): - ilg_model = IntermediateLayerGetter( - self.model, self.return_layers).eval() - ilg_model = torch.jit.script(ilg_model) - ilg_out = ilg_model(self.inp) - sum([o.mean() for o in ilg_out.values()]).backward() - - def test_feature_graph_net_scriptable_forward_backward(self): - fgn_model = build_feature_graph_net(self.model, self.return_layers) - fgn_model = torch.jit.script(fgn_model) - fgn_out = fgn_model(self.inp) + assert ilg_out[k].equal(fgn_out[k]) + + @pytest.mark.parametrize('model_name', get_available_models()) + def test_jit_forward_backward(self, model_name): + set_rng_seed(0) + model = models.__dict__[model_name](**self.model_defaults).train() + train_return_nodes, eval_return_nodes = self._get_return_nodes(model) + model = build_feature_extractor( + model, train_return_nodes=train_return_nodes, + eval_return_nodes=eval_return_nodes) + model = torch.jit.script(model) + fgn_out = model(self.inp) sum([o.mean() for o in fgn_out.values()]).backward() - def test_feature_graph_net_train_eval(self): + def test_train_eval(self): class TestModel(torch.nn.Module): def __init__(self): super().__init__() @@ -153,33 +172,33 @@ def checks(model, mode): # Starting from train mode model.train() - fgn_model = build_feature_graph_net( + fx_model = build_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) # Check that the models stay in their original training state assert model.training - assert fgn_model.training + assert fx_model.training # Check outputs - checks(fgn_model, 'train') + checks(fx_model, 'train') # Check outputs after switching to eval mode - fgn_model.eval() - checks(fgn_model, 'eval') + fx_model.eval() + checks(fx_model, 'eval') # Starting from eval mode model.eval() - fgn_model = build_feature_graph_net( + fx_model = build_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) # Check that the models stay in their original training state assert not model.training - assert not fgn_model.training + assert not fx_model.training # Check outputs - checks(fgn_model, 'eval') + checks(fx_model, 'eval') # Check outputs after switching to train mode - fgn_model.train() - checks(fgn_model, 'train') + fx_model.train() + checks(fx_model, 'train') - def test_feature_graph_net_leaf_module_and_function(self): + def test_leaf_module_and_function(self): class LeafModule(torch.nn.Module): def forward(self, x): # This would raise a TypeError if it were not in a leaf module @@ -197,10 +216,10 @@ def forward(self, x): x = self.conv(x) return self.leaf_module(x) - model = build_feature_graph_net( + model = build_feature_extractor( TestModule(), return_nodes=['leaf_module'], tracer_kwargs={'leaf_modules': [LeafModule], - 'autowrap_functions': [leaf_function]}) + 'autowrap_functions': [leaf_function]}).train() # Check that LeafModule is not in the list of nodes assert 'relu' not in [str(n) for n in model.graph.nodes] diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py new file mode 100644 index 00000000000..8cd5315f5b7 --- /dev/null +++ b/torchvision/models/_utils.py @@ -0,0 +1,61 @@ +from collections import OrderedDict + +from torch import nn +from typing import Dict + + +class IntermediateLayerGetter(nn.ModuleDict): + """ + Module wrapper that returns intermediate layers from a model + It has a strong assumption that the modules have been registered + into the model in the same order as they are used. + This means that one should **not** reuse the same nn.Module + twice in the forward if you want this to work. + Additionally, it is only able to query submodules that are directly + assigned to the model. So if `model` is passed, `model.feature1` can + be returned, but not `model.feature1.layer2`. + Args: + model (nn.Module): model on which we will extract the features + return_layers (Dict[name, new_name]): a dict containing the names + of the modules for which the activations will be returned as + the key of the dict, and the value of the dict is the name + of the returned activation (which the user can specify). + Examples:: + >>> m = torchvision.models.resnet18(pretrained=True) + >>> # extract layer1 and layer3, giving as names `feat1` and feat2` + >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, + >>> {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = new_m(torch.rand(1, 3, 224, 224)) + >>> print([(k, v.shape) for k, v in out.items()]) + >>> [('feat1', torch.Size([1, 64, 56, 56])), + >>> ('feat2', torch.Size([1, 256, 14, 14]))] + """ + _version = 2 + __annotations__ = { + "return_layers": Dict[str, str], + } + + def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None: + if not set(return_layers).issubset([name for name, _ in model.named_children()]): + raise ValueError("return_layers are not present in model") + orig_return_layers = return_layers + return_layers = {str(k): str(v) for k, v in return_layers.items()} + layers = OrderedDict() + for name, module in model.named_children(): + layers[name] = module + if name in return_layers: + del return_layers[name] + if not return_layers: + break + + super(IntermediateLayerGetter, self).__init__(layers) + self.return_layers = orig_return_layers + + def forward(self, x): + out = OrderedDict() + for name, module in self.items(): + x = module(x) + if name in self.return_layers: + out_name = self.return_layers[name] + out[out_name] = x + return out diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 701fab94457..3178a81b52c 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -3,7 +3,7 @@ from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool from torchvision.ops import misc as misc_nn_ops -from ..feature_extraction import IntermediateLayerGetter +from .._utils import IntermediateLayerGetter from .. import mobilenet from .. import resnet diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index 5288aba4281..3474b8d3ecf 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -1,9 +1,7 @@ -from typing import Dict, Callable, List, Union, Optional, Set +from typing import Dict, Callable, List, Union, Optional, Tuple from collections import OrderedDict import warnings import re -from pprint import pprint -from inspect import ismethod from copy import deepcopy from itertools import chain @@ -13,71 +11,6 @@ from torch.fx.graph_module import _copy_attr -class IntermediateLayerGetter(nn.ModuleDict): - """ - Module wrapper that returns intermediate layers from a model. - - It has a strong assumption that the modules have been registered - into the model in the same order as they are used. - This means that one should **not** reuse the same nn.Module - twice in the forward if you want this to work. - - Additionally, it is only able to query submodules that are directly - assigned to the model. So if `model` is passed, `model.feature1` can - be returned, but not `model.feature1.layer2`. - - For a more flexibile feature extraction, see - :func:`build_feature_graph_net`. - - Args: - model (nn.Module): model on which we will extract the features - return_layers (Dict[name, new_name]): a dict containing the names - of the modules for which the activations will be returned as - the key of the dict, and the value of the dict is the name - of the returned activation (which the user can specify). - - Examples:: - - >>> m = torchvision.models.resnet18(pretrained=True) - >>> # extract layer1 and layer3, giving as names `feat1` and feat2` - >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, - >>> {'layer1': 'feat1', 'layer3': 'feat2'}) - >>> out = new_m(torch.rand(1, 3, 224, 224)) - >>> print([(k, v.shape) for k, v in out.items()]) - >>> [('feat1', torch.Size([1, 64, 56, 56])), - >>> ('feat2', torch.Size([1, 256, 14, 14]))] - """ - _version = 2 - __annotations__ = { - "return_layers": Dict[str, str], - } - - def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None: - if not set(return_layers).issubset([name for name, _ in model.named_children()]): - raise ValueError("return_layers are not present in model") - orig_return_layers = return_layers - return_layers = {str(k): str(v) for k, v in return_layers.items()} - layers = OrderedDict() - for name, module in model.named_children(): - layers[name] = module - if name in return_layers: - del return_layers[name] - if not return_layers: - break - - super(IntermediateLayerGetter, self).__init__(layers) - self.return_layers = orig_return_layers - - def forward(self, x): - out = OrderedDict() - for name, module in self.items(): - x = module(x) - if name in self.return_layers: - out_name = self.return_layers[name] - out[out_name] = x - return out - - class LeafModuleAwareTracer(fx.Tracer): """ An fx.Tracer that allows the user to specify a set of leaf modules, ie. @@ -100,26 +33,29 @@ def is_leaf_module(self, m: nn.Module, module_qualname: str) -> bool: class NodePathTracer(LeafModuleAwareTracer): """ NodePathTracer is an FX tracer that, for each operation, also records the - qualified name of the Node from which the operation originated. A - qualified name here is a `.` seperated path walking the hierarchy from top - level module down to leaf operation or leaf module. The name of the top - level module is not included as part of the qualified name. For example, - if we trace a module whose forward method applies a ReLU module, the - qualified name for that node will simply be 'relu'. + name of the Node from which the operation originated. A node name here is + a `.` seperated path walking the hierarchy from top level module down to + leaf operation or leaf module. The name of the top level module is not + included as part of the node name. For example, if we trace a module whose + forward method applies a ReLU module, the name for that node will simply + be 'relu'. Some notes on the specifics: - Nodes are recorded to `self.node_to_qualname` which is a dictionary - mapping a given Node object to its qualified name. + mapping a given Node object to its node name. - Nodes are recorded in the order which they are executed during tracing. - - When a duplicate qualified name is encountered, a suffix of the form + - When a duplicate node name is encountered, a suffix of the form _{int} is added. The counter starts from 1. """ def __init__(self, *args, **kwargs): super(NodePathTracer, self).__init__(*args, **kwargs) # Track the qualified name of the Node being traced self.current_module_qualname = '' - # A map from FX Node to the qualified name + # A map from FX Node to the qualified name\# + # NOTE: This is loosely like the "qualified name" mentioned in the + # torch.fx docs https://pytorch.org/docs/stable/fx.html but adapted + # for the purposes of the torchvision feature extractor self.node_to_qualname = OrderedDict() def call_module(self, m: torch.nn.Module, forward: Callable, args, kwargs): @@ -198,7 +134,7 @@ def _warn_graph_differences( train_tracer: NodePathTracer, eval_tracer: NodePathTracer): """ Utility function for warning the user if there are differences between - the train graph and the eval graph. + the train graph nodes and the eval graph nodes. """ train_nodes = list(train_tracer.node_to_qualname.values()) eval_nodes = list(eval_tracer.node_to_qualname.values()) @@ -223,51 +159,54 @@ def _warn_graph_differences( warnings.warn(msg + suggestion_msg) -def print_graph_node_qualified_names( - model: nn.Module, tracer_kwargs: Dict = {}): +def get_graph_node_names( + model: nn.Module, tracer_kwargs: Dict = {}, + suppress_diff_warning: bool = False) -> Tuple[List[str], List[str]]: """ - Dev utility to prints nodes in order of execution. Useful for choosing - nodes for a FeatureGraphNet design. There are two reasons that qualified + Dev utility to return node names in order of execution. See note on node + names under :func:`build_feature_extractor`. Useful for seeing which node + names are available for feature extraction. There are two reasons that node names can't easily be read directly from the code for a model: 1. Not all submodules are traced through. Modules from `torch.nn` all fall within this category. - 2. Node qualified names that occur more than once in the graph get a - `_{counter}` postfix. + 2. Nodes representing the repeated application of the same operation + or leaf module get a `_{counter}` postfix. - The model is traced twice: once in train mode, and once in eval mode. - If there are discrepancies between the graphs produced, both sets of nodes - will be printed and the user will be warned. + The model is traced twice: once in train mode, and once in eval mode. Both + sets of nodes are returned. Args: model (nn.Module): model for which we'd like to print node names tracer_kwargs (Dict): a dictionary of keywork arguments for `NodePathTracer` (which in turn passes them onto it's parent class `torch.fx.Tracer`). + suppress_diff_warning (bool): whether to suppress a warning when there + are discrepancies between the train and eval version of the graph. + Defaults to False. + + Returns: + Tuple[List[str], List[str]]: a list of node names from tracing the + model in train mode, and another from tracing the model in eval + mode. Examples:: >>> model = torchvision.models.resnet18() - >>> print_graph_node_qualified_names(model) + >>> train_nodes, eval_nodes = get_graph_node_names(model) """ + is_training = model.training train_tracer = NodePathTracer(**tracer_kwargs) train_tracer.trace(model.train()) eval_tracer = NodePathTracer(**tracer_kwargs) eval_tracer.trace(model.eval()) train_nodes = list(train_tracer.node_to_qualname.values()) eval_nodes = list(eval_tracer.node_to_qualname.values()) - if len(train_nodes) == len(eval_nodes) and all( - t == e for t, e in zip(train_nodes, eval_nodes)): - # Nodes are aligned in train vs eval mode - pprint(list(train_tracer.node_to_qualname.values())) - return - print("Nodes from train mode:") - pprint(list(train_tracer.node_to_qualname.values())) - print() - print("Nodes from eval mode:") - pprint(list(eval_tracer.node_to_qualname.values())) - print() - _warn_graph_differences(train_tracer, eval_tracer) + if not suppress_diff_warning: + _warn_graph_differences(train_tracer, eval_tracer) + # Restore training state + model.train(is_training) + return train_nodes, eval_nodes class DualGraphModule(fx.GraphModule): @@ -334,12 +273,13 @@ def train(self, mode=True): return super().train(mode=mode) -def build_feature_graph_net( +def build_feature_extractor( model: nn.Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, eval_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, - tracer_kwargs: Dict = {}) -> fx.GraphModule: + tracer_kwargs: Dict = {}, + suppress_diff_warning: bool = False) -> fx.GraphModule: """ Creates a new graph module that returns intermediate nodes from a given model as dictionary with user specified keys as strings, and the requested @@ -347,18 +287,14 @@ def build_feature_graph_net( the model via FX to return the desired nodes as outputs. All unused nodes are removed, together with their corresponding parameters. - A note on node specification: A node qualified name is specified as a `.` - seperated path walking the hierarchy from top level module down to leaf - operation or leaf module. For instance `blocks.5.3.bn1`. The keys of the - `return_nodes` argument should point to either a node's qualified name, - or some truncated version of it. For example, one could provide `blocks.5` - as a key, and the last node with that prefix will be selected. - :func:`print_graph_node_qualified_names` is a useful helper function - for getting a list of qualified node names of a model. - - An attempt is made to keep all non-parametric properties of the original - model, but existing properties of the constructed `GraphModule` are not - overwritten. + A note on node specification: For the purposes of this feature extraction + utility, a node name is specified as a `.` seperated path walking the + hierarchy from top level module down to leaf operation or leaf module. For + instance `blocks.5.3.bn1`. The keys of the `return_nodes` argument should + point to either a node's name, or some truncated version of it. For + example, one could provide `blocks.5` as a key, and the last node with + that prefix will be selected. :func:`get_graph_node_names` is a useful + helper function for getting a list of node names of a model. Not all models will be FX traceable, although with some massaging they can be made to cooperate. Here's a (not exhaustive) list of tips: @@ -372,7 +308,7 @@ def build_feature_graph_net( list of `autowrap_functions` as one of the `tracer_kwargs` (see example below). - Some inbuilt Python functions can be problematic. For instance, - `int` will raise an error during tracig. You may wrap them in your + `int` will raise an error during tracing. You may wrap them in your own function and then pass that in `autowrap_functions` as one of the `tracer_kwargs`. - Use `torch.matmul(tensor_a, tensor_b)` instead of `tensor_a @ @@ -386,7 +322,7 @@ def build_feature_graph_net( return_nodes (Optional[Union[List[str], Dict[str, str]]]): either a `List` or a `Dict` containing the names (or partial names - see note above) of the nodes for which the activations will be returned. If it is - a `Dict`, the keys are the qualified node names, and the values + a `Dict`, the keys are the node names, and the values are the user-specified keys for the graph module's returned dictionary. If it is a `List`, it is treated as a `Dict` mapping node specification strings directly to output names. In the case @@ -405,15 +341,18 @@ def build_feature_graph_net( tracer_kwargs (Dict): a dictionary of keywork arguments for `NodePathTracer` (which passes them onto it's parent class `torch.fx.Tracer`). + suppress_diff_warning (bool): whether to suppress a warning when there + are discrepancies between the train and eval version of the graph. + Defaults to False. Examples:: >>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` - >>> graph_module = torchvision.models._utils.build_feature_graph_net(m, - >>> {'layer1': 'feat1', 'layer3': 'feat2'}) - >>> out = graph_module(torch.rand(1, 3, 224, 224)) + >>> model = torchvision.models._utils.build_feature_extractor( + >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) + >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) >>> [('feat1', torch.Size([1, 64, 56, 56])), >>> ('feat2', torch.Size([1, 256, 14, 14]))] @@ -440,7 +379,7 @@ def build_feature_graph_net( >>> x = self.conv(x) >>> return self.leaf_module(x) >>> - >>> model = build_feature_graph_net( + >>> model = build_feature_extractor( >>> MyModule(), return_nodes=['leaf_module'], >>> tracer_kwargs={'leaf_modules': [LeafModule], >>> 'autowrap_functions': [leaf_function]}) @@ -496,16 +435,19 @@ def to_strdict(n) -> Dict[str, str]: model, nn.Module) else model.__name__ graph_module = fx.GraphModule(tracer.root, graph, name) - available_nodes = [f'{v}.{k}' for k, v in tracer.node_to_qualname.items()] + available_nodes = list(tracer.node_to_qualname.values()) # FIXME We don't know if we should expect this to happen assert len(set(available_nodes)) == len(available_nodes), \ "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" # Check that all outputs in return_nodes are present in the model for query in mode_return_nodes[mode].keys(): - if not any([m.startswith(query) for m in available_nodes]): + # To check if a query is available we need to check that at least + # one of the available names starts with it up to a . + if not any([re.match(rf'^{query}.?', n) is not None + for n in available_nodes]): raise ValueError( f"node: '{query}' is not present in model. Hint: use " - "`print_graph_node_qualified_names` to make sure the " + "`get_graph_node_names` to make sure the " "`return_nodes` you specified are present. It may even " "be that you need to specify `train_return_nodes` and " "`eval_return_nodes` seperately.") @@ -513,7 +455,7 @@ def to_strdict(n) -> Dict[str, str]: # Remove existing output nodes (train mode) orig_output_nodes = [] for n in reversed(graph_module.graph.nodes): - if n.op == "output": + if n.op == 'output': orig_output_nodes.append(n) assert len(orig_output_nodes) for n in orig_output_nodes: @@ -523,13 +465,13 @@ def to_strdict(n) -> Dict[str, str]: nodes = [n for n in graph_module.graph.nodes] output_nodes = OrderedDict() for n in reversed(nodes): - if 'tensor_constant' in str(n): - # NOTE Without this control flow we would get a None value for - # `module_qualname = tracer.node_to_qualname.get(n)`. - # On the other hand, we can safely assume that we'll never need to - # get this as an interesting intermediate node. - continue module_qualname = tracer.node_to_qualname.get(n) + if module_qualname is None: + # NOTE - Know cases where this happens: + # - Node representing creation of a tensor constant - probably + # not interesting as a return node + # - When packing outputs into a named tuple like in InceptionV3 + continue for query in mode_return_nodes[mode]: depth = query.count('.') if '.'.join(module_qualname.split('.')[:depth + 1]) == query: @@ -552,21 +494,13 @@ def to_strdict(n) -> Dict[str, str]: # Warn user if there are any discrepancies between the graphs of the # train and eval modes - _warn_graph_differences(tracers['train'], tracers['eval']) + if not suppress_diff_warning: + _warn_graph_differences(tracers['train'], tracers['eval']) # Build the final graph module graph_module = DualGraphModule( model, graphs['train'], graphs['eval'], class_name=name) - # Keep non-parameter model properties for reference - for attr_str in model.__dir__(): - attr = getattr(model, attr_str) - if (not attr_str.startswith('_') - and attr_str not in graph_module.__dir__() - and not ismethod(attr) - and not isinstance(attr, (nn.Module, nn.Parameter))): - setattr(graph_module, attr_str, attr) - # Restore original training mode model.train(is_training) graph_module.train(is_training) diff --git a/torchvision/models/segmentation/segmentation.py b/torchvision/models/segmentation/segmentation.py index 426f0e703fe..0f2f14c97ba 100644 --- a/torchvision/models/segmentation/segmentation.py +++ b/torchvision/models/segmentation/segmentation.py @@ -1,4 +1,4 @@ -from ..feature_extraction import IntermediateLayerGetter +from .._utils import IntermediateLayerGetter from ..._internally_replaced_utils import load_state_dict_from_url from .. import mobilenetv3 from .. import resnet From d4efb7d3b31be99faa4c23f8b41d3efe2740748e Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 3 Sep 2021 14:41:34 +0100 Subject: [PATCH 08/18] undo line spacing changes --- torchvision/models/_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 8cd5315f5b7..df5ab9a044c 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -7,20 +7,25 @@ class IntermediateLayerGetter(nn.ModuleDict): """ Module wrapper that returns intermediate layers from a model + It has a strong assumption that the modules have been registered into the model in the same order as they are used. This means that one should **not** reuse the same nn.Module twice in the forward if you want this to work. + Additionally, it is only able to query submodules that are directly assigned to the model. So if `model` is passed, `model.feature1` can be returned, but not `model.feature1.layer2`. + Args: model (nn.Module): model on which we will extract the features return_layers (Dict[name, new_name]): a dict containing the names of the modules for which the activations will be returned as the key of the dict, and the value of the dict is the name of the returned activation (which the user can specify). + Examples:: + >>> m = torchvision.models.resnet18(pretrained=True) >>> # extract layer1 and layer3, giving as names `feat1` and feat2` >>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, From d6a834eae7fd6e9d6436cb9081a79c28714e77cb Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 3 Sep 2021 14:51:39 +0100 Subject: [PATCH 09/18] change type hints in docstrings --- torchvision/models/feature_extraction.py | 31 ++++++++++++------------ 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index 3474b8d3ecf..9d18fa1db5d 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -178,17 +178,16 @@ def get_graph_node_names( Args: model (nn.Module): model for which we'd like to print node names - tracer_kwargs (Dict): a dictionary of keywork arguments for - `NodePathTracer` (which in turn passes them onto it's parent class + tracer_kwargs (dict): a dictionary of keywork arguments for + `NodePathTracer` (they are eventually passed onto `torch.fx.Tracer`). suppress_diff_warning (bool): whether to suppress a warning when there are discrepancies between the train and eval version of the graph. Defaults to False. Returns: - Tuple[List[str], List[str]]: a list of node names from tracing the - model in train mode, and another from tracing the model in eval - mode. + tuple(list, list): a list of node names from tracing the model in + train mode, and another from tracing the model in eval mode. Examples:: @@ -223,10 +222,10 @@ def __init__(self, class_name: str = 'GraphModule'): """ Args: - root (torch.nn.Module): module from which the copied module - hierarchy is built - train_graph (Graph): the graph that should be used in train mode - eval_graph (Graph): the graph that should be used in eval mode + root (nn.Module): module from which the copied module hierarchy is + built + train_graph (fx.Graph): the graph that should be used in train mode + eval_graph (fx.Graph): the graph that should be used in eval mode """ super(fx.GraphModule, self).__init__() @@ -319,8 +318,8 @@ def build_feature_extractor( Args: model (nn.Module): model on which we will extract the features - return_nodes (Optional[Union[List[str], Dict[str, str]]]): either a `List` - or a `Dict` containing the names (or partial names - see note above) + return_nodes (list or dict, optional): either a `List` or a `Dict` + containing the names (or partial names - see note above) of the nodes for which the activations will be returned. If it is a `Dict`, the keys are the node names, and the values are the user-specified keys for the graph module's returned @@ -328,17 +327,17 @@ def build_feature_extractor( node specification strings directly to output names. In the case that `train_return_nodes` and `eval_return_nodes` are specified, this should not be specified. - train_return_nodes (Optional[Union[List[str], Dict[str, str]]]): - similar to `return_nodes`. This can be used if the return nodes + train_return_nodes (list or dict, optional): similar to + `return_nodes`. This can be used if the return nodes for train mode are different than those from eval mode. If this is specified, `eval_return_nodes` must also be specified, and `return_nodes` should not be specified. - eval_return_nodes (Optional[Union[List[str], Dict[str, str]]]): - similar to `return_nodes`. This can be used if the return nodes + eval_return_nodes (list or dict, optional): similar to + `return_nodes`. This can be used if the return nodes for train mode are different than those from eval mode. If this is specified, `train_return_nodes` must also be specified, and `return_nodes` should not be specified. - tracer_kwargs (Dict): a dictionary of keywork arguments for + tracer_kwargs (dict): a dictionary of keywork arguments for `NodePathTracer` (which passes them onto it's parent class `torch.fx.Tracer`). suppress_diff_warning (bool): whether to suppress a warning when there From 8348b7b4b2b9c533124c440b3bbf5ae9421ade94 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Fri, 3 Sep 2021 15:08:46 +0100 Subject: [PATCH 10/18] fix sphinx indentation --- torchvision/models/feature_extraction.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index 9d18fa1db5d..0e6ebaf545d 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -187,7 +187,7 @@ def get_graph_node_names( Returns: tuple(list, list): a list of node names from tracing the model in - train mode, and another from tracing the model in eval mode. + train mode, and another from tracing the model in eval mode. Examples:: From fc831e0729fa9e8712ba9d59d494d7e646c1113a Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 4 Sep 2021 16:09:43 +0100 Subject: [PATCH 11/18] expose feature_extraction --- docs/source/index.rst | 2 +- torchvision/models/__init__.py | 1 + torchvision/models/feature_extraction.py | 5 ++++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 5bd01c4242c..3e02cd34ad4 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -32,9 +32,9 @@ architectures, and common image transformations for computer vision. :caption: Package Reference datasets - feature_extraction io models + feature_extraction ops transforms utils diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 283e544e98e..6d2680f1a95 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -12,3 +12,4 @@ from . import detection from . import video from . import quantization +from .feature_extraction import * diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index 0e6ebaf545d..e75316babf8 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -11,6 +11,9 @@ from torch.fx.graph_module import _copy_attr +__all__ = ['build_feature_extractor', 'get_graph_node_names'] + + class LeafModuleAwareTracer(fx.Tracer): """ An fx.Tracer that allows the user to specify a set of leaf modules, ie. @@ -349,7 +352,7 @@ def build_feature_extractor( >>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` - >>> model = torchvision.models._utils.build_feature_extractor( + >>> model = build_feature_extractor( >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) From 82fea80c3f797a08cba157bac4c2329f6144b148 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 4 Sep 2021 18:04:09 +0100 Subject: [PATCH 12/18] add maskrcnn example --- docs/source/feature_extraction.rst | 97 ++++++++++++++++++++++++++++-- 1 file changed, 93 insertions(+), 4 deletions(-) diff --git a/docs/source/feature_extraction.rst b/docs/source/feature_extraction.rst index a235765d396..7dec3a40855 100644 --- a/docs/source/feature_extraction.rst +++ b/docs/source/feature_extraction.rst @@ -21,13 +21,102 @@ It works by following roughly these steps: how it transforms the input, step by step. 2. Setting the user-selected graph nodes as ouputs. 3. Removing all redundant nodes (anything downstream of the ouput nodes). -4. Generating python code from the resulting graph and bundling that, together - with the graph, and bundling that into a PyTorch module. +4. Generating python code from the resulting graph and bundling that into a + PyTorch module together with the graph itself. | -See `torch.fx documentation `_ for -more information on symbolic tracing. +The `torch.fx documentation `_ +provides a more general and detailed explanation of the above procedure and +the inner workings of the symbolic tracing. + +Here is an example of how we might extract features for MaskRCNN: + +.. code-block:: python + + import torch + from torchvision.models import resnet50 + from torchvision.models.feature_extraction import get_graph_node_names + from torchvision.models.feature_extraction import build_feature_extractor + from torchvision.models.detection.mask_rcnn import MaskRCNN + from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork + + + # To assist you in designing the feature extractor you may want to print out + # the available nodes for resnet50. + m = resnet50() + train_nodes, eval_nodes = get_graph_node_names(resnet50()) + + # The lists returned, are the names of all the graph nodes (in order of + # execution) for the input model traced in train mode and in eval mode + # respectively. You'll find that `train_nodes` and `eval_nodes` are the same + # for this example. But if the model contains control flow that's dependent + # on the training mode, they may be different. + + # To specify the nodes you want to extract, you could select the final node + # that appears in each of the main layers: + return_nodes = { + # node_name: user-specified key for output dict + 'layer1.2.relu_2': 'layer1', + 'layer2.3.relu_2': 'layer2', + 'layer3.5.relu_2': 'layer3', + 'layer4.2.relu_2': 'layer4', + } + + # But `build_feature_extractor` can also accept truncated node specifications + # like "layer1", as it will just pick the last node that's a descendent of + # of the specification. (Tip: be careful with this, especially when a layer + # has multiple outputs. It's not always guaranteed that the last operation + # performed is the one that corresponds to the output you desire. You should + # consult the source code for the input model to confirm.) + return_nodes = { + 'layer1': 'layer1', + 'layer2': 'layer2', + 'layer3': 'layer3', + 'layer4': 'layer4', + } + + # Now you can build the feature extractor. This returns a module whose forward + # method returns a dictionary like: + # { + # 'layer1': ouput of layer 1, + # 'layer2': ouput of layer 2, + # 'layer3': ouput of layer 3, + # 'layer4': ouput of layer 4, + # } + build_feature_extractor(m, return_nodes=return_nodes) + + # Let's put all that together to wrap resnet50 with MaskRCNN + + # MaskRCNN requires a backbone with an attached FPN + class Resnet50WithFPN(torch.nn.Module): + def __init__(self): + super(Resnet50WithFPN, self).__init__() + # Get a resnet50 backbone + m = resnet50() + # Extract 4 main layers (note: you can also provide a list for return + # nodes if the keys and the values are the same) + self.body = build_feature_extractor( + m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4']) + # Dry run to get number of channels for FPN + inp = torch.randn(1, 3, 224, 224) + with torch.no_grad(): + out = self.body(inp) + in_channels_list = [o.shape[1] for o in out.values()] + # Build FPN + self.out_channels = 256 + self.fpn = FeaturePyramidNetwork( + in_channels_list, out_channels=self.out_channels) + + def forward(self, x): + x = self.body(x) + x = self.fpn(x) + return x + + + # Now we can build our model! + model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval() + .. autofunction:: build_feature_extractor From 8b51e041e76c5b7345b73ad0d57ff360e822df79 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 4 Sep 2021 18:07:25 +0100 Subject: [PATCH 13/18] add api refernce subheading --- docs/source/feature_extraction.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/feature_extraction.rst b/docs/source/feature_extraction.rst index 7dec3a40855..08e42f754b9 100644 --- a/docs/source/feature_extraction.rst +++ b/docs/source/feature_extraction.rst @@ -118,6 +118,9 @@ Here is an example of how we might extract features for MaskRCNN: model = MaskRCNN(Resnet50WithFPN(), num_classes=91).eval() +API Reference +------------- + .. autofunction:: build_feature_extractor .. autofunction:: get_graph_node_names \ No newline at end of file From d7591c19443d0a31ff130574c22b966b1af53a1a Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 6 Sep 2021 14:28:16 +0100 Subject: [PATCH 14/18] address latest review notes, refactor names, fix regex, cosmetics --- docs/source/feature_extraction.rst | 14 ++++---- test/test_backbone_utils.py | 42 +++++++++++++++--------- torchvision/models/__init__.py | 3 +- torchvision/models/feature_extraction.py | 32 +++++++++--------- 4 files changed, 49 insertions(+), 42 deletions(-) diff --git a/docs/source/feature_extraction.rst b/docs/source/feature_extraction.rst index 08e42f754b9..27dad21fa78 100644 --- a/docs/source/feature_extraction.rst +++ b/docs/source/feature_extraction.rst @@ -14,7 +14,7 @@ applications in computer vision. Just a few examples are: with a specific task in mind. For example, passing a hierarchy of features to a Feature Pyramid Network with object detection heads. -Torchvision provides :func:`build_feature_extractor` for this purpose. +Torchvision provides :func:`create_feature_extractor` for this purpose. It works by following roughly these steps: 1. Symbolically tracing the model to get a graphical representation of @@ -37,7 +37,7 @@ Here is an example of how we might extract features for MaskRCNN: import torch from torchvision.models import resnet50 from torchvision.models.feature_extraction import get_graph_node_names - from torchvision.models.feature_extraction import build_feature_extractor + from torchvision.models.feature_extraction import create_feature_extractor from torchvision.models.detection.mask_rcnn import MaskRCNN from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork @@ -63,7 +63,7 @@ Here is an example of how we might extract features for MaskRCNN: 'layer4.2.relu_2': 'layer4', } - # But `build_feature_extractor` can also accept truncated node specifications + # But `create_feature_extractor` can also accept truncated node specifications # like "layer1", as it will just pick the last node that's a descendent of # of the specification. (Tip: be careful with this, especially when a layer # has multiple outputs. It's not always guaranteed that the last operation @@ -84,7 +84,7 @@ Here is an example of how we might extract features for MaskRCNN: # 'layer3': ouput of layer 3, # 'layer4': ouput of layer 4, # } - build_feature_extractor(m, return_nodes=return_nodes) + create_feature_extractor(m, return_nodes=return_nodes) # Let's put all that together to wrap resnet50 with MaskRCNN @@ -96,10 +96,10 @@ Here is an example of how we might extract features for MaskRCNN: m = resnet50() # Extract 4 main layers (note: you can also provide a list for return # nodes if the keys and the values are the same) - self.body = build_feature_extractor( + self.body = create_feature_extractor( m, return_nodes=['layer1', 'layer2', 'layer3', 'layer4']) # Dry run to get number of channels for FPN - inp = torch.randn(1, 3, 224, 224) + inp = torch.randn(2, 3, 224, 224) with torch.no_grad(): out = self.body(inp) in_channels_list = [o.shape[1] for o in out.values()] @@ -121,6 +121,6 @@ Here is an example of how we might extract features for MaskRCNN: API Reference ------------- -.. autofunction:: build_feature_extractor +.. autofunction:: create_feature_extractor .. autofunction:: get_graph_node_names \ No newline at end of file diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index f48925e747e..26e50097ca1 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -1,10 +1,11 @@ from functools import partial +from itertools import chain import random import torch from torchvision import models from torchvision.models.detection.backbone_utils import resnet_fpn_backbone -from torchvision.models.feature_extraction import build_feature_extractor +from torchvision.models.feature_extraction import create_feature_extractor from torchvision.models.feature_extraction import get_graph_node_names from torchvision.models._utils import IntermediateLayerGetter @@ -13,9 +14,9 @@ from common_utils import set_rng_seed -# Suppress diff warning from build_feature_extractor -build_feature_extractor = partial( - build_feature_extractor, suppress_diff_warning=True) +# Suppress diff warning from create_feature_extractor +create_feature_extractor = partial( + create_feature_extractor, suppress_diff_warning=True) get_graph_node_names = partial( get_graph_node_names, suppress_diff_warning=True) @@ -63,30 +64,39 @@ def test_build_fx_feature_extractor(self, model_name): model = models.__dict__[model_name](**self.model_defaults).eval() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) # Check that it works with both a list and dict for return nodes - build_feature_extractor( + create_feature_extractor( model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes) - build_feature_extractor( + create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) # Check must specify return nodes with pytest.raises(AssertionError): - build_feature_extractor(model) + create_feature_extractor(model) # Check return_nodes and train_return_nodes / eval_return nodes # mutual exclusivity with pytest.raises(AssertionError): - build_feature_extractor(model, return_nodes=train_return_nodes, - train_return_nodes=train_return_nodes) + create_feature_extractor(model, return_nodes=train_return_nodes, + train_return_nodes=train_return_nodes) # Check train_return_nodes / eval_return nodes must both be specified with pytest.raises(AssertionError): - build_feature_extractor( + create_feature_extractor( model, train_return_nodes=train_return_nodes) + # Check invalid node name raises ValueError + with pytest.raises(ValueError): + # First just double check that this node really doesn't exist + if not any(n.startswith('l') or n.startswith('l.') for n + in chain(train_return_nodes, eval_return_nodes)): + create_feature_extractor( + model, train_return_nodes=['l'], eval_return_nodes=['l']) + else: # otherwise skip this check + raise ValueError @pytest.mark.parametrize('model_name', get_available_models()) def test_forward_backward(self, model_name): model = models.__dict__[model_name](**self.model_defaults).train() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) - model = build_feature_extractor( + model = create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) out = model(self.inp) @@ -103,7 +113,7 @@ def test_feature_extraction_methods_equivalence(self): ilg_model = IntermediateLayerGetter( model, return_layers).eval() - fx_model = build_feature_extractor(model, return_layers) + fx_model = create_feature_extractor(model, return_layers) # Check that we have same parameters for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), @@ -124,7 +134,7 @@ def test_jit_forward_backward(self, model_name): set_rng_seed(0) model = models.__dict__[model_name](**self.model_defaults).train() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) - model = build_feature_extractor( + model = create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) model = torch.jit.script(model) @@ -172,7 +182,7 @@ def checks(model, mode): # Starting from train mode model.train() - fx_model = build_feature_extractor( + fx_model = create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) # Check that the models stay in their original training state @@ -186,7 +196,7 @@ def checks(model, mode): # Starting from eval mode model.eval() - fx_model = build_feature_extractor( + fx_model = create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) # Check that the models stay in their original training state @@ -216,7 +226,7 @@ def forward(self, x): x = self.conv(x) return self.leaf_module(x) - model = build_feature_extractor( + model = create_feature_extractor( TestModule(), return_nodes=['leaf_module'], tracer_kwargs={'leaf_modules': [LeafModule], 'autowrap_functions': [leaf_function]}).train() diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 4a87eaf2750..bbd0699fac0 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -8,9 +8,8 @@ from .mobilenet import * from .mnasnet import * from .shufflenetv2 import * -from .efficientnet import * from . import segmentation from . import detection from . import video from . import quantization -from .feature_extraction import * +from . import feature_extraction diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index e75316babf8..7d5d184cd97 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -11,7 +11,7 @@ from torch.fx.graph_module import _copy_attr -__all__ = ['build_feature_extractor', 'get_graph_node_names'] +__all__ = ['create_feature_extractor', 'get_graph_node_names'] class LeafModuleAwareTracer(fx.Tracer): @@ -167,7 +167,7 @@ def get_graph_node_names( suppress_diff_warning: bool = False) -> Tuple[List[str], List[str]]: """ Dev utility to return node names in order of execution. See note on node - names under :func:`build_feature_extractor`. Useful for seeing which node + names under :func:`create_feature_extractor`. Useful for seeing which node names are available for feature extraction. There are two reasons that node names can't easily be read directly from the code for a model: @@ -181,12 +181,12 @@ def get_graph_node_names( Args: model (nn.Module): model for which we'd like to print node names - tracer_kwargs (dict): a dictionary of keywork arguments for + tracer_kwargs (dict, optional): a dictionary of keywork arguments for `NodePathTracer` (they are eventually passed onto `torch.fx.Tracer`). - suppress_diff_warning (bool): whether to suppress a warning when there - are discrepancies between the train and eval version of the graph. - Defaults to False. + suppress_diff_warning (bool, optional): whether to suppress a warning + when there are discrepancies between the train and eval version of + the graph. Defaults to False. Returns: tuple(list, list): a list of node names from tracing the model in @@ -275,7 +275,7 @@ def train(self, mode=True): return super().train(mode=mode) -def build_feature_extractor( +def create_feature_extractor( model: nn.Module, return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, train_return_nodes: Optional[Union[List[str], Dict[str, str]]] = None, @@ -313,8 +313,6 @@ def build_feature_extractor( `int` will raise an error during tracing. You may wrap them in your own function and then pass that in `autowrap_functions` as one of the `tracer_kwargs`. - - Use `torch.matmul(tensor_a, tensor_b)` instead of `tensor_a @ - tensor_b`. For further information on FX see the `torch.fx documentation `_. @@ -340,19 +338,19 @@ def build_feature_extractor( for train mode are different than those from eval mode. If this is specified, `train_return_nodes` must also be specified, and `return_nodes` should not be specified. - tracer_kwargs (dict): a dictionary of keywork arguments for + tracer_kwargs (dict, optional): a dictionary of keywork arguments for `NodePathTracer` (which passes them onto it's parent class `torch.fx.Tracer`). - suppress_diff_warning (bool): whether to suppress a warning when there - are discrepancies between the train and eval version of the graph. - Defaults to False. + suppress_diff_warning (bool, optional): whether to suppress a warning + when there are discrepancies between the train and eval version of + the graph. Defaults to False. Examples:: >>> # Feature extraction with resnet >>> model = torchvision.models.resnet18() >>> # extract layer1 and layer3, giving as names `feat1` and feat2` - >>> model = build_feature_extractor( + >>> model = create_feature_extractor( >>> model, {'layer1': 'feat1', 'layer3': 'feat2'}) >>> out = model(torch.rand(1, 3, 224, 224)) >>> print([(k, v.shape) for k, v in out.items()]) @@ -381,7 +379,7 @@ def build_feature_extractor( >>> x = self.conv(x) >>> return self.leaf_module(x) >>> - >>> model = build_feature_extractor( + >>> model = create_feature_extractor( >>> MyModule(), return_nodes=['leaf_module'], >>> tracer_kwargs={'leaf_modules': [LeafModule], >>> 'autowrap_functions': [leaf_function]}) @@ -445,14 +443,14 @@ def to_strdict(n) -> Dict[str, str]: for query in mode_return_nodes[mode].keys(): # To check if a query is available we need to check that at least # one of the available names starts with it up to a . - if not any([re.match(rf'^{query}.?', n) is not None + if not any([re.match(rf'^{query}(\.|$)', n) is not None for n in available_nodes]): raise ValueError( f"node: '{query}' is not present in model. Hint: use " "`get_graph_node_names` to make sure the " "`return_nodes` you specified are present. It may even " "be that you need to specify `train_return_nodes` and " - "`eval_return_nodes` seperately.") + "`eval_return_nodes` separately.") # Remove existing output nodes (train mode) orig_output_nodes = [] From 2d6cdfd5be3c940dc6918ffc5a1b36d4b1e42800 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 6 Sep 2021 16:01:10 +0200 Subject: [PATCH 15/18] Add back efficientnet to models --- torchvision/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index bbd0699fac0..e57f4773c8c 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -8,6 +8,7 @@ from .mobilenet import * from .mnasnet import * from .shufflenetv2 import * +from .efficientnet import * from . import segmentation from . import detection from . import video From 31f86f67afb5716b81c76e46d11da0b5ccc68705 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 6 Sep 2021 15:21:55 +0100 Subject: [PATCH 16/18] fix tests for effnet --- test/test_backbone_utils.py | 38 ++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 26e50097ca1..0dfbf980186 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -4,6 +4,7 @@ import torch from torchvision import models +import torchvision from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.feature_extraction import create_feature_extractor from torchvision.models.feature_extraction import get_graph_node_names @@ -45,11 +46,22 @@ class TestFxFeatureExtraction: 'num_classes': 1, 'pretrained': False } + leaf_modules = [torchvision.ops.StochasticDepth] + + def _create_feature_extractor(self, *args, **kwargs): + """ + Apply leaf modules + """ + return create_feature_extractor( + *args, **kwargs, + tracer_kwargs={'leaf_modules': self.leaf_modules}, + suppress_diff_warning=True) def _get_return_nodes(self, model): set_rng_seed(0) exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk'] - train_nodes, eval_nodes = get_graph_node_names(model) + train_nodes, eval_nodes = get_graph_node_names( + model, tracer_kwargs={'leaf_modules': self.leaf_modules}) # Get rid of any nodes that don't return tensors as they cause issues # when testing backward pass. train_nodes = [n for n in train_nodes @@ -64,30 +76,30 @@ def test_build_fx_feature_extractor(self, model_name): model = models.__dict__[model_name](**self.model_defaults).eval() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) # Check that it works with both a list and dict for return nodes - create_feature_extractor( + self._create_feature_extractor( model, train_return_nodes={v: v for v in train_return_nodes}, eval_return_nodes=eval_return_nodes) - create_feature_extractor( + self._create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) # Check must specify return nodes with pytest.raises(AssertionError): - create_feature_extractor(model) + self._create_feature_extractor(model) # Check return_nodes and train_return_nodes / eval_return nodes # mutual exclusivity with pytest.raises(AssertionError): - create_feature_extractor(model, return_nodes=train_return_nodes, + self._create_feature_extractor(model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes) # Check train_return_nodes / eval_return nodes must both be specified with pytest.raises(AssertionError): - create_feature_extractor( + self._create_feature_extractor( model, train_return_nodes=train_return_nodes) # Check invalid node name raises ValueError with pytest.raises(ValueError): # First just double check that this node really doesn't exist if not any(n.startswith('l') or n.startswith('l.') for n in chain(train_return_nodes, eval_return_nodes)): - create_feature_extractor( + self._create_feature_extractor( model, train_return_nodes=['l'], eval_return_nodes=['l']) else: # otherwise skip this check raise ValueError @@ -96,7 +108,7 @@ def test_build_fx_feature_extractor(self, model_name): def test_forward_backward(self, model_name): model = models.__dict__[model_name](**self.model_defaults).train() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) - model = create_feature_extractor( + model = self._create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) out = model(self.inp) @@ -113,7 +125,7 @@ def test_feature_extraction_methods_equivalence(self): ilg_model = IntermediateLayerGetter( model, return_layers).eval() - fx_model = create_feature_extractor(model, return_layers) + fx_model = self._create_feature_extractor(model, return_layers) # Check that we have same parameters for (n1, p1), (n2, p2) in zip(ilg_model.named_parameters(), @@ -134,7 +146,7 @@ def test_jit_forward_backward(self, model_name): set_rng_seed(0) model = models.__dict__[model_name](**self.model_defaults).train() train_return_nodes, eval_return_nodes = self._get_return_nodes(model) - model = create_feature_extractor( + model = self._create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) model = torch.jit.script(model) @@ -182,7 +194,7 @@ def checks(model, mode): # Starting from train mode model.train() - fx_model = create_feature_extractor( + fx_model = self._create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) # Check that the models stay in their original training state @@ -196,7 +208,7 @@ def checks(model, mode): # Starting from eval mode model.eval() - fx_model = create_feature_extractor( + fx_model = self._create_feature_extractor( model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes) # Check that the models stay in their original training state @@ -226,7 +238,7 @@ def forward(self, x): x = self.conv(x) return self.leaf_module(x) - model = create_feature_extractor( + model = self._create_feature_extractor( TestModule(), return_nodes=['leaf_module'], tracer_kwargs={'leaf_modules': [LeafModule], 'autowrap_functions': [leaf_function]}).train() From a4973f8605e7687387561b33265dbf928fe86d22 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 6 Sep 2021 15:25:47 +0100 Subject: [PATCH 17/18] fix linting issue --- test/test_backbone_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 0dfbf980186..04c51dbe1a4 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -88,8 +88,9 @@ def test_build_fx_feature_extractor(self, model_name): # Check return_nodes and train_return_nodes / eval_return nodes # mutual exclusivity with pytest.raises(AssertionError): - self._create_feature_extractor(model, return_nodes=train_return_nodes, - train_return_nodes=train_return_nodes) + self._create_feature_extractor( + model, return_nodes=train_return_nodes, + train_return_nodes=train_return_nodes) # Check train_return_nodes / eval_return nodes must both be specified with pytest.raises(AssertionError): self._create_feature_extractor( From e8fec3393a1a75ab021e4924a2f1fb1a36afe130 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Mon, 6 Sep 2021 15:46:40 +0100 Subject: [PATCH 18/18] fix test tracer kwargs --- test/test_backbone_utils.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index 04c51dbe1a4..d85128a46d0 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -15,13 +15,6 @@ from common_utils import set_rng_seed -# Suppress diff warning from create_feature_extractor -create_feature_extractor = partial( - create_feature_extractor, suppress_diff_warning=True) -get_graph_node_names = partial( - get_graph_node_names, suppress_diff_warning=True) - - def get_available_models(): # TODO add a registration mechanism to torchvision.models return [k for k, v in models.__dict__.items() @@ -52,16 +45,22 @@ def _create_feature_extractor(self, *args, **kwargs): """ Apply leaf modules """ + tracer_kwargs = {} + if 'tracer_kwargs' not in kwargs: + tracer_kwargs = {'leaf_modules': self.leaf_modules} + else: + tracer_kwargs = kwargs.pop('tracer_kwargs') return create_feature_extractor( *args, **kwargs, - tracer_kwargs={'leaf_modules': self.leaf_modules}, + tracer_kwargs=tracer_kwargs, suppress_diff_warning=True) def _get_return_nodes(self, model): set_rng_seed(0) exclude_nodes_filter = ['getitem', 'floordiv', 'size', 'chunk'] train_nodes, eval_nodes = get_graph_node_names( - model, tracer_kwargs={'leaf_modules': self.leaf_modules}) + model, tracer_kwargs={'leaf_modules': self.leaf_modules}, + suppress_diff_warning=True) # Get rid of any nodes that don't return tensors as they cause issues # when testing backward pass. train_nodes = [n for n in train_nodes