Skip to content

Add FX feature extraction as an alternative to intermediate_layer_getter #4302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Sep 6, 2021
Merged

Add FX feature extraction as an alternative to intermediate_layer_getter #4302

merged 22 commits into from
Sep 6, 2021

Conversation

alexander-soare
Copy link
Contributor

@alexander-soare alexander-soare commented Aug 21, 2021

Motivation is similar to #3597. I just happened to pick up on this while working on timm: huggingface/pytorch-image-models#800

This builds on #3597 with a few adjustments:

  • Keep track of leaf ops as well as leaf modules, giving the user even finer control
  • Handle reuse of same op/module using an incremented _{int} suffix on qualified node names. This is automatically done by fx for repeated ops, but not for modules.
  • Add utility for printing node qualified names of a model.
  • Add wrapper module which converts the input model to a graph module but keeps the input model's non-parametric properties.

A caveat, and possibly a dealbreaker that we should figure out how to handle:

  • Static control flow is frozen into place when the GraphModule is produced. This is especially not a good thing for control flow that depends on the state of model.training

Other things I will do once we are ready to commit:

  • Tests
  • More documentation

cc @datumbox

@alexander-soare alexander-soare marked this pull request as draft August 21, 2021 14:47
@alexander-soare
Copy link
Contributor Author

alexander-soare commented Aug 21, 2021

@fmassa just wanted to make sure this topic stays active so decided to get a draft going anyway. I have some checklist items above explaining what I think needs to be done to close this out. Can you please help me add what's needed there?

I believe you mentioned you'd chat to the fx team about the train/eval problem so I'll be on standby to pick it back up based on the outcome of that.

Also, do you think there's an appropriate place to add more docs than what's in the docstrings?

@fmassa
Copy link
Member

fmassa commented Aug 24, 2021

Thanks a ton for the PR, let me know when you'll want me to have a closer look at it.

For the points you brought, I have a few questions:

Handle reuse of same op/module using an incremented _{int} suffix on qualified node names. This is automatically done by fx for repeated ops, but not for modules.

Maybe I'm missing something here, but is it even possible to have two modules with the same name? If this is for handling things like layer3.2 being equal to layer3, is the best way to increment the node names (i.e., creating a new node) instead of having a one-to-many mapping from node to module?

Add wrapper module which converts the input model to a graph module but keeps the input model's non-parametric properties.

Should this be a new nn.Module, or should we instead re-use the same module returned by FX and add the missing attributes there? What's the benefits of wrapping it in a nn.Module? The current module wrapper doesn't have the same state-dict as the original model, which makes it a bit annoying to be seamlessly used to load state-dicts etc.

Static control flow is frozen into place when the GraphModule is produced. This is especially not a good thing for control flow that depends on the state of model.training

I've discussed about this with @jamesr66a . His view was that this was a special case of the more general re-specialization problem (like for specializing over shape / ranks of Tensors), and as of now FX is not going to be handling it explicitly, but nothing blocks us from having our own tooling to do this, and I think we should do it.

A naive way would be to trace the model twice (one in training=True and one with training=False) and have a wrapper nn.Module be returned that does dispatch to one of the two traces depending on the training mode.

A naive solution is something in the lines of

class WrappedModel(nn.Module):
    def __init__(self, model, ...):
        super().__init__()
        mode = model.training
        self.train_model = trace(model.train(), ...)
        self.eval_model = trace(model.eval(), ...)
        model.train(mode)

    def forward(self, *args, **kwargs):
        if self.training:
            return self.train_model(*args, **kwargs)
        else:
            return self.eval_model(*args, **kwargs)

but ideally we wouldn't duplicate parameters nor change the namespace of the parameters.

Also, do you think there's an appropriate place to add more docs than what's in the docstrings?

Yes, I think we should either make a new doc page under https://github.com/pytorch/vision/tree/main/docs/source or use the new galleries to better illustrate how this functionality can be used.
But this can come once we have finalized the APIs / etc, and can come as a follow-up PR.

I have some checklist items above explaining what I think needs to be done to close this out. Can you please help me add what's needed there?

