diff --git a/torchvision/models/resnet.py b/torchvision/models/resnet.py index 93f0d06f67b..70602705521 100644 --- a/torchvision/models/resnet.py +++ b/torchvision/models/resnet.py @@ -217,9 +217,9 @@ def __init__( # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual: for m in self.modules(): - if isinstance(m, Bottleneck): + if isinstance(m, Bottleneck) and m.bn3.weight is not None: nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] - elif isinstance(m, BasicBlock): + elif isinstance(m, BasicBlock) and m.bn2.weight is not None: nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] def _make_layer(