Skip to content

Commit 9da6b11

Browse files
wconstabH-Huang
authored andcommitted
WIP
ghstack-source-id: 7ce172d Pull Request resolved: pytorch#2032
1 parent f3e551f commit 9da6b11

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

torchtitan/distributed/pipeline_parallel.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,53 @@
4040
]
4141

4242

43+
def _override_torch_ops_for_zero_bubble():
44+
class MmSeparateWeightGrad(torch.autograd.Function):
45+
@staticmethod
46+
def forward(ctx, i, w):
47+
ctx.save_for_backward(i)
48+
return w
49+
50+
@staticmethod
51+
def backward(ctx, grad_output):
52+
(i,) = ctx.saved_tensors
53+
grad_weight = i.t().mm(grad_output)
54+
return None, grad_weight
55+
56+
class MmSeparateInputGrad(torch.autograd.Function):
57+
@staticmethod
58+
def forward(ctx, i, w):
59+
ctx.save_for_backward(w)
60+
return i
61+
62+
@staticmethod
63+
def backward(ctx, grad_output):
64+
(w,) = ctx.saved_tensors
65+
grad_input = grad_output.mm(w.t())
66+
return grad_input, None
67+
68+
class MmPassThrough(torch.autograd.Function):
69+
@staticmethod
70+
def forward(ctx, x, y):
71+
return torch.mm(x, y)
72+
73+
@staticmethod
74+
def backward(ctx, gO):
75+
return gO, gO
76+
77+
def split_mm(i, w):
78+
print("split mul")
79+
# Apply the pass-through node. y is passed to this node so that it can be
80+
# saved for backward, but detach because we don't want to actually build
81+
# this edge of the graph
82+
w1 = MmSeparateWeightGrad.apply(i.detach(), w)
83+
i1 = MmSeparateInputGrad.apply(i, w.detach())
84+
return MmPassThrough.apply(i1, w1)
85+
86+
lib = torch.library.Library("aten", "IMPL")
87+
lib.impl("mm", split_mm, "Autograd")
88+
89+
4390
def pipeline_llm(
4491
model: nn.Module,
4592
parallel_dims: ParallelDims,
@@ -51,6 +98,9 @@ def pipeline_llm(
5198
) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]:
5299
pp_mesh = parallel_dims.world_mesh["pp"]
53100

101+
if True:
102+
_override_torch_ops_for_zero_bubble()
103+
54104
# Determine the number of virtual stages based on schedule type
55105
schedule_class = get_schedule_class(
56106
job_config.parallelism.pipeline_parallel_schedule

torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ export_dtype = "float32"
6161
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]"
6262

6363
[activation_checkpoint]
64-
mode = "none" # ["none", "selective", "full"]
64+
mode = "selective" # ["none", "selective", "full"]
6565
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
6666

6767
[compile]

0 commit comments

Comments
 (0)