Skip to content

Commit 00b1b18

Browse files
authored
[Bug] Update torch.optim.Optimizer parameter states after tensor parallelism (#3835)
* 1. Fix torch.optim.Optimizer parameter to address mapping in TP 2. Fix DTensor broadcast issues in cpu_ram_efficient_loading * Add DTensor Guard on fsdp_utils full_param.to_local() * add Unit Test * add comments for addressing feedback * fix ruff
1 parent a73fd3a commit 00b1b18

File tree

5 files changed

+167
-1
lines changed

5 files changed

+167
-1
lines changed

src/accelerate/accelerator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,8 @@ def _prepare_tp(self, *args):
15901590

15911591
device_mesh = self.torch_device_mesh
15921592

1593+
old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))
1594+
15931595
for arg in result:
15941596
if not isinstance(arg, torch.nn.Module):
15951597
continue
@@ -1613,6 +1615,24 @@ def _prepare_tp(self, *args):
16131615
dp = torch.nn.Parameter(dp, requires_grad=param.requires_grad)
16141616
setattr(module_to_tp, param_type, dp)
16151617

1618+
new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=False))
1619+
# Build a map from old to new params
1620+
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
1621+
1622+
def _get_tensor_address(p):
1623+
if isinstance(p, DTensor):
1624+
return p._local_tensor.data_ptr()
1625+
return p.data_ptr()
1626+
1627+
for obj in result:
1628+
if isinstance(obj, torch.optim.Optimizer):
1629+
for param_group in obj.param_groups:
1630+
# Each param_group originally maps to model parameters (e.g., from model.parameters()).
1631+
# After _prepare_tp(), parameter references are replaced with DTensor instances.
1632+
# Therefore, we remap the parameter references to their new DTensor addresses
1633+
# so that the optimizer can correctly update the model parameters.
1634+
param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
1635+
16161636
return args
16171637

16181638
def _prepare_cp(self, *args):