The points I've made above might still require iterating a bit on some of the aspects implemented here (but please let me know if I'm missing something!). Also, adding tests will be important to validate that the code behave as expected (including dummy nn.Module that exercise the train / eval differences if this gets implemented now).

From my perspective, we can break this PR into multiple ones if it facilitate things for you:

  • get an initial version that works for all models in torchvision and handles the issues I mentioned in the first point in Use FX to have a more robust intermediate feature extraction #3597 (comment) . It would also be good to know what percentage of timm models are covered with this initial version (including what are most common failure modes), so we can plan accordingly next steps
  • extend support for train / eval differences
  • gallery with examples
  • fix some of the long-tail models to make the code more robust
  • (maybe?) upstream the feature to PyTorch

Thoughts?

@alexander-soare
Copy link
Contributor Author

alexander-soare commented Aug 28, 2021

@fmassa thanks for your input! I've made some good progress with the train/eval problem but I'm out of time for this week and will move onto tests/docs next week. Might be worth waiting till then before you have a closer look.

But to answer the points that you already have:

Maybe I'm missing something here, but is it even possible to have two modules with the same name?

I'm talking about something like: user defines self.relu = nn.ReLU() then uses it twice in some forward method. Without my modification, the qualnames would both come out as walk.path.to.relu. With my modification we get: walk.path.to.relu and walk.path.to.relu_1. So it's possible to disambiguate if we want to choose these as return nodes.

... is the best way to increment the node names (i.e., creating a new node) instead of having a one-to-many mapping from node to module?

FX creates the new nodes and the node to module mapping, and I'm thinking we leave that behaviour as it is. All we're doing is making a node to string mapping. And the only reason we do that is because we want to use our nicer convention for naming nodes so it's easy for our user to specify which nodes they want.

Should this be a new nn.Module, or should we instead re-use the same module returned by FX and add the missing attributes there?

I agree. I just deleted it and incorporated the desired behaviour into build_feature_graph_net.

A naive way would be to trace the model twice ... but ideally we wouldn't duplicate parameters nor change the namespace of the parameters.

Good suggestion. That's what I did + some gymnastics to not duplicate params or change the namespace.

It would also be good to know what percentage of timm models are covered with this initial version

100% of timm models are covered when following the guidelines here. When I say "covered" I mean that they can be converted, their forward and backward methods work, they are scriptable + forward + backward (wherever the original model is scriptable), and their output features match those of the conventional methods. Would it make sense to put a version of those guidelines in the docs and gradually phase them out as/if core FX updates solve those problems?

@fmassa
Copy link
Member

fmassa commented Sep 1, 2021

Hi @alexander-soare

Thanks for the clarifications, makes sense!

Let's get this first version of the PR merged into torchvision this week. Just tests + docs should be good for now.

Would it make sense to put a version of those guidelines in the docs and gradually phase them out as/if core FX updates solve those problems

Good question, I think it might be good to have some high-level guidelines in the documentation as well. I'm not sure FX will be able to capture assert without using torch._assert, but I'd feel a bit reluctant to propose users to rely on a private API.

I'll do a first pass on the PR today with some more high-level comments.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great, thanks a lot @alexander-soare !

I think we are almost ready to get this merged.

Also, I was wondering if we should add it in a different location in torchvision that is not private, maybe torchvision.models.utils or torchvision.models.feature_extraction or something like that?

@alexander-soare
Copy link
Contributor Author

alexander-soare commented Sep 1, 2021

@fmassa (sorry, edited this a few times as I was working)

Also, I was wondering if we should add it in a different location in torchvision that is not private, maybe torchvision.models.utils or torchvision.models.feature_extraction or something like that?

That makes sense. I used torchvision.models.feature_extraction and added a dedicated docs page. I also moved IntermediateLayerGetter there and made sure to reroute references to it. Let me know if there are any issues with that and I can revert it.

I'd feel a bit reluctant to propose users to rely on a private API.

Makes sense - see my list of tips in the docstring for build_feature_graph_net. I'll hold off adding the torch._assert tip unless you say otherwise.

@fmassa
Copy link
Member

fmassa commented Sep 2, 2021

Thanks for the changes @alexander-soare !

I also moved IntermediateLayerGetter there and made sure to reroute references to it. Let me know if there are any issues with that and I can revert it.

I've made IntermediateLayerGetter private because of all the limitations it had. The way I see things, build_feature_graph_net entirely replaces the need for IntermediateLayerGetter, so I would be just fine keeping it in _utils.py for now, and maybe issuing a deprecation warning (but let's do this in a separate PR), and not put it in the documentation neither (that's why it isn't documented).

I'm bikeshedding here, but I'm thinking if we can find another name for build_feature_graph_net and print_graph_node_qualified_names. Here is my reasoning:

  • print_graph_node_qualified_names has a bit too specific nomenclature, and I'm not sure if FX uses qualified_names much in its documentation. Maybe something like get_graph_node_names might be better?
  • build_feature_graph_net is a bit weird to me as well, maybe get_submodel or get_intermediate_execution_graph or something like that?

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome, thanks a ton @alexander-soare !

I've made a few more comments, and I think you can make the PR ready for a more detailed review already.
I've looked into the test and some of the APIs we are providing, but I would like to pull the PR locally and play a bit with it once you mark the PR ready for review.

@alexander-soare
Copy link
Contributor Author

alexander-soare commented Sep 3, 2021

@fmassa thanks for the continued support on this! I've sorted it all out in the latest commit. Two more things

Thing 1:

About the naming. Cool, I've changed them. I went for build_feature_extractor which I hope is also in line with your motivation (abstracting from specific nomenclature). If not, I have no problem at all changing it to whatever works.

I also dropped the use of the term "qualified node name" in the parts that are visible to the user. I do have one bikesheddy concern about this though. Calling it just a "node name" might not be specific enough, and may clash with other notions. An actual fx Node already has some sort of notion of a name attached to it. In fact, fx docs even mention qualified node names. But that's slightly different from the node naming convention I've put together. BUT, happy to go ahead with this PR as is, and leave this as a note.

  • Thing 2:

I would like to pull the PR locally and play a bit with it once you mark the PR ready for review.

Maybe it's ready for you to go ahead and do that - and in fact that would be welcome at this point. But before marking as ready, I wanted to have a go at training a detection model as the final test. Hoping to do that on the weekend.

@@ -32,6 +32,7 @@ architectures, and common image transformations for computer vision.
:caption: Package Reference

datasets
feature_extraction
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like feature_extraction is a submodule or a folder in torchvision, but I think the files are placed under models/feature_extraction

Copy link
Contributor Author

@alexander-soare alexander-soare Sep 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't keen on putting it in the models docs as that seems to have a consistent theme of talking about models, how to load them, and their metrics. And I'm a noob with auto-docs so didn't know if I could nest it in the table of contents but still have its own page. Any suggestions?

Copy link
Contributor Author

@alexander-soare alexander-soare Sep 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re. your comment below:

Probably there is small confusion in placing the files?

Yes probably :) Happy to hear further suggestions if you have any.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty much noob too 😄 Probably the maintainers here would tell you better solution for docs.

@oke-aditya
Copy link
Contributor

Hi, I just had a very small look at this. I was very eager for this feature.
Probably there is small confusion in placing the files?
Or maybe I have misunderstood something.

@alexander-soare
Copy link
Contributor Author

Trained ImageNet classification with pretrained resnet18 + concatenated FPN features with both IntermediateLayerGetter and build_feature_extractor. Trained each for 1 epoch printing loss values (6 decimal places) for approx 100 intervals. No difference between the two feature extraction methods.

@fmassa this covers off my "Thing 2" above so from my end this is ready for review.

@alexander-soare alexander-soare marked this pull request as ready for review September 4, 2021 15:41
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried the PR locally and it worked great, thanks!

I've just noted one potential (unwanted?) side-effect of the prefix matching, let me know what you think.

Also, the test failures can be fixed with some of the comments I made.

I think after this batch of comments is addressed we are good to merge, thanks a ton @alexander-soare !

@fmassa
Copy link
Member

fmassa commented Sep 6, 2021

@alexander-soare also, following some discussion we had, maybe using create_feature_extractor would be better than build_feature_extractor (the build might look like the builder pattern, which is not what we are doing here).
Would you mind changing it as well?

@alexander-soare
Copy link
Contributor Author

@fmassa that's all sorted. Thanks!

@@ -8,8 +8,8 @@
from .mobilenet import *
from .mnasnet import *
from .shufflenetv2 import *
from .efficientnet import *
Copy link
Member

@fmassa fmassa Sep 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this line back?

EDIT: don't bother, I'll do it here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep on it :) Sorry!

Copy link
Contributor Author

@alexander-soare alexander-soare Sep 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, there's more to do. For tests, need to add a leaf module that can't be traced through from effnet. Will send it soon. @fmassa

@alexander-soare
Copy link
Contributor Author

alexander-soare commented Sep 6, 2021

@fmassa see my comment above about your last push. Will pull your changes and tack on the last bit

@fmassa
Copy link
Member

fmassa commented Sep 6, 2021

@alexander-soare should we change the implementation of EfficientNet so that it can be properly traced?

Anyway, we can for now skip testing efficientnets, and tackle that in a follow-up PR?

cc @datumbox

@fmassa
Copy link
Member

fmassa commented Sep 6, 2021

@alexander-soare don't bother about fixing EfficientNet, we will do it ourselves. For now, can you just disable efficientnets in the testing of FX, so that we can move forward? Something like filtering efficientnet from the get_classification_models should be enough

@fmassa
Copy link
Member

fmassa commented Sep 6, 2021

@alexander-soare FYI we are changing StochasticDepth so that it is FX-traceable, see #4372 for the fix.

Anyway, we can get your diff merged once tests finish (lint is failing btw), and then revert the latest changes so that we test StochasticDepth as well

@alexander-soare
Copy link
Contributor Author

alexander-soare commented Sep 6, 2021

For now, can you just disable efficientnets in the testing of FX, so that we can move forward?

@fmassa yeah this does make sense. Was just trying to make sure there wasn't another hidden reason for the issue. So as it stands, Stochastic depth is treated as a leaf module (in the testing), and I can revert it once you're ready.

@fmassa
Copy link
Member

fmassa commented Sep 6, 2021

@alexander-soare there are still some minor test failures following your last changes, do you mind fixing it so that we can get this merged?

Traceback (most recent call last):
  File "/root/project/test/test_backbone_utils.py", line 245, in test_leaf_module_and_function
    'autowrap_functions': [leaf_function]}).train()
  File "/root/project/test/test_backbone_utils.py", line 58, in _create_feature_extractor
    suppress_diff_warning=True)
TypeError: functools.partial object got multiple values for keyword argument 'tracer_kwargs'

@alexander-soare
Copy link
Contributor Author

alexander-soare commented Sep 6, 2021

@fmassa sorted - introduced it last time... This is the one though. I can feel it (also I let the tests finish before I pushed this time 😛 )

@fmassa fmassa merged commit 72d650a into pytorch:main Sep 6, 2021
@fmassa
Copy link
Member

fmassa commented Sep 6, 2021

Thanks a ton for all your work @alexander-soare !

@alexander-soare
Copy link
Contributor Author

@fmassa was a pleasure. And thank you again for all the support!

facebook-github-bot pushed a commit that referenced this pull request Sep 9, 2021
…layer_getter (#4302)

Summary:
* add fx feature extraction util

* Make it possible to use train and eval mode

* FX feature extraction - Tweaks and small bug fixes

* FX feature extraction - add tests

* move to feature_extraction.py, add LeafModuleAwareTracer, add docs

* Tweaks to docs

* addressing latest round of feedback

* undo line spacing changes

* change type hints in docstrings

* fix sphinx indentation

* expose feature_extraction

* add maskrcnn example

* add api refernce subheading

* address latest review notes, refactor names, fix regex, cosmetics

* Add back efficientnet to models

* fix tests for effnet

* fix linting issue

* fix test tracer kwargs

Reviewed By: fmassa

Differential Revision: D30793334

fbshipit-source-id: 34b3497ce75c5b4773f4f3bab328330c114b4193

Co-authored-by: Francisco Massa <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants