|
11 | 11 |
|
12 | 12 | from dataclasses import dataclass, field
|
13 | 13 | from datetime import timedelta
|
| 14 | +from functools import partial |
14 | 15 | from io import BytesIO
|
15 | 16 | from timeit import default_timer as timer
|
16 | 17 | from typing import Any, Dict, List
|
|
20 | 21 | import torch
|
21 | 22 | import torch.nn.functional as F
|
22 | 23 | from torch.distributed import destroy_process_group
|
| 24 | +from torch.distributed._tensor.experimental.attention import context_parallel_buffers |
23 | 25 | from torch.distributed.checkpoint.stateful import Stateful
|
24 | 26 | from torch.distributed.elastic.multiprocessing.errors import record
|
25 | 27 | from torch.distributed.tensor.parallel import loss_parallel
|
@@ -169,6 +171,7 @@ def main(job_config: JobConfig):
|
169 | 171 | world_size = int(os.environ["WORLD_SIZE"])
|
170 | 172 | parallel_dims = ParallelDims(
|
171 | 173 | dp=job_config.training.data_parallel_degree,
|
| 174 | + cp=job_config.experimental.context_parallel_degree, |
172 | 175 | tp=job_config.training.tensor_parallel_degree,
|
173 | 176 | pp=job_config.experimental.pipeline_parallel_degree,
|
174 | 177 | world_size=world_size,
|
@@ -213,6 +216,20 @@ def main(job_config: JobConfig):
|
213 | 216 | job_config.experimental.enable_compiled_autograd,
|
214 | 217 | )
|
215 | 218 |
|
| 219 | + if parallel_dims.cp_enabled: |
| 220 | + cp_mesh = world_mesh["cp"] |
| 221 | + context_parallel_ctx = partial( |
| 222 | + context_parallel_buffers, |
| 223 | + cp_rank=cp_mesh.get_local_rank(), |
| 224 | + cp_world_size=cp_mesh.size(), |
| 225 | + ) |
| 226 | + else: |
| 227 | + context_parallel_ctx = partial( |
| 228 | + context_parallel_buffers, |
| 229 | + cp_rank=0, |
| 230 | + cp_world_size=1, |
| 231 | + ) |
| 232 | + |
216 | 233 | # loss fn can be shared by pipeline-parallel or non-pp execution
|
217 | 234 | def loss_fn(pred, labels):
|
218 | 235 | return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
|
@@ -371,38 +388,43 @@ def loss_fn(pred, labels):
|
371 | 388 | ntokens_since_last_log += labels.numel()
|
372 | 389 | data_loading_times.append(timer() - data_load_start)
|
373 | 390 |
|
374 |
| - input_ids = input_ids.cuda() |
375 |
| - labels = labels.cuda() |
376 | 391 | optimizers.zero_grad()
|
377 | 392 |
|
378 |
| - if parallel_dims.pp_enabled: |
379 |
| - # pipeline parallel forward / backward inside step() call |
380 |
| - is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 |
381 |
| - |
382 |
| - with train_context(): |
383 |
| - if pp_mesh.get_local_rank() == 0: |
384 |
| - pp_schedule.step(input_ids) |
385 |
| - elif is_last_stage: |
386 |
| - losses = [] |
387 |
| - pp_schedule.step(target=labels, losses=losses) |
388 |
| - else: |
389 |
| - pp_schedule.step() |
390 |
| - |
391 |
| - # accumulate losses across pipeline microbatches |
392 |
| - loss = ( |
393 |
| - torch.mean(torch.stack(losses)) |
394 |
| - if is_last_stage |
395 |
| - else torch.Tensor([-1.0]) |
396 |
| - ) |
397 |
| - else: |
398 |
| - # Non-PP forward / backward |
399 |
| - with train_context(): |
400 |
| - pred = model(input_ids) |
401 |
| - loss = loss_fn(pred, labels) |
402 |
| - # pred.shape=(bs, seq_len, vocab_size) |
403 |
| - # need to free to before bwd to avoid peaking memory |
404 |
| - del pred |
405 |
| - loss.backward() |
| 393 | + with context_parallel_ctx( |
| 394 | + buffers=[input_ids, labels, model.freqs_cis], |
| 395 | + seq_dims=[1, 1, 0], |
| 396 | + keep_orig_buffers=[False, False, True], |
| 397 | + ): |
| 398 | + input_ids = input_ids.cuda() |
| 399 | + labels = labels.cuda() |
| 400 | + if parallel_dims.pp_enabled: |
| 401 | + # pipeline parallel forward / backward inside step() call |
| 402 | + is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1 |
| 403 | + |
| 404 | + with train_context(): |
| 405 | + if pp_mesh.get_local_rank() == 0: |
| 406 | + pp_schedule.step(input_ids) |
| 407 | + elif is_last_stage: |
| 408 | + losses = [] |
| 409 | + pp_schedule.step(target=labels, losses=losses) |
| 410 | + else: |
| 411 | + pp_schedule.step() |
| 412 | + |
| 413 | + # accumulate losses across pipeline microbatches |
| 414 | + loss = ( |
| 415 | + torch.mean(torch.stack(losses)) |
| 416 | + if is_last_stage |
| 417 | + else torch.Tensor([-1.0]) |
| 418 | + ) |
| 419 | + else: |
| 420 | + # Non-PP forward / backward |
| 421 | + with train_context(): |
| 422 | + pred = model(input_ids) |
| 423 | + loss = loss_fn(pred, labels) |
| 424 | + # pred.shape=(bs, seq_len, vocab_size) |
| 425 | + # need to free to before bwd to avoid peaking memory |
| 426 | + del pred |
| 427 | + loss.backward() |
406 | 428 |
|
407 | 429 | # clip gradients
|
408 | 430 | for model in model_parts:
|
|
0 commit comments