@@ -116,28 +116,6 @@ def swa_update(
116116 return swa_update
117117
118118
119- def _load_state_dict_pre_hook (
120- module ,
121- state_dict ,
122- prefix ,
123- local_metadata ,
124- strict ,
125- missing_keys ,
126- unexpected_keys ,
127- error_msgs ,
128- ):
129- """Pre-hook to handle backward compatibility with tensor n_averaged."""
130- # Check if the old tensor n_averaged is present in the state dict
131- n_averaged_key = prefix + "n_averaged"
132- if n_averaged_key in state_dict :
133- # Convert tensor n_averaged to Python int for backward compatibility
134- n_averaged_tensor = state_dict [n_averaged_key ]
135- if isinstance (n_averaged_tensor , Tensor ):
136- module .n_averaged = int (n_averaged_tensor .item ())
137- # Remove the old tensor buffer from state_dict to avoid loading it
138- del state_dict [n_averaged_key ]
139-
140-
141119class AveragedModel (Module ):
142120 r"""Implements averaged model for Stochastic Weight Averaging (SWA) and Exponential Moving Average (EMA).
143121
@@ -237,7 +215,7 @@ class AveragedModel(Module):
237215 https://paperswithcode.com/method/polyak-averaging
238216 """
239217
240- n_averaged : int
218+ n_averaged : Tensor
241219
242220 def __init__ (
243221 self ,
@@ -256,25 +234,17 @@ def __init__(
256234 self .module = deepcopy (model )
257235 if device is not None :
258236 self .module = self .module .to (device )
259- self .n_averaged = 0
237+ self .register_buffer (
238+ "n_averaged" , torch .tensor (0 , dtype = torch .long , device = device )
239+ )
260240 self .avg_fn = avg_fn
261241 self .multi_avg_fn = multi_avg_fn
262242 self .use_buffers = use_buffers
263- self .register_load_state_dict_pre_hook (_load_state_dict_pre_hook )
264243
265244 def forward (self , * args , ** kwargs ):
266245 """Forward pass."""
267246 return self .module (* args , ** kwargs )
268247
269- def get_extra_state (self ) -> Any :
270- """Get extra state for serialization."""
271- return {"n_averaged" : self .n_averaged }
272-
273- def set_extra_state (self , state : Any ) -> None :
274- """Set extra state from deserialization."""
275- if isinstance (state , dict ) and "n_averaged" in state :
276- self .n_averaged = state ["n_averaged" ]
277-
278248 def update_parameters (self , model : Module ):
279249 """Update model parameters."""
280250 self_param = (
@@ -310,26 +280,28 @@ def update_parameters(self, model: Module):
310280 self .multi_avg_fn (
311281 self_params , # type: ignore[arg-type]
312282 model_params , # type: ignore[arg-type]
313- self .n_averaged ,
283+ self .n_averaged . to ( device ) ,
314284 )
315285 elif (
316286 device is not None
317287 and device .type in _get_foreach_kernels_supported_devices ()
318288 ):
319289 multi_avg_fn = get_swa_multi_avg_fn ()
320- multi_avg_fn (self_params , model_params , self .n_averaged )
290+ multi_avg_fn (
291+ self_params , model_params , self .n_averaged .to (device )
292+ )
321293 else :
322294 avg_fn = get_swa_avg_fn ()
295+ n_averaged = self .n_averaged .to (device )
323296 for p_averaged , p_model in zip (self_params , model_params ): # type: ignore[assignment]
324- p_averaged .copy_ (
325- avg_fn (p_averaged , p_model , self .n_averaged )
326- )
297+ p_averaged .copy_ (avg_fn (p_averaged , p_model , n_averaged ))
327298 else :
328299 for p_averaged , p_model in zip ( # type: ignore[assignment]
329300 self_param_detached , model_param_detached
330301 ):
302+ n_averaged = self .n_averaged .to (p_averaged .device )
331303 p_averaged .detach ().copy_ (
332- self .avg_fn (p_averaged .detach (), p_model , self . n_averaged )
304+ self .avg_fn (p_averaged .detach (), p_model , n_averaged )
333305 )
334306
335307 if not self .use_buffers :
0 commit comments