Skip to content

Commit 3f3bc38

Browse files
committed
Enable CP
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested. ghstack-source-id: 651f322 Pull Request resolved: #433
1 parent 6234a06 commit 3f3bc38

File tree

5 files changed

+108
-37
lines changed

5 files changed

+108
-37
lines changed

estimation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def estimate_memory(job_config: JobConfig):
6767

6868
parallel_dims = ParallelDims(
6969
dp=job_config.training.data_parallel_degree,
70+
cp=job_config.experimental.context_parallel_degree,
7071
tp=job_config.training.tensor_parallel_degree,
7172
pp=job_config.experimental.pipeline_parallel_degree,
7273
world_size=world_size,

torchtitan/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,12 @@ def __init__(self):
323323
action="store_true",
324324
help="Enable CompiledAutograd to compile the backward.",
325325
)
326+
self.parser.add_argument(
327+
"--experimental.context_parallel_degree",
328+
type=int,
329+
default=1,
330+
help="Context parallelism degree. 1 means disabled.",
331+
)
326332
self.parser.add_argument(
327333
"--training.mixed_precision_param",
328334
type=str,

torchtitan/parallelisms/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
@dataclass
2525
class ParallelDims:
2626
dp: int
27+
cp: int
2728
tp: int
2829
pp: int
2930
world_size: int
@@ -35,22 +36,24 @@ def __post_init__(self):
3536
self._validate()
3637

3738
def _validate(self):
38-
dp, tp, pp = self.dp, self.tp, self.pp
39+
dp, cp, tp, pp = self.dp, self.cp, self.tp, self.pp
3940
if dp == -1:
40-
self.dp = dp = self.world_size // (tp * pp)
41+
self.dp = dp = self.world_size // (cp * tp * pp)
4142
assert dp >= 1, dp
43+
assert cp >= 1, cp
4244
assert tp >= 1, tp
4345
assert pp >= 1, pp
44-
assert (
45-
dp * tp * pp == self.world_size
46-
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
46+
assert dp * cp * tp * pp == self.world_size, (
47+
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
48+
f"!= WORLD_SIZE({self.world_size})"
49+
)
4750
assert self.dp_type in ("fsdp", "ddp")
4851

4952
def build_mesh(self, device_type):
5053
dims = []
5154
names = []
5255
for d, name in zip(
53-
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
56+
[self.pp, self.dp, self.cp, self.tp], ["pp", "dp", "cp", "tp"], strict=True
5457
):
5558
if d > 1:
5659
dims.append(d)
@@ -63,6 +66,10 @@ def build_mesh(self, device_type):
6366
def dp_enabled(self):
6467
return self.dp > 1
6568

69+
@property
70+
def cp_enabled(self):
71+
return self.cp > 1
72+
6673
@property
6774
def tp_enabled(self):
6875
return self.tp > 1

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,15 @@
1919

2020
from torch.distributed._composable.replicate import replicate
2121
from torch.distributed._tensor import Replicate, Shard
22+
23+
try:
24+
from torch.distributed._tensor.experimental.attention import enable_context_parallel
25+
except ImportError:
26+
print("The PyTorch version does not include the experimental CP APIs.")
2227
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
2328
checkpoint_wrapper as ptd_checkpoint_wrapper,
2429
)
30+
from torch.distributed.device_mesh import init_device_mesh
2531
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
2632
from torch.distributed.tensor.parallel import (
2733
ColwiseParallel,
@@ -451,17 +457,43 @@ def apply_compile(model: nn.Module, job_config: JobConfig):
451457
return model
452458

453459

460+
def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig):
461+
"""
462+
Apply context parallelism to the model. This is an experimental feature.
463+
"""
464+
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
465+
raise NotImplementedError("CP + TP or CP + PP are not supported yet.")
466+
cp_mesh = world_mesh["cp"]
467+
callers = []
468+
for layer_id, transformer_block in model.layers.items():
469+
callers.append(transformer_block.attention)
470+
enable_context_parallel(seq_dim=2, callers=callers, device_mesh=cp_mesh)
471+
logger.info("Applied CP to the model")
472+
473+
return model
474+
475+
454476
def apply_fsdp(
455477
model: nn.Module,
456478
world_mesh: DeviceMesh,
457479
parallel_dims: "ParallelDims",
458480
job_config: JobConfig,
459481
):
482+
460483
"""
461484
Apply data parallelism to the model. FSDP2 is used here.
462485
"""
463486

464-
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
487+
if parallel_dims.cp_enabled:
488+
# Temporary solution to enable FSDP + CP
489+
dp_mesh = init_device_mesh(
490+
world_mesh.device_type,
491+
(parallel_dims.dp * parallel_dims.cp,),
492+
mesh_dim_names=["dp"],
493+
)
494+
else:
495+
dp_mesh = world_mesh["dp"]
496+
465497
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
466498

467499
mp_policy = MixedPrecisionPolicy(
@@ -538,6 +570,9 @@ def parallelize_llama(
538570
if job_config.training.compile:
539571
model = apply_compile(model, job_config)
540572

573+
if parallel_dims.cp_enabled:
574+
model = apply_cp(model, world_mesh, parallel_dims, job_config)
575+
541576
if parallel_dims.dp_enabled:
542577
if parallel_dims.dp_type == "fsdp":
543578
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)

train.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from dataclasses import dataclass, field
1313
from datetime import timedelta
14+
from functools import partial
1415
from io import BytesIO
1516
from timeit import default_timer as timer
1617
from typing import Any, Dict, List
@@ -20,6 +21,7 @@
2021
import torch
2122
import torch.nn.functional as F
2223
from torch.distributed import destroy_process_group
24+
from torch.distributed._tensor.experimental.attention import context_parallel_buffers
2325
from torch.distributed.checkpoint.stateful import Stateful
2426
from torch.distributed.elastic.multiprocessing.errors import record
2527
from torch.distributed.tensor.parallel import loss_parallel
@@ -169,6 +171,7 @@ def main(job_config: JobConfig):
169171
world_size = int(os.environ["WORLD_SIZE"])
170172
parallel_dims = ParallelDims(
171173
dp=job_config.training.data_parallel_degree,
174+
cp=job_config.experimental.context_parallel_degree,
172175
tp=job_config.training.tensor_parallel_degree,
173176
pp=job_config.experimental.pipeline_parallel_degree,
174177
world_size=world_size,
@@ -213,6 +216,20 @@ def main(job_config: JobConfig):
213216
job_config.experimental.enable_compiled_autograd,
214217
)
215218

219+
if parallel_dims.cp_enabled:
220+
cp_mesh = world_mesh["cp"]
221+
context_parallel_ctx = partial(
222+
context_parallel_buffers,
223+
cp_rank=cp_mesh.get_local_rank(),
224+
cp_world_size=cp_mesh.size(),
225+
)
226+
else:
227+
context_parallel_ctx = partial(
228+
context_parallel_buffers,
229+
cp_rank=0,
230+
cp_world_size=1,
231+
)
232+
216233
# loss fn can be shared by pipeline-parallel or non-pp execution
217234
def loss_fn(pred, labels):
218235
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
@@ -371,38 +388,43 @@ def loss_fn(pred, labels):
371388
ntokens_since_last_log += labels.numel()
372389
data_loading_times.append(timer() - data_load_start)
373390

374-
input_ids = input_ids.cuda()
375-
labels = labels.cuda()
376391
optimizers.zero_grad()
377392

378-
if parallel_dims.pp_enabled:
379-
# pipeline parallel forward / backward inside step() call
380-
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
381-
382-
with train_context():
383-
if pp_mesh.get_local_rank() == 0:
384-
pp_schedule.step(input_ids)
385-
elif is_last_stage:
386-
losses = []
387-
pp_schedule.step(target=labels, losses=losses)
388-
else:
389-
pp_schedule.step()
390-
391-
# accumulate losses across pipeline microbatches
392-
loss = (
393-
torch.mean(torch.stack(losses))
394-
if is_last_stage
395-
else torch.Tensor([-1.0])
396-
)
397-
else:
398-
# Non-PP forward / backward
399-
with train_context():
400-
pred = model(input_ids)
401-
loss = loss_fn(pred, labels)
402-
# pred.shape=(bs, seq_len, vocab_size)
403-
# need to free to before bwd to avoid peaking memory
404-
del pred
405-
loss.backward()
393+
with context_parallel_ctx(
394+
buffers=[input_ids, labels, model.freqs_cis],
395+
seq_dims=[1, 1, 0],
396+
keep_orig_buffers=[False, False, True],
397+
):
398+
input_ids = input_ids.cuda()
399+
labels = labels.cuda()
400+
if parallel_dims.pp_enabled:
401+
# pipeline parallel forward / backward inside step() call
402+
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
403+
404+
with train_context():
405+
if pp_mesh.get_local_rank() == 0:
406+
pp_schedule.step(input_ids)
407+
elif is_last_stage:
408+
losses = []
409+
pp_schedule.step(target=labels, losses=losses)
410+
else:
411+
pp_schedule.step()
412+
413+
# accumulate losses across pipeline microbatches
414+
loss = (
415+
torch.mean(torch.stack(losses))
416+
if is_last_stage
417+
else torch.Tensor([-1.0])
418+
)
419+
else:
420+
# Non-PP forward / backward
421+
with train_context():
422+
pred = model(input_ids)
423+
loss = loss_fn(pred, labels)
424+
# pred.shape=(bs, seq_len, vocab_size)
425+
# need to free to before bwd to avoid peaking memory
426+
del pred
427+
loss.backward()
406428

407429
# clip gradients
408430
for model in model_parts:

0 commit comments

Comments
 (0)