diff --git a/run_train.sh b/run_train.sh index 8aaf55de28..08ae54d92b 100755 --- a/run_train.sh +++ b/run_train.sh @@ -11,7 +11,7 @@ set -ex # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh NGPU=${NGPU:-"8"} -export LOG_RANK=${LOG_RANK:-0} +export LOG_RANK=${LOG_RANK:-0,2} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"} diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 7137579f18..b7f33464c5 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -375,6 +375,14 @@ class Parallelism: The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size. """ + pipeline_parallel_expert_parallel_overlap: bool = True + """Whether to turn on the optimization to overlap expert parallel and pipeline parallel + communication. This is only effective when the pipeline parallel schedule is DualPipeV and + pipeline_parallel_degree > 1 and expert_parallel_degree > 1. + + TODO: Does not support activation_checkpoint, set mode="none" + """ + context_parallel_degree: int = 1 """Context parallelism degree. 1 means disabled.""" diff --git a/torchtitan/distributed/dual_pipe_v.py b/torchtitan/distributed/dual_pipe_v.py new file mode 100644 index 0000000000..f5a6b7b7f7 --- /dev/null +++ b/torchtitan/distributed/dual_pipe_v.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import threading +from typing import Optional + +import torch +import torch.nn as nn + +from torch.distributed.pipelining.schedules import ( + _Action, + _PipelineContext, + _PipelineScheduleRuntime, + _wait_batch_p2p, +) +from torch.distributed.pipelining.stage import _PipelineStageBase +from torch.distributed.tensor import DeviceMesh, distribute_module +from torch.profiler import record_function + +from torchtitan.distributed.expert_parallel import ExpertParallel + +from torchtitan.tools.utils import get_device_info + +""" +Below are optimizations related to pipeline parallelism with expert parallelism +""" + + +class DualPipeExpertParallel(ExpertParallel): + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + """ + The execution order is: + A -> dispatch -> B -> module -> C -> combine -> D + + Hooks are called in the order they are registered: + SyncHookA, _token_dispatch, SyncHookB (pre hooks) + SyncHookC, _token_combine, SyncHookD (post hooks) + """ + inner_wrapped_module = self._wrap_with_pre_comm_hooks(module) + distributed_module = distribute_module( + inner_wrapped_module, + device_mesh, + partition_fn=ExpertParallel._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + final_module = self._wrap_with_post_comm_hooks(distributed_module) + return final_module + + def _wrap_with_pre_comm_hooks(self, module): + def inner_pre_hook(module, input): + return (SyncHook.apply(input[0], "A"),) + input[1:] + + def inner_post_hook(module, input, output): + return SyncHook.apply(output, "C") + + module.register_forward_pre_hook(inner_pre_hook) + module.register_forward_hook(inner_post_hook) + return module + + def _wrap_with_post_comm_hooks(self, module): + def outer_pre_hook(module, input): + return (SyncHook.apply(input[0], "B"),) + input[1:] + + def outer_post_hook(module, input, output): + return SyncHook.apply(output, "D") + + module.register_forward_pre_hook(outer_pre_hook) + module.register_forward_hook(outer_post_hook) + return module + + +class HookCoordinator: + def __init__(self): + # Barrier for 2 threads (forward and backward) to synchronize + # This ensures that we always alternate at executing one compute and one comm op together + self._execution_barrier = threading.Barrier(2) + + self._coordination_enabled = False + self._cycle_count = 0 + self._num_layers = None + + def barrier(self): + """Barrier for 2 threads to synchronize""" + if not self.is_coordination_enabled(): + return + + try: + self._execution_barrier.wait() + except threading.BrokenBarrierError: + pass + + def enable_coordination(self, num_layers: Optional[int] = None): + if num_layers is not None and num_layers > 0: + self._coordination_enabled = True + self._cycle_count = 0 + + # Reset barrier + self._execution_barrier = threading.Barrier(2) + self._num_layers = num_layers + + def disable_coordination(self): + self._coordination_enabled = False + self._cycle_count = 0 + self._execution_barrier.abort() # Break barrier to unblock threads + + def check_should_continue_coordination(self): + if self._num_layers is not None and self._cycle_count >= self._num_layers: + return False + return True + + def is_coordination_enabled(self): + return self._coordination_enabled + + +# Global coordinator +_hook_coordinator = HookCoordinator() + + +class SyncHook(torch.autograd.Function): + @staticmethod + def forward(ctx, x, hook_name=""): + ctx.hook_name = hook_name + # handle edge case for transformer level boundary + if _hook_coordinator._coordination_enabled and hook_name == "D": + _hook_coordinator._cycle_count += 1 + if not _hook_coordinator.check_should_continue_coordination(): + _hook_coordinator.disable_coordination() + return x + + _hook_coordinator.barrier() + return x + + @staticmethod + def backward(ctx, grad_output): + hook_name = ctx.hook_name + + # Edge case, skip initial barrier, all subsequent backward hooks will acquire + if hook_name == "D" and _hook_coordinator._cycle_count == 0: + return grad_output, None + + _hook_coordinator.barrier() + return grad_output, None + + +def _count_moe_modules(model): + """Count MoE modules directly""" + from torchtitan.models.moe import MoE + + moe_count = 0 + for _, module in model.named_modules(): + if isinstance(module, MoE): + moe_count += 1 + return moe_count + + +# import fbvscode +# fbvscode.attach_debugger() + +device_type, device_module = get_device_info() + + +def overlap_callback(action: _Action, ctx: _PipelineContext): + """ + Custom callback for OVERLAP_F_B computation that allows expert parallel communication + and pipeline parallel computation to overlap. + """ + schedule = ctx.schedule_ref + assert isinstance(schedule, _PipelineScheduleRuntime) + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in schedule._stages + } + assert action.sub_actions is not None + fwd_action = action.sub_actions[0] + bwd_action = action.sub_actions[1] + + # Get stages + forward_stage_index = fwd_action.stage_index + forward_mb_index = fwd_action.microbatch_index + assert forward_mb_index is not None + backward_stage_index = bwd_action.stage_index + backward_stage = stage_index_to_stage[backward_stage_index] + + # Forward setup + arg_mbs = ctx.arg_mbs + kwarg_mbs = ctx.kwarg_mbs + assert arg_mbs is not None and kwarg_mbs is not None + fwd_recv_ops = schedule.fwd_recv_ops + forward_stage = stage_index_to_stage[forward_stage_index] + forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage + forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage + + # Backward setup + backward_is_next_stage_on_this_rank = ( + backward_stage.stage_index + 1 in stage_index_to_stage + ) + backward_is_prev_stage_on_this_rank = ( + backward_stage.stage_index - 1 in stage_index_to_stage + ) + backward_mb_index = bwd_action.microbatch_index + assert backward_mb_index is not None + bwd_recv_ops = schedule.bwd_recv_ops + + # Fwd receives + if ( + not forward_stage.is_first + # no recv op expected for V-schedule special case + and not forward_is_prev_stage_on_this_rank + ): + assert ( + forward_stage_index, + forward_mb_index, + ) in fwd_recv_ops, f"Computing {action=} before receiving input" + _wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index))) + + # Bwd receives + if ( + not backward_stage.is_last + # no recv op expected for V-schedule special case + and not backward_is_next_stage_on_this_rank + ): + assert ( + backward_stage_index, + backward_mb_index, + ) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" + _wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index))) + + # We count num layers in case the stage layers differ + # If they differ than we only want coordination to happen for the min amount of layers + min_num_layers = min( + _count_moe_modules(forward_stage.submod), + _count_moe_modules(backward_stage.submod), + ) + # PP computation ======================================================== + _hook_coordinator.enable_coordination(num_layers=min_num_layers) + main_stream = torch.accelerator.current_stream(device_module) + + # Shared container for exception from backward thread + def run_backward(): + schedule._assert_unsharded(backward_stage) + # Set the backward thread to use the same stream as forward + device_module.set_stream(main_stream) + with record_function( + f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}" + ): + loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) + schedule.backward_counter[backward_stage_index] += 1 + last_backward = ( + schedule.backward_counter[backward_stage_index] + == schedule._n_microbatches + ) + backward_stage.backward_one_chunk( + backward_mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + + if backward_is_prev_stage_on_this_rank: + stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( + backward_stage.get_local_bwd_output(backward_mb_index), + backward_mb_index, + ) + + def run_forward(): + schedule._assert_unsharded(forward_stage) + output = forward_stage.forward_one_chunk( + forward_mb_index, + arg_mbs[forward_mb_index], + kwarg_mbs[forward_mb_index], + ) + schedule._maybe_compute_loss( + forward_stage, output, ctx.target_mbs, forward_mb_index + ) + if forward_is_next_stage_on_this_rank: + stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input( + output, forward_mb_index + ) + + # Run forward and backward in parallel + thread = threading.Thread(target=run_backward, daemon=True) + thread.start() + run_forward() + thread.join() + + _hook_coordinator.disable_coordination() diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 06dba40d6f..bafefddbec 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -18,6 +18,7 @@ _PipelineSchedule, _PipelineScheduleRuntime, get_schedule_class, + OVERLAP_F_B, PipelineScheduleMulti, PipelineScheduleSingle, ScheduleDualPipeV, @@ -27,6 +28,7 @@ from torchtitan.components.loss import LossFunction, rescale_accumulated_loss from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.distributed.dual_pipe_v import overlap_callback from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction from torchtitan.tools.logging import logger @@ -209,6 +211,11 @@ def build_pipeline_schedule( f"with {n_microbatches} microbatches and {num_total_stages} stages." ) + if job_config.parallelism.pipeline_parallel_expert_parallel_overlap and isinstance( + schedule, ScheduleDualPipeV + ): + schedule.register_custom_function(OVERLAP_F_B, overlap_callback) + if pp_schedule_csv: assert schedule_class in [ PipelineScheduleSingle, diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 525bd96c13..911c2e7b17 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -97,8 +97,8 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, - attn_mask_type="block_causal", + use_flex_attn=False, + # attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 8d13a3f31f..ac92e757ba 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -101,6 +101,9 @@ def parallelize_deepseekv3( else None ), etp_enabled=parallel_dims.etp_enabled, + dual_pipe_v=job_config.parallelism.pipeline_parallel_expert_parallel_overlap + and job_config.parallelism.pipeline_parallel_schedule.lower() + == "dualpipev", ) model_compile_enabled = ( diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml index 1951cc4350..d0cdf2640f 100644 --- a/torchtitan/models/deepseek_v3/train_configs/debug_model.toml +++ b/torchtitan/models/deepseek_v3/train_configs/debug_model.toml @@ -4,9 +4,10 @@ description = "DeepSeek-V3 debug training" print_config = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 10 +profile_freq = 1 +profiler_warmup = 0 enable_memory_snapshot = false save_memory_snapshot_folder = "memory_snapshot" @@ -30,17 +31,18 @@ lr = 8e-4 eps = 1e-8 [lr_scheduler] -warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +warmup_steps = 0 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" min_lr_factor = 0.0 [training] -local_batch_size = 8 -seq_len = 2048 +local_batch_size = 4 +seq_len = 4 max_norm = 1.0 # grad norm clipping -steps = 10 +steps = 6 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) +# dataset = "c4" [parallelism] data_parallel_replicate_degree = 1 @@ -48,10 +50,10 @@ data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "1F1B" +pipeline_parallel_degree = 2 +expert_parallel_degree = 2 context_parallel_degree = 1 -expert_parallel_degree = 1 +pipeline_parallel_schedule = "DualPipeV" expert_tensor_parallel_degree = 1 [checkpoint] @@ -63,7 +65,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "none" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 00ec53310e..d8a38f80ae 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -4,7 +4,7 @@ description = "DeepSeek-V3 16B model training" print_config = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 enable_memory_snapshot = false @@ -38,7 +38,7 @@ min_lr_factor = 0.1 local_batch_size = 4 seq_len = 4096 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 30 dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -47,9 +47,9 @@ data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always tensor_parallel_degree = 1 enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 8 +pipeline_parallel_degree = 2 +pipeline_parallel_schedule = "DualPipeV" +expert_parallel_degree = 4 expert_tensor_parallel_degree = 1 [checkpoint] @@ -61,7 +61,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "none" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index 1f579ccd04..b54bb702d6 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -21,6 +21,7 @@ from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.dual_pipe_v import DualPipeExpertParallel from torchtitan.distributed.expert_parallel import ( ExpertParallel, @@ -108,6 +109,8 @@ def parallelize_llama( else None ), etp_enabled=parallel_dims.etp_enabled, + dual_pipe_v=job_config.parallelism.pipeline_parallel_expert_parallel_overlap + and job_config.parallelism.pipeline_parallel_schedule == "dualpipev", ) model_compile_enabled = ( @@ -440,6 +443,7 @@ def apply_moe_ep_tp( ep_mesh: DeviceMesh | None, ep_tp_mesh: DeviceMesh | None, etp_enabled: bool, + dual_pipe_v: bool = False, ): assert ep_mesh is not None or tp_mesh is not None @@ -491,7 +495,7 @@ def apply_moe_ep_tp( elif tp_mesh is None or not etp_enabled: experts_mesh = ep_mesh # input / output sharding on the batch / tokens dim - experts_plan = ExpertParallel() + experts_plan = DualPipeExpertParallel() if dual_pipe_v else ExpertParallel() else: experts_mesh = ep_tp_mesh experts_plan = ExpertTensorParallel()