Skip to content

Respect strict=False when loading detection models #5841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 20, 2022

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Apr 20, 2022

Fixes #5835

Tested with:

from torchvision.models import detection

fns = [v for k, v in detection.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
for fn in fns:
    try:
        model = fn(weights_backbone=None)
        model.load_state_dict({}, strict=False)
        print(fn.__name__, "PASS")
    except:
        print(fn.__name__, "FAIL")

Output:

fasterrcnn_resnet50_fpn PASS
fasterrcnn_resnet50_fpn_v2 PASS
fasterrcnn_mobilenet_v3_large_fpn PASS
fasterrcnn_mobilenet_v3_large_320_fpn PASS
fcos_resnet50_fpn PASS
keypointrcnn_resnet50_fpn PASS
maskrcnn_resnet50_fpn PASS
maskrcnn_resnet50_fpn_v2 PASS
retinanet_resnet50_fpn PASS
retinanet_resnet50_fpn_v2 PASS
ssd300_vgg16 PASS
ssdlite320_mobilenet_v3_large PASS

I'm not adding the test in the unit-tests to avoid slowing them down further. This gap can be resolved properly once we redesign the testing strategy for models.

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @datumbox

@datumbox datumbox merged commit 3122ea1 into pytorch:main Apr 20, 2022
@datumbox datumbox deleted the bugfix/strict_false branch April 20, 2022 11:44
facebook-github-bot pushed a commit that referenced this pull request May 5, 2022
Summary:
* Convert weights only if `old_key` is in `state_dict`

* Fix linter

Reviewed By: jdsgomes, NicolasHug

Differential Revision: D36095666

fbshipit-source-id: 32300797a9a5c4aa1abe2414dfad75480f898760
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

torchvision models no longer respect strict=False when loading
3 participants