@@ -77,16 +77,26 @@ def _distribute_dtensor(
7777 assert inner_mesh .mesh_dim_names is not None
7878 submesh_names = outer_mesh .mesh_dim_names + inner_mesh .mesh_dim_names
7979 spanned_mesh = outer_global_mesh [submesh_names ]
80- shard_dim = placements [0 ].dim
81- split_factor = inner_spec .num_shards_map [shard_dim ]
82- tensor_placement = (
83- (
84- _StridedShard (shard_dim , split_factor = split_factor )
85- if split_factor > 1
86- else placements [0 ]
87- ),
88- inner_spec .placements [0 ],
89- )
80+
81+ if placements [0 ].is_shard ():
82+ # for FSDP + TP dtensor placement
83+ shard_dim = placements [0 ].dim
84+ split_factor = inner_spec .num_shards_map [shard_dim ]
85+ tensor_placement = (
86+ (
87+ _StridedShard (shard_dim , split_factor = split_factor )
88+ if split_factor > 1
89+ else placements [0 ]
90+ ),
91+ inner_spec .placements [0 ],
92+ )
93+ elif placements [0 ].is_replicate ():
94+ # for DDP + TP dtensor placement
95+ tensor_placement = (placements [0 ], inner_spec .placements [0 ])
96+ else :
97+ raise ValueError (
98+ f"Unsupported placement { placements [0 ]} for distributing DTensor { tensor } "
99+ )
90100
91101 current_spec = DTensorSpec (
92102 mesh = outer_mesh ,
@@ -154,10 +164,8 @@ def replicate_compute(self, x):
154164 # the gradients are partial tensors that needs to perform reduction
155165 # (i.e. DDP: allreduce, FSDP: reduce_scatter, HSDP: mix of both)
156166
157- # NOTE: specifying mixed precision is only available in pytorch_intern24
158- # https://github.com/tianyu-l/pytorch_intern24/pull/20
159- # support for FSDP + TP (assuming TP shards the inner-most dim)
160- if self .mode == "fully_shard" and x ._spec .mesh .ndim == 2 :
167+ # support for FSDP/DDP + TP (assuming TP shards the inner-most dim)
168+ if x ._spec .mesh .mesh_dim_names [- 1 ] == "tp" :
161169 dp_placement , tp_placement = x ._spec .placements
162170 # TODO: remove tp_mesh as an input arg to data_parallel API and use x._spec.mesh["tp"]
163171 # after DeviceMesh supports slicing a non-root mesh
@@ -170,7 +178,8 @@ def replicate_compute(self, x):
170178 sharded_local_tensor , dp_mesh , self .param_sharding
171179 )
172180
173- # the actuall FSDP all-gather on dp_mesh
181+ # the actual FSDP's fwd all-gather & bwd reduce-scatter
182+ # DDP's all-reduce(bwd) on dp_mesh
174183 replicated_dtensor = sharded_dtensor .redistribute (
175184 placements = self .compute_placements ,
176185 forward_dtype = self .param_dtype ,
0 commit comments