Skip to content

Commit 43c08cd

Browse files
committed
Update on "[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel"
Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796"> after: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0"> [ghstack-poisoned]
1 parent b579e87 commit 43c08cd

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

torchtitan/models/llama/model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -427,15 +427,10 @@ def forward(self, tokens: torch.Tensor):
427427
torch.Tensor: Output logits after applying the Transformer model.
428428
429429
"""
430-
bs = tokens.shape[0]
431430
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stage
431+
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
432432
# fold batch dimension and sequence dimension for more efficient allgather/reduce_scatter
433-
if self.tok_embeddings:
434-
tokens = tokens.view(-1)
435-
h = self.tok_embeddings(tokens)
436-
else:
437-
h = tokens
438-
h = h.view(-1, self.model_args.dim)
433+
h = h.view(-1, self.model_args.dim)
439434

440435
seqlen = self.model_args.max_seq_len
441436
freqs_cis = self.freqs_cis[0:seqlen]
@@ -444,6 +439,7 @@ def forward(self, tokens: torch.Tensor):
444439

445440
h = self.norm(h) if self.norm else h
446441
# unfold batch and sequence dimension
442+
bs = tokens.shape[0]
447443
h = h.view(bs, -1, self.model_args.dim)
448444
output = self.output(h).float() if self.output else h
449445
return output

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,14 +350,18 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
350350
{
351351
"tok_embeddings": RowwiseParallel(
352352
input_layouts=Replicate(),
353-
output_layouts=Shard(0),
354353
),
355354
"output": col_parallel_strategy(
356355
input_layouts=Shard(0),
357356
output_layouts=Shard(-1) if loss_parallel else Replicate(),
358357
use_local_output=not loss_parallel,
359358
),
360359
"norm": SequenceParallel(sequence_dim=0),
360+
"layers.0": PrepareModuleInput(
361+
input_layouts=(Replicate(), None),
362+
desired_input_layouts=(Shard(0), None),
363+
use_local_output=True,
364+
),
361365
},
362366
)
363367

0 commit comments

Comments
 (0)