diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index 496efcfd14..f211743770 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -176,6 +176,7 @@ def convert_sync_batchnorm(module, process_group=None): module_output.running_mean = module.running_mean module_output.running_var = module.running_var module_output.num_batches_tracked = module.num_batches_tracked + module_output.training = module.training if hasattr(module, "qconfig"): module_output.qconfig = module.qconfig for name, child in module.named_children():