Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,23 +464,23 @@ def __init__(
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(
optimizer,
initial_scale,
min_scale,
growth_factor,
backoff_factor,
growth_interval,
hysteresis,
max_scale,
clip_grad_norm,
verbose,
reduce_bucket_size,
communication_dtype,
overlap_communication,
partition_grad,
cpu_offload,
dp_process_group,
forced_dtype,
optimizer=optimizer,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
clip_grad_norm=clip_grad_norm,
verbose=verbose,
reduce_bucket_size=reduce_bucket_size,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
partition_grad=partition_grad,
cpu_offload=cpu_offload,
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
)

def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
Expand Down
6 changes: 4 additions & 2 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ def __init__(
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True,
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
) -> None:
super().__init__()
Expand All @@ -272,18 +273,19 @@ def __init__(
self.precision = precision
self.zero_optim_kwargs = dict(
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
min_scale=min_scale,
max_scale=max_scale,
clip_grad_norm=max_norm,
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(stage == 2),
cpu_offload=cpu_offload,
master_weights=master_weights,
)
self.verbose = verbose

Expand Down
30 changes: 21 additions & 9 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(
cpu_offload: bool = False, # cpu offload
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
forced_dtype: Optional[torch.dtype] = None,
master_weights: bool = True, # master weights
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]["params"][0].dtype
Expand Down Expand Up @@ -106,6 +107,9 @@ def __init__(
# gradient clipping
self._clip_grad_norm = clip_grad_norm

# master weights copy
self._master_weights = master_weights

if forced_dtype:
for group in self.optim.param_groups:
group_params = group["params"]
Expand Down Expand Up @@ -135,7 +139,6 @@ def __init__(
self._working_param_groups[group_id] = group_params

master_param_current_rank = self._create_master_param_current_rank(group_params)

self._master_param_groups_of_current_rank[group_id] = master_param_current_rank

# need to replace the params in the `params` field in the optimizer
Expand Down Expand Up @@ -200,11 +203,18 @@ def _create_master_param_current_rank(self, param_list):
with torch.no_grad():
if padding_size > 0:
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
# reset working params' ptr when no master weights
if self._master_weights == False:
param.data = padding_param[: param.numel()].view(param.shape)
else:
padding_param = param.data.view(-1)
splited_params = padding_param.split(padding_param.numel() // self._world_size)

splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
# use fp32 when master_weights is True
if self._master_weights is True:
splited_param_current_rank = splited_params[self._local_rank].detach().float().to(device)
else:
splited_param_current_rank = splited_params[self._local_rank]
params_current_rank.append(splited_param_current_rank)
self._param_store.link_master_and_working_param(splited_param_current_rank, param)

Expand Down Expand Up @@ -402,9 +412,7 @@ def step(self, closure=None):
# and should not be updated
real_working_params = dict()
real_master_params = dict()

grad_index = 0 if self._partition_grads else self._local_rank

for group_id in range(self.num_param_groups):
master_params = self._master_param_groups_of_current_rank[group_id]
real_working_params[group_id] = []
Expand All @@ -417,7 +425,12 @@ def step(self, closure=None):
grads = self._grad_store.get_partitioned_gradients_by_param_id(group_id, id(working_param))
if len(grads) > 0:
real_working_params[group_id].append(working_param)
grad = grads[grad_index].to(splited_param.dtype).to(splited_param.device)
# no need to copy fp32 grad if master_weights is False
grad = (
grads[grad_index].to(splited_param.dtype).to(splited_param.device)
if self._master_weights
else grads[grad_index]
)
splited_param.grad = grad
grad_partition_groups.append(grad)
real_master_params[group_id].append(splited_param)
Expand Down Expand Up @@ -445,17 +458,16 @@ def step(self, closure=None):
release_param_grad(self._master_param_groups_of_current_rank[group_id])

# update working partition updated by the current rank
dtype = real_working_params[0][0].dtype
# dtype = real_working_params[0][0].dtype
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
all_splited_param = [
torch.zeros(splited_param.shape, device="cuda", dtype=dtype) for _ in range(self._world_size)
torch.zeros(splited_param.shape, device="cuda", dtype=self._dtype) for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.cuda().to(dtype), group=self.dp_pg)
dist.all_gather(all_splited_param, splited_param.cuda().to(self._dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))

self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]

def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
Expand Down