Skip to content

Commit fba4f42

Browse files
authored
allow to use custom norm_layer (#4621)
1 parent 54a4550 commit fba4f42

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

torchvision/models/regnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,8 @@ def _reset_parameters(self) -> None:
392392

393393

394394
def _regnet(arch: str, block_params: BlockParams, pretrained: bool, progress: bool, **kwargs: Any) -> RegNet:
395-
model = RegNet(block_params, norm_layer=partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1), **kwargs)
395+
norm_layer = kwargs.pop("norm_layer", partial(nn.BatchNorm2d, eps=1e-05, momentum=0.1))
396+
model = RegNet(block_params, norm_layer=norm_layer, **kwargs)
396397
if pretrained:
397398
if arch not in model_urls:
398399
raise ValueError(f"No checkpoint is available for model type {arch}")

0 commit comments

Comments
 (0)