-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[WIP/NO_MERGE] Prototype RegularizedShortcut #4549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
ae2de45
Adding first implementation of the util.
datumbox e0b6aa4
Update the model graph.
datumbox b2faf3f
Adding delete method.
datumbox 74b005c
Fixing linter.
datumbox 5684459
Fixing types.
datumbox 3ecf8aa
Restoring break and moving lint.
datumbox 21c45ca
Merge branch 'main' into prototype/regularized_shortcut
datumbox 51bf33f
Allow delete to remove custom ops.
datumbox c6fbe63
Minor refactoring
datumbox da3a5e3
Pass model names.
datumbox File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import copy | ||
import operator | ||
import warnings | ||
from typing import Callable, Optional, Tuple, Union | ||
|
||
import torch | ||
from torch import fx | ||
from torchvision.models.feature_extraction import LeafModuleAwareTracer | ||
|
||
|
||
# TODO: Investigate what happens in the scenario of y = x + f1(x) + f2(x). | ||
|
||
|
||
class RegularizedShortcut(torch.nn.Module): | ||
def __init__(self, regularizer_layer: Callable[..., torch.nn.Module]): | ||
super().__init__() | ||
self._regularizer = regularizer_layer() | ||
|
||
def forward(self, input, result): | ||
return input + self._regularizer(result) | ||
|
||
|
||
def add_regularized_shortcut( | ||
model: torch.nn.Module, | ||
block_types: Union[type, Tuple[type, ...]], | ||
regularizer_layer: Callable[..., torch.nn.Module], | ||
inplace: bool = True, | ||
) -> torch.nn.Module: | ||
if not inplace: | ||
model = copy.deepcopy(model) | ||
|
||
reg_name = RegularizedShortcut.__name__.lower() | ||
tracer = fx.Tracer() | ||
modifications = {} | ||
for name, m in model.named_modules(): | ||
if isinstance(m, block_types): | ||
# Add the Layer directly on submodule prior tracing | ||
# workaround due to https://github.com/pytorch/pytorch/issues/66197 | ||
m.add_module(reg_name, RegularizedShortcut(regularizer_layer)) | ||
|
||
graph = tracer.trace(m) | ||
patterns = {operator.add, torch.add, "add"} | ||
|
||
input = None | ||
for node in graph.nodes: | ||
if node.op == "call_function": | ||
if node.target in patterns and len(node.args) == 2 and input in node.args: | ||
# TODO: ensure the arg2 has "input" as its ancestor | ||
with graph.inserting_after(node): | ||
# Always put the shortcut value first | ||
args = node.args if node.args[0] == input else node.args[::-1] | ||
node.replace_all_uses_with(graph.call_module(reg_name, args)) | ||
graph.erase_node(node) | ||
modifications[name] = graph | ||
break | ||
elif node.op == "placeholder": | ||
input = node | ||
|
||
if modifications: | ||
# Update the model by overwriting its modules | ||
for name, graph in modifications.items(): | ||
graph.lint() | ||
parent_name, child_name = name.rsplit(".", 1) | ||
parent = model.get_submodule(parent_name) | ||
previous_child = parent.get_submodule(child_name) | ||
new_child = fx.GraphModule(previous_child, graph, previous_child.__class__.__name__) | ||
parent.register_module(child_name, new_child) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
warnings.warn( | ||
"No shortcut was detected. Please ensure you have provided the correct `block_types` parameter " | ||
"for this model." | ||
) | ||
|
||
return model | ||
|
||
|
||
def del_regularized_shortcut( | ||
model: torch.nn.Module, | ||
block_types: Union[type, Tuple[type, ...]] = RegularizedShortcut, | ||
op: Optional[Callable] = operator.add, | ||
inplace: bool = True, | ||
) -> torch.nn.Module: | ||
if isinstance(block_types, type): | ||
block_types = (block_types,) | ||
if not inplace: | ||
model = copy.deepcopy(model) | ||
|
||
tracer = LeafModuleAwareTracer(leaf_modules=block_types) | ||
graph = tracer.trace(model) | ||
for node in graph.nodes: | ||
# The isinstance() won't work if the model has already been traced before because it loses | ||
# the class info of submodules. See https://github.com/pytorch/pytorch/issues/66335 | ||
if node.op == "call_module" and isinstance(model.get_submodule(node.target), block_types): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jamesr66a We just figured out that FX traced models lose their submodule class information. This means that for a model that has been traced before, we can't use |
||
if op is not None: | ||
with graph.inserting_before(node): | ||
new_node = graph.call_function(op, node.args) | ||
node.replace_all_uses_with(new_node) | ||
else: | ||
if len(node.args) == 1: | ||
node.replace_all_uses_with(node.prev) | ||
else: | ||
raise ValueError("Can't eliminate an operator that receives more than 1 arguments.") | ||
graph.erase_node(node) | ||
|
||
return fx.GraphModule(model, graph, model.__class__.__name__) | ||
|
||
|
||
if __name__ == "__main__": | ||
from functools import partial | ||
|
||
from torchvision.models.resnet import resnet50, BasicBlock, Bottleneck | ||
from torchvision.ops.stochastic_depth import StochasticDepth | ||
|
||
out = [] | ||
batch = torch.randn((7, 3, 224, 224)) | ||
|
||
print("Before") | ||
model = resnet50() | ||
with torch.no_grad(): | ||
out.append(model(batch)) | ||
fx.symbolic_trace(model).graph.print_tabular() | ||
|
||
print("After addition") | ||
regularizer_layer = partial(StochasticDepth, p=0.0, mode="row") | ||
model = add_regularized_shortcut(model, (BasicBlock, Bottleneck), regularizer_layer) | ||
fx.symbolic_trace(model).graph.print_tabular() | ||
# print(model) | ||
with torch.no_grad(): | ||
out.append(model(batch)) | ||
|
||
print("After deletion") | ||
model = del_regularized_shortcut(model) | ||
fx.symbolic_trace(model).graph.print_tabular() | ||
with torch.no_grad(): | ||
out.append(model(batch)) | ||
|
||
for v in out[1:]: | ||
torch.testing.assert_allclose(out[0], v) |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.