src/accelerate/utils/fsdp_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def fsdp2_load_full_state_dict(accelerator, model: torch.nn.Module, full_sd: dic
470470
full_sd (`dict`): The full state dict to load, can only be on rank 0
471471
"""
472472
import torch.distributed as dist
473-
from torch.distributed.tensor import distribute_tensor
473+
from torch.distributed.tensor import DTensor, distribute_tensor
474474

475475
# Model was previously copied to meta device
476476
meta_sharded_sd = model.state_dict()
@@ -506,6 +506,11 @@ def _cast_and_contiguous(tensor, to_contiguous, dtype):
506506
for (param_name, full_param), sharded_param in zip(full_sd.items(), meta_sharded_sd.values()):
507507
device_mesh = sharded_param.device_mesh
508508
full_param = full_param.detach().to(device_mesh.device_type)
509+
if isinstance(full_param, DTensor):
510+
# dist.broadcast() only supports torch.Tensor.
511+
# After prepare_tp(), model parameters may become DTensor.
512+
# To broadcast such a parameter, convert it to a local tensor first.
513+
full_param = full_param.to_local()
509514
dist.broadcast(full_param, src=0, group=dist.group.WORLD)
510515
sharded_tensor = distribute_tensor(full_param, device_mesh, sharded_param.placements)
511516
to_contiguous, casting_dtype = _infer_parameter_dtype(

tests/tp/fsdp2_tp_preparation.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from datetime import timedelta
16+
17+
import torch
18+
from datasets import load_dataset
19+
from torch.utils.data import DataLoader
20+
from transformers import AutoModelForCausalLM, AutoTokenizer
21+
22+
from accelerate import Accelerator, InitProcessGroupKwargs
23+
from accelerate.parallelism_config import ParallelismConfig
24+
from accelerate.utils import FullyShardedDataParallelPlugin
25+
26+
27+
class LmHeadWrapper(torch.nn.Module):
28+
def __init__(self, lm_head):
29+
super().__init__()
30+
self.lm_head = lm_head
31+
32+
def forward(self, x):
33+
return self.lm_head(x)
34+
35+
36+
def build_simple_dataloader(tokenizer, seq_len=64, batch_size=2):
37+
"""Build a simple dataloader for reproduction."""
38+
# Load small dataset
39+
raw = load_dataset("wikitext", "wikitext-2-raw-v1", split="train[:1%]")
40+
raw = raw.filter(lambda x: len(tokenizer(x["text"])["input_ids"]) > 0)
41+
raw = raw.select(range(min(100, len(raw)))) # Use only 100 samples
42+
43+
def tok_fn(examples):
44+
return tokenizer(examples["text"], truncation=True, max_length=seq_len)
45+
46+
ds = raw.map(tok_fn, batched=True, remove_columns=["text"])
47+
ds.set_format(type="torch", columns=["input_ids"])
48+
49+
def collate(batch):
50+
ids = [b["input_ids"] for b in batch]
51+
labels = [x.clone() for x in ids]
52+
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
53+
x = torch.nn.utils.rnn.pad_sequence(ids, batch_first=True, padding_value=pad_id)
54+
y = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
55+
return {"input_ids": x, "labels": y}
56+
57+
return DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)
58+
59+
60+
def main():
61+
# Configuration
62+
MODEL_NAME = "Qwen/Qwen3-0.6B"
63+
BATCH_SIZE = 2
64+
SEQ_LEN = 64
65+
TP = 2
66+
DP = 4 // TP
67+
68+
# Setup Accelerator with FSDP2
69+
init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
70+
pc = ParallelismConfig(dp_shard_size=DP, tp_size=TP)
71+
72+
fsdp_plugin = FullyShardedDataParallelPlugin(
73+
fsdp_version=2,
74+
reshard_after_forward=True,
75+
auto_wrap_policy="transformer_based_wrap",
76+
state_dict_type="SHARDED_STATE_DICT",
77+
activation_checkpointing=False,
78+
cpu_ram_efficient_loading=True,
79+
)
80+
81+
accelerator = Accelerator(kwargs_handlers=[init_kwargs], parallelism_config=pc, fsdp_plugin=fsdp_plugin)
82+
83+
rank = accelerator.process_index
84+
print(f"[Rank {rank}] Initializing...")
85+
86+
# Load model with TP if needed
87+
model_kwargs = {"tp_size": TP, "tp_plan": "auto", "device_mesh": accelerator.torch_device_mesh} if TP > 1 else {}
88+
89+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, use_cache=False, **model_kwargs)
90+
91+
model.lm_head = LmHeadWrapper(model.lm_head)
92+
93+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
94+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
95+
96+
print(f"[Rank {rank}] Building dataloader...")
97+
loader = build_simple_dataloader(tokenizer, seq_len=SEQ_LEN, batch_size=BATCH_SIZE)
98+
99+
print(f"[Rank {rank}] Preparing with accelerator...")
100+
# ERROR OCCURS HERE AT LINE 110 in original script
101+
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
102+
103+
print(f"[Rank {rank}] Preparation successful!")
104+
105+
106+
if __name__ == "__main__":
107+
main()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# FSDP2 Single Node Configuration
2+
# Status: CURRENT - Recommended for new single-node usage
3+
4+
compute_environment: LOCAL_MACHINE
5+
debug: false
6+
distributed_type: FSDP
7+
downcast_bf16: 'no'
8+
machine_rank: 0
9+
main_training_function: main
10+
mixed_precision: 'no'
11+
num_machines: 1
12+
num_processes: 4 # Adjust for your GPU count
13+
rdzv_backend: static
14+
same_network: true
15+
tpu_env: []
16+
tpu_use_cluster: false
17+
tpu_use_sudo: false
18+
use_cpu: false

tests/tp/test_tp.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
import os
17+
1618
from accelerate.test_utils.testing import (
1719
TempDirTestCase,
1820
execute_subprocess_async,
@@ -61,3 +63,17 @@ def test_working_of_tp(self):
6163
)
6264
with patch_environment(omp_num_threads=1):
6365
execute_subprocess_async(cmd)
66+
67+
def test_working_of_tp_and_fsdp(self):
68+
current_dir = os.path.dirname(os.path.abspath(__file__))
69+
self.test_file_path = os.path.join(current_dir, "fsdp2_tp_preparation.py")
70+
self.test_config_path = os.path.join(current_dir, "fsdp2_tp_preparation_config.yaml")
71+
cmd = get_launch_command()
72+
cmd.extend(
73+
[
74+
f"--config_file={self.test_config_path}",
75+
self.test_file_path,
76+
]
77+
)
78+
with patch_environment(omp_num_threads=4):
79+
execute_subprocess_async(cmd)

0 commit comments

Comments
 (0)