Skip to content

Conversation

datumbox
Copy link
Contributor

Currently, the majority of our Detection models replace the BatchNorm2d layers with FrozenBatchNorm2d. This is a reasonable mitigation that improves the stability of training for small batch-sizes. Unfortunately, our current implementation freezes the BNs even when they are completely randomly initialized. Since FrozenBatchNorm2d freezes both the running stats and the affine parameters, its parameters get initialized and fixed to values to:

self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))

Consequently, the BN layers are effectively completely disabled for those who try to train the models from scratch.

This PR fixes the issue by replacing the BNs with FrozenBNs when at least some pre-trained weights are loaded.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Feb 18, 2022

💊 CI failures summary and remediations

As of commit d222c46 (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

1 failure not recognized by patterns:

Job Step Action
CircleCI cmake_macos_cpu curl -o conda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh
sh conda.sh -b
source $HOME/miniconda3/bin/activate
conda install -yq conda-build cmake
packaging/build_cmake.sh
🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3
)
is_trained = pretrained or pretrained_backbone
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if the model was trained for detection with large batch sizes from scratch, and then we finetune it afterwards (still with large batch sizes) then in this case we would be using FrozenBatchNorm.

This is an ok heuristic, but hints that we might want to make this an explicit parameter from the constructor in the future

@datumbox datumbox merged commit 350a3e8 into pytorch:main Mar 7, 2022
@datumbox datumbox deleted the bug/frozen_bn branch March 7, 2022 11:34
facebook-github-bot pushed a commit that referenced this pull request Mar 15, 2022
Reviewed By: vmoens

Differential Revision: D34878996

fbshipit-source-id: 690b04fe0810cbd45ed582067b79f7e4254c054e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants