4040]
4141
4242
43+ def _override_torch_ops_for_zero_bubble ():
44+ class MmSeparateWeightGrad (torch .autograd .Function ):
45+ @staticmethod
46+ def forward (ctx , i , w ):
47+ ctx .save_for_backward (i )
48+ return w
49+
50+ @staticmethod
51+ def backward (ctx , grad_output ):
52+ (i ,) = ctx .saved_tensors
53+ grad_weight = i .t ().mm (grad_output )
54+ return None , grad_weight
55+
56+ class MmSeparateInputGrad (torch .autograd .Function ):
57+ @staticmethod
58+ def forward (ctx , i , w ):
59+ ctx .save_for_backward (w )
60+ return i
61+
62+ @staticmethod
63+ def backward (ctx , grad_output ):
64+ (w ,) = ctx .saved_tensors
65+ grad_input = grad_output .mm (w .t ())
66+ return grad_input , None
67+
68+ class MmPassThrough (torch .autograd .Function ):
69+ @staticmethod
70+ def forward (ctx , x , y ):
71+ return torch .mm (x , y )
72+
73+ @staticmethod
74+ def backward (ctx , gO ):
75+ return gO , gO
76+
77+ def split_mm (i , w ):
78+ print ("split mul" )
79+ # Apply the pass-through node. y is passed to this node so that it can be
80+ # saved for backward, but detach because we don't want to actually build
81+ # this edge of the graph
82+ w1 = MmSeparateWeightGrad .apply (i .detach (), w )
83+ i1 = MmSeparateInputGrad .apply (i , w .detach ())
84+ return MmPassThrough .apply (i1 , w1 )
85+
86+ lib = torch .library .Library ("aten" , "IMPL" )
87+ lib .impl ("mm" , split_mm , "Autograd" )
88+
89+
4390def pipeline_llm (
4491 model : nn .Module ,
4592 parallel_dims : ParallelDims ,
@@ -51,6 +98,9 @@ def pipeline_llm(
5198) -> tuple [_PipelineSchedule , list [nn .Module ], bool , bool ]:
5299 pp_mesh = parallel_dims .world_mesh ["pp" ]
53100
101+ if True :
102+ _override_torch_ops_for_zero_bubble ()
103+
54104 # Determine the number of virtual stages based on schedule type
55105 schedule_class = get_schedule_class (
56106 job_config .parallelism .pipeline_parallel_schedule
0 commit comments