1010from __future__ import annotations
1111
1212from typing import TYPE_CHECKING
13+ from typing import Any
1314
1415import torch
1516from torch import Tensor
@@ -56,6 +57,208 @@ def matmul(
5657 return out
5758
5859
60+ @helion .kernel
61+ def matmul_bwd (
62+ grad_out : Tensor , # [m, n] gradient w.r.t output
63+ mat1 : Tensor , # [m, k] first matrix
64+ mat2 : Tensor , # [k, n] second matrix
65+ ) -> tuple [Tensor , Tensor ]:
66+ """
67+ Backward pass for matrix multiplication following Triton reference pattern.
68+
69+ For C = A @ B, given grad_C, computes:
70+ - grad_A = grad_C @ B.T
71+ - grad_B = A.T @ grad_C
72+
73+ Args:
74+ grad_out: Gradient w.r.t output [m, n]
75+ mat1: First matrix [m, k]
76+ mat2: Second matrix [k, n]
77+
78+ Returns:
79+ tuple[Tensor, Tensor]: (grad_mat1, grad_mat2)
80+ """
81+ # Get all dimensions first
82+ m , n = grad_out .size ()
83+ m2 , k = mat1 .size ()
84+ k2 , n2 = mat2 .size ()
85+
86+ # All assertions at the top
87+ assert m == m2 and n == n2 and k == k2 , "Size mismatch in matmul backward"
88+
89+ # Declare ALL output tensors at the top before any loops
90+ grad_mat1 = torch .empty_like (mat1 )
91+ grad_mat2 = torch .empty_like (mat2 )
92+
93+ # First loop block: compute grad_mat1 = grad_out @ mat2.T
94+ for tile_m1 , tile_k1 in hl .tile ([m , k ]):
95+ acc1 = hl .zeros ([tile_m1 , tile_k1 ], dtype = torch .float32 )
96+ for tile_n1 in hl .tile (n ):
97+ # Need mat2.T: mat2 is [k, n], so mat2[tile_k, tile_n].T gives [tile_n, tile_k]
98+ acc1 = torch .addmm (
99+ acc1 , grad_out [tile_m1 , tile_n1 ], mat2 [tile_k1 , tile_n1 ].T
100+ )
101+ grad_mat1 [tile_m1 , tile_k1 ] = acc1 .to (mat1 .dtype )
102+
103+ # Second loop block: compute grad_mat2 = mat1.T @ grad_out
104+ for tile_k2 , tile_n2 in hl .tile ([k , n ]):
105+ acc2 = hl .zeros ([tile_k2 , tile_n2 ], dtype = torch .float32 )
106+ for tile_m2 in hl .tile (m ):
107+ # Need mat1.T: mat1 is [m, k], so mat1[tile_m, tile_k].T gives [tile_k, tile_m]
108+ acc2 = torch .addmm (
109+ acc2 , mat1 [tile_m2 , tile_k2 ].T , grad_out [tile_m2 , tile_n2 ]
110+ )
111+ grad_mat2 [tile_k2 , tile_n2 ] = acc2 .to (mat2 .dtype )
112+
113+ return grad_mat1 , grad_mat2
114+
115+
116+ @helion .kernel
117+ def addmm_bwd (
118+ grad_out : Tensor , # [m, n] gradient w.r.t output
119+ bias : Tensor , # [m, n] or broadcastable bias tensor
120+ mat1 : Tensor , # [m, k] first matrix
121+ mat2 : Tensor , # [k, n] second matrix
122+ alpha : float = 1.0 , # scalar multiplier for matmul
123+ beta : float = 1.0 , # scalar multiplier for bias
124+ ) -> tuple [Tensor , Tensor , Tensor ]:
125+ """
126+ Backward pass for addmm operation following Triton reference pattern.
127+
128+ Forward: output = beta * bias + alpha * (mat1 @ mat2)
129+
130+ Based on the Triton kernel analysis:
131+ - grad_input = beta * grad_out (with proper reduction for broadcasting)
132+ - grad_mat1 = alpha * (grad_out @ mat2.T)
133+ - grad_mat2 = alpha * (mat1.T @ grad_out)
134+
135+ Args:
136+ grad_out: Gradient w.r.t output [m, n]
137+ bias: Bias tensor [m, n] (or broadcastable)
138+ mat1: First matrix [m, k]
139+ mat2: Second matrix [k, n]
140+ alpha: Scalar multiplier for matmul
141+ beta: Scalar multiplier for bias
142+
143+ Returns:
144+ tuple[Tensor, Tensor, Tensor]: (grad_input, grad_mat1, grad_mat2)
145+ """
146+ # Get all dimensions first
147+ m , n = grad_out .size ()
148+ m2 , k = mat1 .size ()
149+ k2 , n2 = mat2 .size ()
150+
151+ # All assertions at the top
152+ assert m == m2 and n == n2 and k == k2 , "Size mismatch in addmm backward"
153+
154+ # Declare ALL output tensors at the top before any loops
155+ grad_input = torch .empty_like (bias )
156+ grad_mat1 = torch .empty_like (mat1 )
157+ grad_mat2 = torch .empty_like (mat2 )
158+
159+ # Handle grad_input = beta * grad_out (assuming same shape for now)
160+ for tile_m3 , tile_n3 in hl .tile ([m , n ]):
161+ grad_input [tile_m3 , tile_n3 ] = beta * grad_out [tile_m3 , tile_n3 ]
162+
163+ # First loop block: compute grad_mat1 = alpha * (grad_out @ mat2.T)
164+ for tile_m1 , tile_k1 in hl .tile ([m , k ]):
165+ acc1 = hl .zeros ([tile_m1 , tile_k1 ], dtype = torch .float32 )
166+ for tile_n1 in hl .tile (n ):
167+ acc1 = torch .addmm (
168+ acc1 , grad_out [tile_m1 , tile_n1 ], mat2 [tile_k1 , tile_n1 ].T
169+ )
170+ grad_mat1 [tile_m1 , tile_k1 ] = (alpha * acc1 ).to (mat1 .dtype )
171+
172+ # Second loop block: compute grad_mat2 = alpha * (mat1.T @ grad_out)
173+ for tile_k2 , tile_n2 in hl .tile ([k , n ]):
174+ acc2 = hl .zeros ([tile_k2 , tile_n2 ], dtype = torch .float32 )
175+ for tile_m2 in hl .tile (m ):
176+ acc2 = torch .addmm (
177+ acc2 , mat1 [tile_m2 , tile_k2 ].T , grad_out [tile_m2 , tile_n2 ]
178+ )
179+ grad_mat2 [tile_k2 , tile_n2 ] = (alpha * acc2 ).to (mat2 .dtype )
180+
181+ return grad_input , grad_mat1 , grad_mat2
182+
183+
184+ # %%
185+ class MatMulFunction (torch .autograd .Function ):
186+ @staticmethod
187+ def forward (
188+ ctx : Any , # noqa: ANN401
189+ mat1 : Tensor ,
190+ mat2 : Tensor ,
191+ ) -> Tensor :
192+ """Forward pass for matrix multiplication."""
193+ result = matmul (mat1 , mat2 )
194+ ctx .save_for_backward (mat1 , mat2 )
195+ return result
196+
197+ @staticmethod
198+ def backward (
199+ ctx : Any , # noqa: ANN401
200+ * grad_outputs : Tensor ,
201+ ) -> tuple [Tensor | None , Tensor | None ]:
202+ """Backward pass for matrix multiplication."""
203+ grad_out = grad_outputs [0 ]
204+ mat1 , mat2 = ctx .saved_tensors
205+ grad_mat1 , grad_mat2 = matmul_bwd (grad_out , mat1 , mat2 )
206+ return grad_mat1 , grad_mat2
207+
208+
209+ def matmul_autograd (mat1 : Tensor , mat2 : Tensor ) -> Tensor :
210+ """Matrix multiplication with forward + backward support."""
211+ return MatMulFunction .apply (mat1 , mat2 ) # type: ignore[no-any-return]
212+
213+
214+ class AddMMFunction (torch .autograd .Function ):
215+ @staticmethod
216+ def forward (
217+ ctx : Any , # noqa: ANN401
218+ bias : Tensor ,
219+ mat1 : Tensor ,
220+ mat2 : Tensor ,
221+ alpha : float = 1.0 ,
222+ beta : float = 1.0 ,
223+ ) -> Tensor :
224+ """Forward pass for addmm operation using helion matmul with epilogue."""
225+ m , k = mat1 .size ()
226+ k2 , n = mat2 .size ()
227+ input_broadcasted = torch .broadcast_to (bias , [m , n ])
228+
229+ # Define epilogue that adds bias: alpha * acc + beta * bias
230+ def addmm_epilogue (acc : Tensor , tile : tuple [Tensor , ...]) -> Tensor :
231+ return alpha * acc + beta * input_broadcasted [tile [0 ], tile [1 ]]
232+
233+ result = matmul (mat1 , mat2 , addmm_epilogue )
234+ ctx .save_for_backward (bias , mat1 , mat2 )
235+ ctx .alpha = alpha
236+ ctx .beta = beta
237+ return result
238+
239+ @staticmethod
240+ def backward (
241+ ctx : Any , # noqa: ANN401
242+ * grad_outputs : Tensor ,
243+ ) -> tuple [Tensor | None , Tensor | None , Tensor | None , None , None ]:
244+ """Backward pass for addmm operation."""
245+ grad_out = grad_outputs [0 ]
246+ bias , mat1 , mat2 = ctx .saved_tensors
247+ alpha = ctx .alpha
248+ beta = ctx .beta
249+ grad_input , grad_mat1 , grad_mat2 = addmm_bwd (
250+ grad_out , bias , mat1 , mat2 , alpha , beta
251+ )
252+ return grad_input , grad_mat1 , grad_mat2 , None , None
253+
254+
255+ def addmm_autograd (
256+ bias : Tensor , mat1 : Tensor , mat2 : Tensor , alpha : float = 1.0 , beta : float = 1.0
257+ ) -> Tensor :
258+ """AddMM operation with forward + backward support."""
259+ return AddMMFunction .apply (bias , mat1 , mat2 , alpha , beta ) # type: ignore[no-any-return]
260+
261+
59262# %%
60263def autotune (m : int , k : int , n : int ) -> None :
61264 """
@@ -129,6 +332,63 @@ def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
129332 (x , y ),
130333 )
131334
335+ # Test matmul forward + backward pass
336+ print ("\n \n === MatMul Forward + Backward Pass Test ===" )
337+ x_grad = torch .randn ([m , k ], device = "cuda" , dtype = torch .float16 , requires_grad = True )
338+ y_grad = torch .randn ([k , n ], device = "cuda" , dtype = torch .float16 , requires_grad = True )
339+
340+ run_example (
341+ matmul_autograd ,
342+ torch .matmul ,
343+ (x_grad , y_grad ),
344+ kernel_name = "helion_matmul_autograd" ,
345+ baseline_name = "torch" ,
346+ rtol = 1e-2 ,
347+ atol = 1e-2 ,
348+ bwd = True ,
349+ )
350+
351+ # Test addmm forward + backward pass
352+ print ("\n \n === AddMM Forward + Backward Pass Test ===" )
353+ input_grad = torch .randn (
354+ [m , n ], device = "cuda" , dtype = torch .float16 , requires_grad = True
355+ )
356+ mat1_grad = torch .randn (
357+ [m , k ], device = "cuda" , dtype = torch .float16 , requires_grad = True
358+ )
359+ mat2_grad = torch .randn (
360+ [k , n ], device = "cuda" , dtype = torch .float16 , requires_grad = True
361+ )
362+
363+ # Use lambda to handle the keyword argument format for torch.addmm
364+ run_example (
365+ addmm_autograd ,
366+ lambda bias , mat1 , mat2 , alpha , beta : torch .addmm (
367+ bias , mat1 , mat2 , alpha = alpha , beta = beta
368+ ),
369+ (input_grad , mat1_grad , mat2_grad , 1.0 , 1.0 ),
370+ kernel_name = "helion_addmm_autograd" ,
371+ baseline_name = "torch" ,
372+ rtol = 1e-2 ,
373+ atol = 1e-2 ,
374+ bwd = True ,
375+ )
376+
377+ # Test addmm forward + backward with different alpha/beta values
378+ print ("\n \n === AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===" )
379+ run_example (
380+ addmm_autograd ,
381+ lambda bias , mat1 , mat2 , alpha , beta : torch .addmm (
382+ bias , mat1 , mat2 , alpha = alpha , beta = beta
383+ ),
384+ (input_grad , mat1_grad , mat2_grad , 2.0 , 0.5 ),
385+ kernel_name = "helion_addmm_autograd_scaled" ,
386+ baseline_name = "torch" ,
387+ rtol = 1e-2 ,
388+ atol = 1e-2 ,
389+ bwd = True ,
390+ )
391+
132392
133393# %%
134394def matmul_tritonbench (
@@ -145,8 +405,9 @@ def matmul_tritonbench(
145405 Callable: A callable that runs the matmul kernel with or without bias.
146406 """
147407 if bias is not None :
148- return lambda : matmul (a , b , lambda acc , tile : acc + bias [tile [1 ]])
149- return lambda : matmul (a , b )
408+ # For gemm with bias, use matmul_autograd and add bias
409+ return lambda : matmul_autograd (a , b ) + bias
410+ return lambda : matmul_autograd (a , b )
150411
151412
152413def addmm_tritonbench (
@@ -160,12 +421,9 @@ def addmm_tritonbench(
160421 mat1 (torch.Tensor): Left matrix.
161422 mat2 (torch.Tensor): Right matrix.
162423 Returns:
163- Callable: A callable that runs the matmul kernel with bias.
424+ Callable: A callable that runs the addmm autograd function with bias.
164425 """
165- m , k = mat1 .size ()
166- k2 , n = mat2 .size ()
167- bias = torch .broadcast_to (bias , [m , n ])
168- return lambda : matmul (mat1 , mat2 , lambda acc , tile : acc + bias [tile [0 ], tile [1 ]])
426+ return lambda : addmm_autograd (bias , mat1 , mat2 )
169427
170428
171429# %%
0 commit comments