@@ -75,6 +75,7 @@ def __init__(
7575 cpu_offload : bool = False , # cpu offload
7676 dp_process_group : Optional [ProcessGroup ] = None , # the dp pg for comm
7777 forced_dtype : Optional [torch .dtype ] = None ,
78+ master_weights : bool = True , # master weights
7879 ):
7980 super (LowLevelZeroOptimizer , self ).__init__ (optim = optimizer )
8081 self ._dtype = self .optim .param_groups [0 ]["params" ][0 ].dtype
@@ -106,6 +107,9 @@ def __init__(
106107 # gradient clipping
107108 self ._clip_grad_norm = clip_grad_norm
108109
110+ # master weights copy
111+ self ._master_weights = master_weights
112+
109113 if forced_dtype :
110114 for group in self .optim .param_groups :
111115 group_params = group ["params" ]
@@ -135,7 +139,6 @@ def __init__(
135139 self ._working_param_groups [group_id ] = group_params
136140
137141 master_param_current_rank = self ._create_master_param_current_rank (group_params )
138-
139142 self ._master_param_groups_of_current_rank [group_id ] = master_param_current_rank
140143
141144 # need to replace the params in the `params` field in the optimizer
@@ -200,11 +203,18 @@ def _create_master_param_current_rank(self, param_list):
200203 with torch .no_grad ():
201204 if padding_size > 0 :
202205 padding_param = torch .nn .functional .pad (param .data .view (- 1 ), [0 , padding_size ])
206+ # reset working params' ptr when no master weights
207+ if self ._master_weights == False :
208+ param .data = padding_param [: param .numel ()].view (param .shape )
203209 else :
204210 padding_param = param .data .view (- 1 )
205211 splited_params = padding_param .split (padding_param .numel () // self ._world_size )
206212
207- splited_param_current_rank = splited_params [self ._local_rank ].detach ().float ().to (device )
213+ # use fp32 when master_weights is True
214+ if self ._master_weights is True :
215+ splited_param_current_rank = splited_params [self ._local_rank ].detach ().float ().to (device )
216+ else :
217+ splited_param_current_rank = splited_params [self ._local_rank ]
208218 params_current_rank .append (splited_param_current_rank )
209219 self ._param_store .link_master_and_working_param (splited_param_current_rank , param )
210220
@@ -402,9 +412,7 @@ def step(self, closure=None):
402412 # and should not be updated
403413 real_working_params = dict ()
404414 real_master_params = dict ()
405-
406415 grad_index = 0 if self ._partition_grads else self ._local_rank
407-
408416 for group_id in range (self .num_param_groups ):
409417 master_params = self ._master_param_groups_of_current_rank [group_id ]
410418 real_working_params [group_id ] = []
@@ -417,7 +425,12 @@ def step(self, closure=None):
417425 grads = self ._grad_store .get_partitioned_gradients_by_param_id (group_id , id (working_param ))
418426 if len (grads ) > 0 :
419427 real_working_params [group_id ].append (working_param )
420- grad = grads [grad_index ].to (splited_param .dtype ).to (splited_param .device )
428+ # no need to copy fp32 grad if master_weights is False
429+ grad = (
430+ grads [grad_index ].to (splited_param .dtype ).to (splited_param .device )
431+ if self ._master_weights
432+ else grads [grad_index ]
433+ )
421434 splited_param .grad = grad
422435 grad_partition_groups .append (grad )
423436 real_master_params [group_id ].append (splited_param )
@@ -445,17 +458,16 @@ def step(self, closure=None):
445458 release_param_grad (self ._master_param_groups_of_current_rank [group_id ])
446459
447460 # update working partition updated by the current rank
448- dtype = real_working_params [0 ][0 ].dtype
461+ # dtype = real_working_params[0][0].dtype
449462 for group_id in range (self .num_param_groups ):
450463 master_working_param = self .optim .param_groups [group_id ]["params" ]
451464 for idx , splited_param in enumerate (master_working_param ):
452465 working_param = real_working_params [group_id ][idx ]
453466 all_splited_param = [
454- torch .zeros (splited_param .shape , device = "cuda" , dtype = dtype ) for _ in range (self ._world_size )
467+ torch .zeros (splited_param .shape , device = "cuda" , dtype = self . _dtype ) for _ in range (self ._world_size )
455468 ]
456- dist .all_gather (all_splited_param , splited_param .cuda ().to (dtype ), group = self .dp_pg )
469+ dist .all_gather (all_splited_param , splited_param .cuda ().to (self . _dtype ), group = self .dp_pg )
457470 working_param .data .copy_ (flatten (all_splited_param )[: working_param .numel ()].reshape_as (working_param ))
458-
459471 self .optim .param_groups [group_id ]["params" ] = self ._master_param_groups_of_current_rank [group_id ]
460472
461473 def _compute_grad_norm (self , gradients : List [Tensor ], norm_type : int = 2 ) -> float :
0 commit comments