Skip to content

Commit e51ddb0

Browse files
committed
Update bn_reinit
1 parent 12357aa commit e51ddb0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

references/classification/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,9 @@ def bn_reinitialization(model: torch.nn.Module, gamma: float = 1.0, beta: float
197197
beta (float): The beta initial value.
198198
"""
199199
for module in model.modules():
200-
if isinstance(module, torch.nn._BatchNorm):
201-
module.weight.fill_(gamma)
202-
module.bias.fill_(beta)
200+
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
201+
torch.nn.init.constant_(module.weight, gamma)
202+
torch.nn.init.constant_(module.bias, beta)
203203

204204

205205
def accuracy(output, target, topk=(1,)):

0 commit comments

Comments
 (0)