From 33932e1bb27f56f6c63c4ba07794e8a062a4c3da Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 12 Nov 2025 16:29:41 -0800 Subject: [PATCH] WIP [ghstack-poisoned] --- torchtitan/distributed/pipeline_parallel.py | 50 +++++++++++++++++++ .../train_configs/deepseek_v3_16b.toml | 2 +- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index bafefddbec..fe676f1781 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -40,6 +40,53 @@ ] +def _override_torch_ops_for_zero_bubble(): + class MmSeparateWeightGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, i, w): + ctx.save_for_backward(i) + return w + + @staticmethod + def backward(ctx, grad_output): + (i,) = ctx.saved_tensors + grad_weight = i.t().mm(grad_output) + return None, grad_weight + + class MmSeparateInputGrad(torch.autograd.Function): + @staticmethod + def forward(ctx, i, w): + ctx.save_for_backward(w) + return i + + @staticmethod + def backward(ctx, grad_output): + (w,) = ctx.saved_tensors + grad_input = grad_output.mm(w.t()) + return grad_input, None + + class MmPassThrough(torch.autograd.Function): + @staticmethod + def forward(ctx, x, y): + return torch.mm(x, y) + + @staticmethod + def backward(ctx, gO): + return gO, gO + + def split_mm(i, w): + print("split mul") + # Apply the pass-through node. y is passed to this node so that it can be + # saved for backward, but detach because we don't want to actually build + # this edge of the graph + w1 = MmSeparateWeightGrad.apply(i.detach(), w) + i1 = MmSeparateInputGrad.apply(i, w.detach()) + return MmPassThrough.apply(i1, w1) + + lib = torch.library.Library("aten", "IMPL") + lib.impl("mm", split_mm, "Autograd") + + def pipeline_llm( model: nn.Module, parallel_dims: ParallelDims, @@ -51,6 +98,9 @@ def pipeline_llm( ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: pp_mesh = parallel_dims.world_mesh["pp"] + if True: + _override_torch_ops_for_zero_bubble() + # Determine the number of virtual stages based on schedule type schedule_class = get_schedule_class( job_config.parallelism.pipeline_parallel_schedule diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index d8a38f80ae..cb1fb30fe3 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -61,7 +61,7 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" [activation_checkpoint] -mode = "none" # ["none", "selective", "full"] +mode = "selective" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile]