@@ -112,6 +112,74 @@ def matmul_bwd(
112112 return grad_mat1 , grad_mat2
113113
114114
115+ @helion .kernel
116+ def addmm_bwd (
117+ grad_out : Tensor , # [m, n] gradient w.r.t output
118+ input : Tensor , # [m, n] or broadcastable bias tensor
119+ mat1 : Tensor , # [m, k] first matrix
120+ mat2 : Tensor , # [k, n] second matrix
121+ alpha : float = 1.0 , # scalar multiplier for matmul
122+ beta : float = 1.0 , # scalar multiplier for bias
123+ ) -> tuple [Tensor , Tensor , Tensor ]:
124+ """
125+ Backward pass for addmm operation following Triton reference pattern.
126+
127+ Forward: output = beta * input + alpha * (mat1 @ mat2)
128+
129+ Based on the Triton kernel analysis:
130+ - grad_input = beta * grad_out (with proper reduction for broadcasting)
131+ - grad_mat1 = alpha * (grad_out @ mat2.T)
132+ - grad_mat2 = alpha * (mat1.T @ grad_out)
133+
134+ Args:
135+ grad_out: Gradient w.r.t output [m, n]
136+ input: Bias tensor [m, n] (or broadcastable)
137+ mat1: First matrix [m, k]
138+ mat2: Second matrix [k, n]
139+ alpha: Scalar multiplier for matmul
140+ beta: Scalar multiplier for bias
141+
142+ Returns:
143+ tuple[Tensor, Tensor, Tensor]: (grad_input, grad_mat1, grad_mat2)
144+ """
145+ # Get all dimensions first
146+ m , n = grad_out .size ()
147+ m2 , k = mat1 .size ()
148+ k2 , n2 = mat2 .size ()
149+
150+ # All assertions at the top
151+ assert m == m2 and n == n2 and k == k2 , "Size mismatch in addmm backward"
152+
153+ # Declare ALL output tensors at the top before any loops
154+ grad_input = torch .empty_like (input )
155+ grad_mat1 = torch .empty_like (mat1 )
156+ grad_mat2 = torch .empty_like (mat2 )
157+
158+ # Handle grad_input = beta * grad_out (assuming same shape for now)
159+ for tile_m3 , tile_n3 in hl .tile ([m , n ]):
160+ grad_input [tile_m3 , tile_n3 ] = beta * grad_out [tile_m3 , tile_n3 ]
161+
162+ # First loop block: compute grad_mat1 = alpha * (grad_out @ mat2.T)
163+ for tile_m1 , tile_k1 in hl .tile ([m , k ]):
164+ acc1 = hl .zeros ([tile_m1 , tile_k1 ], dtype = torch .float32 )
165+ for tile_n1 in hl .tile (n ):
166+ acc1 = torch .addmm (
167+ acc1 , grad_out [tile_m1 , tile_n1 ], mat2 [tile_k1 , tile_n1 ].T
168+ )
169+ grad_mat1 [tile_m1 , tile_k1 ] = (alpha * acc1 ).to (mat1 .dtype )
170+
171+ # Second loop block: compute grad_mat2 = alpha * (mat1.T @ grad_out)
172+ for tile_k2 , tile_n2 in hl .tile ([k , n ]):
173+ acc2 = hl .zeros ([tile_k2 , tile_n2 ], dtype = torch .float32 )
174+ for tile_m2 in hl .tile (m ):
175+ acc2 = torch .addmm (
176+ acc2 , mat1 [tile_m2 , tile_k2 ].T , grad_out [tile_m2 , tile_n2 ]
177+ )
178+ grad_mat2 [tile_k2 , tile_n2 ] = (alpha * acc2 ).to (mat2 .dtype )
179+
180+ return grad_input , grad_mat1 , grad_mat2
181+
182+
115183# %%
116184class MatMulFunction (torch .autograd .Function ):
117185 @staticmethod
@@ -141,6 +209,45 @@ def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor:
141209 return MatMulFunction .apply (mat1 , mat2 ) # type: ignore[no-any-return]
142210
143211
212+ class AddMMFunction (torch .autograd .Function ):
213+ @staticmethod
214+ def forward (
215+ ctx : Any ,
216+ input : Tensor ,
217+ mat1 : Tensor ,
218+ mat2 : Tensor ,
219+ alpha : float = 1.0 ,
220+ beta : float = 1.0 ,
221+ ) -> Tensor :
222+ """Forward pass for addmm operation."""
223+ result = torch .addmm (input , mat1 , mat2 , alpha = alpha , beta = beta )
224+ ctx .save_for_backward (input , mat1 , mat2 )
225+ ctx .alpha = alpha
226+ ctx .beta = beta
227+ return result
228+
229+ @staticmethod
230+ def backward (
231+ ctx : Any ,
232+ grad_out : Tensor ,
233+ ) -> tuple [Tensor | None , Tensor | None , Tensor | None , None , None ]:
234+ """Backward pass for addmm operation."""
235+ input , mat1 , mat2 = ctx .saved_tensors
236+ alpha = ctx .alpha
237+ beta = ctx .beta
238+ grad_input , grad_mat1 , grad_mat2 = addmm_bwd (
239+ grad_out , input , mat1 , mat2 , alpha , beta
240+ )
241+ return grad_input , grad_mat1 , grad_mat2 , None , None
242+
243+
244+ def addmm_autograd (
245+ input : Tensor , mat1 : Tensor , mat2 : Tensor , alpha : float = 1.0 , beta : float = 1.0
246+ ) -> Tensor :
247+ """AddMM operation with forward + backward support."""
248+ return AddMMFunction .apply (input , mat1 , mat2 , alpha , beta ) # type: ignore[no-any-return]
249+
250+
144251# %%
145252def autotune (m : int , k : int , n : int ) -> None :
146253 """
@@ -230,6 +337,47 @@ def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
230337 bwd = True ,
231338 )
232339
340+ # Test addmm forward + backward pass
341+ print ("\n \n === AddMM Forward + Backward Pass Test ===" )
342+ input_grad = torch .randn (
343+ [m , n ], device = "cuda" , dtype = torch .float16 , requires_grad = True
344+ )
345+ mat1_grad = torch .randn (
346+ [m , k ], device = "cuda" , dtype = torch .float16 , requires_grad = True
347+ )
348+ mat2_grad = torch .randn (
349+ [k , n ], device = "cuda" , dtype = torch .float16 , requires_grad = True
350+ )
351+
352+ # Use lambda to handle the keyword argument format for torch.addmm
353+ run_example (
354+ addmm_autograd ,
355+ lambda input , mat1 , mat2 , alpha , beta : torch .addmm (
356+ input , mat1 , mat2 , alpha = alpha , beta = beta
357+ ),
358+ (input_grad , mat1_grad , mat2_grad , 1.0 , 1.0 ),
359+ kernel_name = "helion_addmm_autograd" ,
360+ baseline_name = "torch" ,
361+ rtol = 1e-2 ,
362+ atol = 1e-2 ,
363+ bwd = True ,
364+ )
365+
366+ # Test addmm forward + backward with different alpha/beta values
367+ print ("\n \n === AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===" )
368+ run_example (
369+ addmm_autograd ,
370+ lambda input , mat1 , mat2 , alpha , beta : torch .addmm (
371+ input , mat1 , mat2 , alpha = alpha , beta = beta
372+ ),
373+ (input_grad , mat1_grad , mat2_grad , 2.0 , 0.5 ),
374+ kernel_name = "helion_addmm_autograd_scaled" ,
375+ baseline_name = "torch" ,
376+ rtol = 1e-2 ,
377+ atol = 1e-2 ,
378+ bwd = True ,
379+ )
380+
233381
234382# %%
235383def matmul_tritonbench (
0 commit comments