Skip to content

Commit a201f90

Browse files
committed
add support for ddp+tp
1 parent d0ed9b4 commit a201f90

File tree

2 files changed

+36
-27
lines changed

2 files changed

+36
-27
lines changed

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,26 @@ def _distribute_dtensor(
7777
assert inner_mesh.mesh_dim_names is not None
7878
submesh_names = outer_mesh.mesh_dim_names + inner_mesh.mesh_dim_names
7979
spanned_mesh = outer_global_mesh[submesh_names]
80-
shard_dim = placements[0].dim
81-
split_factor = inner_spec.num_shards_map[shard_dim]
82-
tensor_placement = (
83-
(
84-
_StridedShard(shard_dim, split_factor=split_factor)
85-
if split_factor > 1
86-
else placements[0]
87-
),
88-
inner_spec.placements[0],
89-
)
80+
81+
if placements[0].is_shard():
82+
# for FSDP + TP dtensor placement
83+
shard_dim = placements[0].dim
84+
split_factor = inner_spec.num_shards_map[shard_dim]
85+
tensor_placement = (
86+
(
87+
_StridedShard(shard_dim, split_factor=split_factor)
88+
if split_factor > 1
89+
else placements[0]
90+
),
91+
inner_spec.placements[0],
92+
)
93+
elif placements[0].is_replicate():
94+
# for DDP + TP dtensor placement
95+
tensor_placement = (placements[0], inner_spec.placements[0])
96+
else:
97+
raise ValueError(
98+
f"Unsupported placement {placements[0]} for distributing DTensor {tensor}"
99+
)
90100

91101
current_spec = DTensorSpec(
92102
mesh=outer_mesh,
@@ -154,10 +164,8 @@ def replicate_compute(self, x):
154164
# the gradients are partial tensors that needs to perform reduction
155165
# (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)
156166

157-
# NOTE: specifying mixed precision is only available in pytorch_intern24
158-
# https://github.com/tianyu-l/pytorch_intern24/pull/20
159-
# support for FSDP + TP (assuming TP shards the inner-most dim)
160-
if self.mode == "fully_shard" and x._spec.mesh.ndim == 2:
167+
# support for FSDP/DDP + TP (assuming TP shards the inner-most dim)
168+
if x._spec.mesh.mesh_dim_names[-1] == "tp":
161169
dp_placement, tp_placement = x._spec.placements
162170
# TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"]
163171
# after DeviceMesh supports slicing a non-root mesh
@@ -170,7 +178,8 @@ def replicate_compute(self, x):
170178
sharded_local_tensor, dp_mesh, self.param_sharding
171179
)
172180

173-
# the actuall FSDP all-gather on dp_mesh
181+
# the actual FSDP's fwd all-gather & bwd reduce-scatter
182+
# DDP's all-reduce(bwd) on dp_mesh
174183
replicated_dtensor = sharded_dtensor.redistribute(
175184
placements=self.compute_placements,
176185
forward_dtype=self.param_dtype,

torchtitan/experiments/simple_fsdp/tests/integration_tests.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def build_test_list():
131131
"hsdp",
132132
ngpu=4,
133133
),
134-
# TODO: Adds back after HSDP+TP & DDP+TP is supported by SimpleFSDP
134+
# TODO: Adds back after HSDP+TP is supported by SimpleFSDP
135135
# OverrideDefinitions(
136136
# [
137137
# [
@@ -144,17 +144,17 @@ def build_test_list():
144144
# "hsdp+tp",
145145
# ngpu=8,
146146
# ),
147-
# OverrideDefinitions(
148-
# [
149-
# [
150-
# "--parallelism.data_parallel_replicate_degree=2",
151-
# "--parallelism.tensor_parallel_degree=2",
152-
# ]
153-
# ],
154-
# "DDP+TP",
155-
# "ddp+tp",
156-
# ngpu=4,
157-
# ),
147+
OverrideDefinitions(
148+
[
149+
[
150+
"--parallelism.data_parallel_replicate_degree=2",
151+
"--parallelism.tensor_parallel_degree=2",
152+
]
153+
],
154+
"DDP+TP",
155+
"ddp+tp",
156+
ngpu=4,
157+
),
158158
OverrideDefinitions(
159159
[
160160
[

0 commit comments

Comments
 (0)