Skip to content

Commit f5a3ad7

Browse files
authored
simplify embedding + first transformer block TP (#314)
as titled, we can directly specify the rowwise parallel embedding output layouts be shard on sequence dim, so that we don't need the first layer prepare input. Switching to output_layouts = Shard(1) would also trigger reduce_scatter instead of allreduce for embedding layer, which could give some small perf wins
1 parent 3295448 commit f5a3ad7

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -160,18 +160,14 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
160160
{
161161
"tok_embeddings": RowwiseParallel(
162162
input_layouts=Replicate(),
163+
output_layouts=Shard(1),
163164
),
164165
"output": col_parallel_strategy(
165166
input_layouts=Shard(1),
166167
output_layouts=(Shard(-1) if loss_parallel else Replicate()),
167168
use_local_output=not loss_parallel,
168169
),
169170
"norm": SequenceParallel(),
170-
"layers.0": PrepareModuleInput(
171-
input_layouts=(Replicate(), None),
172-
desired_input_layouts=(Shard(1), None),
173-
use_local_output=True,
174-
),
175171
},
176172
)
177173

0 commit comments

Comments
 (0)