Skip to content

Commit a0684e7

Browse files
authored
[feature] support no master weights option for low level zero plugin (#4816)
* [feature] support no master weights for low level zero plugin * [feature] support no master weights for low level zero plugin, remove data copy when no master weights * remove data copy and typecasting when no master weights * not load weights to cpu when using no master weights * fix grad: use fp16 grad when no master weights * only do not update working param when no master weights * fix: only do not update working param when no master weights * fix: passing params in dict format in hybrid plugin * fix: remove extra params (tp_process_group) in hybrid_parallel_plugin
1 parent 77a9328 commit a0684e7

File tree

3 files changed

+42
-28
lines changed

3 files changed

+42
-28
lines changed

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -464,23 +464,23 @@ def __init__(
464464
if use_pipeline:
465465
init_pipeline_optimizer(optimizer, model)
466466
super().__init__(
467-
optimizer,
468-
initial_scale,
469-
min_scale,
470-
growth_factor,
471-
backoff_factor,
472-
growth_interval,
473-
hysteresis,
474-
max_scale,
475-
clip_grad_norm,
476-
verbose,
477-
reduce_bucket_size,
478-
communication_dtype,
479-
overlap_communication,
480-
partition_grad,
481-
cpu_offload,
482-
dp_process_group,
483-
forced_dtype,
467+
optimizer=optimizer,
468+
initial_scale=initial_scale,
469+
min_scale=min_scale,
470+
growth_factor=growth_factor,
471+
backoff_factor=backoff_factor,
472+
growth_interval=growth_interval,
473+
hysteresis=hysteresis,
474+
max_scale=max_scale,
475+
clip_grad_norm=clip_grad_norm,
476+
verbose=verbose,
477+
reduce_bucket_size=reduce_bucket_size,
478+
communication_dtype=communication_dtype,
479+
overlap_communication=overlap_communication,
480+
partition_grad=partition_grad,
481+
cpu_offload=cpu_offload,
482+
dp_process_group=dp_process_group,
483+
forced_dtype=forced_dtype,
484484
)
485485

486486
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:

colossalai/booster/plugin/low_level_zero_plugin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ def __init__(
262262
communication_dtype: Optional[torch.dtype] = None,
263263
overlap_communication: bool = True,
264264
cpu_offload: bool = False,
265+
master_weights: bool = True,
265266
verbose: bool = False,
266267
) -> None:
267268
super().__init__()
@@ -272,18 +273,19 @@ def __init__(
272273
self.precision = precision
273274
self.zero_optim_kwargs = dict(
274275
initial_scale=initial_scale,
276+
min_scale=min_scale,
275277
growth_factor=growth_factor,
276278
backoff_factor=backoff_factor,
277279
growth_interval=growth_interval,
278280
hysteresis=hysteresis,
279-
min_scale=min_scale,
280281
max_scale=max_scale,
281282
clip_grad_norm=max_norm,
282283
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
283284
communication_dtype=communication_dtype,
284285
overlap_communication=overlap_communication,
285-
cpu_offload=cpu_offload,
286286
partition_grad=(stage == 2),
287+
cpu_offload=cpu_offload,
288+
master_weights=master_weights,
287289
)
288290
self.verbose = verbose
289291

colossalai/zero/low_level/low_level_optim.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
cpu_offload: bool = False, # cpu offload
7676
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
7777
forced_dtype: Optional[torch.dtype] = None,
78+
master_weights: bool = True, # master weights
7879
):
7980
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
8081
self._dtype = self.optim.param_groups[0]["params"][0].dtype
@@ -106,6 +107,9 @@ def __init__(
106107
# gradient clipping
107108
self._clip_grad_norm = clip_grad_norm
108109

110+
# master weights copy
111+
self._master_weights = master_weights
112+
109113
if forced_dtype:
110114
for group in self.optim.param_groups:
111115
group_params = group["params"]
@@ -135,7 +139,6 @@ def __init__(
135139
self._working_param_groups[group_id] = group_params
136140

137141
master_param_current_rank = self._create_master_param_current_rank(group_params)
138-
139142
self._master_param_groups_of_current_rank[group_id] = master_param_current_rank
140143

141144
# need to replace the params in the `params` field in the optimizer
@@ -200,11 +203,18 @@ def _create_master_param_current_rank(self, param_list):
200203
with torch.no_grad():
201204
if padding_size > 0:
202205
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
206+
# reset working params' ptr when no master weights
207+
if self._master_weights == False:
208+
param.data = padding_param[: param.numel()].view(param.shape)
203209
else:
204210
padding_param = param.data.view(-1)
205211
splited_params = padding_param.split(padding_param.numel() // self._world_size)
206212

207-
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
213+
# use fp32 when master_weights is True
214+
if self._master_weights is True:
215+
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
216+
else:
217+
splited_param_current_rank = splited_params[self._local_rank]
208218
params_current_rank.append(splited_param_current_rank)
209219
self._param_store.link_master_and_working_param(splited_param_current_rank, param)
210220

@@ -402,9 +412,7 @@ def step(self, closure=None):
402412
# and should not be updated
403413
real_working_params = dict()
404414
real_master_params = dict()
405-
406415
grad_index = 0 if self._partition_grads else self._local_rank
407-
408416
for group_id in range(self.num_param_groups):
409417
master_params = self._master_param_groups_of_current_rank[group_id]
410418
real_working_params[group_id] = []
@@ -417,7 +425,12 @@ def step(self, closure=None):
417425
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
418426
if len(grads) > 0:
419427
real_working_params[group_id].append(working_param)
420-
grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device)
428+
# no need to copy fp32 grad if master_weights is False
429+
grad = (
430+
grads[grad_index].to(splited_param.dtype).to(splited_param.device)
431+
if self._master_weights
432+
else grads[grad_index]
433+
)
421434
splited_param.grad = grad
422435
grad_partition_groups.append(grad)
423436
real_master_params[group_id].append(splited_param)
@@ -445,17 +458,16 @@ def step(self, closure=None):
445458
release_param_grad(self._master_param_groups_of_current_rank[group_id])
446459

447460
# update working partition updated by the current rank
448-
dtype = real_working_params[0][0].dtype
461+
# dtype = real_working_params[0][0].dtype
449462
for group_id in range(self.num_param_groups):
450463
master_working_param = self.optim.param_groups[group_id]["params"]
451464
for idx, splited_param in enumerate(master_working_param):
452465
working_param = real_working_params[group_id][idx]
453466
all_splited_param = [
454-
torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size)
467+
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size)
455468
]
456-
dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg)
469+
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
457470
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
458-
459471
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
460472

461473
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:

0 commit comments

Comments
 (0)