Skip to content

Commit e9e3cae

Browse files
authored
Add matmul/addmm bwd examples and add test coverage
Differential Revision: D84363848 Pull Request resolved: #748
1 parent 0b8f9de commit e9e3cae

File tree

5 files changed

+595
-51
lines changed

5 files changed

+595
-51
lines changed

benchmarks/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ class RunResult:
111111
"examples.matmul",
112112
"addmm_tritonbench",
113113
),
114+
"addmm-bwd": (
115+
"tritonbench.operators.addmm.operator",
116+
"examples.matmul",
117+
"addmm_tritonbench",
118+
),
114119
"geglu": (
115120
"tritonbench.operators.geglu.operator",
116121
"examples.geglu",
@@ -252,6 +257,14 @@ class RunResult:
252257
"num_inputs": 6, # gemm takes long time on Benchmark CI, so use fewer inputs instead.
253258
},
254259
),
260+
"gemm-bwd": (
261+
"tritonbench.operators.gemm.operator",
262+
"examples.matmul",
263+
"matmul_tritonbench",
264+
{
265+
"num_inputs": 10, # gemm-bwd takes long time on Benchmark CI, so use fewer inputs instead.
266+
},
267+
),
255268
"welford": (
256269
"tritonbench.operators.welford.operator",
257270
"examples.welford",

examples/matmul.py

Lines changed: 265 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from __future__ import annotations
1111

1212
from typing import TYPE_CHECKING
13+
from typing import Any
1314

1415
import torch
1516
from torch import Tensor
@@ -56,6 +57,208 @@ def matmul(
5657
return out
5758

5859

60+
@helion.kernel
61+
def matmul_bwd(
62+
grad_out: Tensor, # [m, n] gradient w.r.t output
63+
mat1: Tensor, # [m, k] first matrix
64+
mat2: Tensor, # [k, n] second matrix
65+
) -> tuple[Tensor, Tensor]:
66+
"""
67+
Backward pass for matrix multiplication following Triton reference pattern.
68+
69+
For C = A @ B, given grad_C, computes:
70+
- grad_A = grad_C @ B.T
71+
- grad_B = A.T @ grad_C
72+
73+
Args:
74+
grad_out: Gradient w.r.t output [m, n]
75+
mat1: First matrix [m, k]
76+
mat2: Second matrix [k, n]
77+
78+
Returns:
79+
tuple[Tensor, Tensor]: (grad_mat1, grad_mat2)
80+
"""
81+
# Get all dimensions first
82+
m, n = grad_out.size()
83+
m2, k = mat1.size()
84+
k2, n2 = mat2.size()
85+
86+
# All assertions at the top
87+
assert m == m2 and n == n2 and k == k2, "Size mismatch in matmul backward"
88+
89+
# Declare ALL output tensors at the top before any loops
90+
grad_mat1 = torch.empty_like(mat1)
91+
grad_mat2 = torch.empty_like(mat2)
92+
93+
# First loop block: compute grad_mat1 = grad_out @ mat2.T
94+
for tile_m1, tile_k1 in hl.tile([m, k]):
95+
acc1 = hl.zeros([tile_m1, tile_k1], dtype=torch.float32)
96+
for tile_n1 in hl.tile(n):
97+
# Need mat2.T: mat2 is [k, n], so mat2[tile_k, tile_n].T gives [tile_n, tile_k]
98+
acc1 = torch.addmm(
99+
acc1, grad_out[tile_m1, tile_n1], mat2[tile_k1, tile_n1].T
100+
)
101+
grad_mat1[tile_m1, tile_k1] = acc1.to(mat1.dtype)
102+
103+
# Second loop block: compute grad_mat2 = mat1.T @ grad_out
104+
for tile_k2, tile_n2 in hl.tile([k, n]):
105+
acc2 = hl.zeros([tile_k2, tile_n2], dtype=torch.float32)
106+
for tile_m2 in hl.tile(m):
107+
# Need mat1.T: mat1 is [m, k], so mat1[tile_m, tile_k].T gives [tile_k, tile_m]
108+
acc2 = torch.addmm(
109+
acc2, mat1[tile_m2, tile_k2].T, grad_out[tile_m2, tile_n2]
110+
)
111+
grad_mat2[tile_k2, tile_n2] = acc2.to(mat2.dtype)
112+
113+
return grad_mat1, grad_mat2
114+
115+
116+
@helion.kernel
117+
def addmm_bwd(
118+
grad_out: Tensor, # [m, n] gradient w.r.t output
119+
bias: Tensor, # [m, n] or broadcastable bias tensor
120+
mat1: Tensor, # [m, k] first matrix
121+
mat2: Tensor, # [k, n] second matrix
122+
alpha: float = 1.0, # scalar multiplier for matmul
123+
beta: float = 1.0, # scalar multiplier for bias
124+
) -> tuple[Tensor, Tensor, Tensor]:
125+
"""
126+
Backward pass for addmm operation following Triton reference pattern.
127+
128+
Forward: output = beta * bias + alpha * (mat1 @ mat2)
129+
130+
Based on the Triton kernel analysis:
131+
- grad_input = beta * grad_out (with proper reduction for broadcasting)
132+
- grad_mat1 = alpha * (grad_out @ mat2.T)
133+
- grad_mat2 = alpha * (mat1.T @ grad_out)
134+
135+
Args:
136+
grad_out: Gradient w.r.t output [m, n]
137+
bias: Bias tensor [m, n] (or broadcastable)
138+
mat1: First matrix [m, k]
139+
mat2: Second matrix [k, n]
140+
alpha: Scalar multiplier for matmul
141+
beta: Scalar multiplier for bias
142+
143+
Returns:
144+
tuple[Tensor, Tensor, Tensor]: (grad_input, grad_mat1, grad_mat2)
145+
"""
146+
# Get all dimensions first
147+
m, n = grad_out.size()
148+
m2, k = mat1.size()
149+
k2, n2 = mat2.size()
150+
151+
# All assertions at the top
152+
assert m == m2 and n == n2 and k == k2, "Size mismatch in addmm backward"
153+
154+
# Declare ALL output tensors at the top before any loops
155+
grad_input = torch.empty_like(bias)
156+
grad_mat1 = torch.empty_like(mat1)
157+
grad_mat2 = torch.empty_like(mat2)
158+
159+
# Handle grad_input = beta * grad_out (assuming same shape for now)
160+
for tile_m3, tile_n3 in hl.tile([m, n]):
161+
grad_input[tile_m3, tile_n3] = beta * grad_out[tile_m3, tile_n3]
162+
163+
# First loop block: compute grad_mat1 = alpha * (grad_out @ mat2.T)
164+
for tile_m1, tile_k1 in hl.tile([m, k]):
165+
acc1 = hl.zeros([tile_m1, tile_k1], dtype=torch.float32)
166+
for tile_n1 in hl.tile(n):
167+
acc1 = torch.addmm(
168+
acc1, grad_out[tile_m1, tile_n1], mat2[tile_k1, tile_n1].T
169+
)
170+
grad_mat1[tile_m1, tile_k1] = (alpha * acc1).to(mat1.dtype)
171+
172+
# Second loop block: compute grad_mat2 = alpha * (mat1.T @ grad_out)
173+
for tile_k2, tile_n2 in hl.tile([k, n]):
174+
acc2 = hl.zeros([tile_k2, tile_n2], dtype=torch.float32)
175+
for tile_m2 in hl.tile(m):
176+
acc2 = torch.addmm(
177+
acc2, mat1[tile_m2, tile_k2].T, grad_out[tile_m2, tile_n2]
178+
)
179+
grad_mat2[tile_k2, tile_n2] = (alpha * acc2).to(mat2.dtype)
180+
181+
return grad_input, grad_mat1, grad_mat2
182+
183+
184+
# %%
185+
class MatMulFunction(torch.autograd.Function):
186+
@staticmethod
187+
def forward(
188+
ctx: Any, # noqa: ANN401
189+
mat1: Tensor,
190+
mat2: Tensor,
191+
) -> Tensor:
192+
"""Forward pass for matrix multiplication."""
193+
result = matmul(mat1, mat2)
194+
ctx.save_for_backward(mat1, mat2)
195+
return result
196+
197+
@staticmethod
198+
def backward(
199+
ctx: Any, # noqa: ANN401
200+
*grad_outputs: Tensor,
201+
) -> tuple[Tensor | None, Tensor | None]:
202+
"""Backward pass for matrix multiplication."""
203+
grad_out = grad_outputs[0]
204+
mat1, mat2 = ctx.saved_tensors
205+
grad_mat1, grad_mat2 = matmul_bwd(grad_out, mat1, mat2)
206+
return grad_mat1, grad_mat2
207+
208+
209+
def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor:
210+
"""Matrix multiplication with forward + backward support."""
211+
return MatMulFunction.apply(mat1, mat2) # type: ignore[no-any-return]
212+
213+
214+
class AddMMFunction(torch.autograd.Function):
215+
@staticmethod
216+
def forward(
217+
ctx: Any, # noqa: ANN401
218+
bias: Tensor,
219+
mat1: Tensor,
220+
mat2: Tensor,
221+
alpha: float = 1.0,
222+
beta: float = 1.0,
223+
) -> Tensor:
224+
"""Forward pass for addmm operation using helion matmul with epilogue."""
225+
m, k = mat1.size()
226+
k2, n = mat2.size()
227+
input_broadcasted = torch.broadcast_to(bias, [m, n])
228+
229+
# Define epilogue that adds bias: alpha * acc + beta * bias
230+
def addmm_epilogue(acc: Tensor, tile: tuple[Tensor, ...]) -> Tensor:
231+
return alpha * acc + beta * input_broadcasted[tile[0], tile[1]]
232+
233+
result = matmul(mat1, mat2, addmm_epilogue)
234+
ctx.save_for_backward(bias, mat1, mat2)
235+
ctx.alpha = alpha
236+
ctx.beta = beta
237+
return result
238+
239+
@staticmethod
240+
def backward(
241+
ctx: Any, # noqa: ANN401
242+
*grad_outputs: Tensor,
243+
) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]:
244+
"""Backward pass for addmm operation."""
245+
grad_out = grad_outputs[0]
246+
bias, mat1, mat2 = ctx.saved_tensors
247+
alpha = ctx.alpha
248+
beta = ctx.beta
249+
grad_input, grad_mat1, grad_mat2 = addmm_bwd(
250+
grad_out, bias, mat1, mat2, alpha, beta
251+
)
252+
return grad_input, grad_mat1, grad_mat2, None, None
253+
254+
255+
def addmm_autograd(
256+
bias: Tensor, mat1: Tensor, mat2: Tensor, alpha: float = 1.0, beta: float = 1.0
257+
) -> Tensor:
258+
"""AddMM operation with forward + backward support."""
259+
return AddMMFunction.apply(bias, mat1, mat2, alpha, beta) # type: ignore[no-any-return]
260+
261+
59262
# %%
60263
def autotune(m: int, k: int, n: int) -> None:
61264
"""
@@ -129,6 +332,63 @@ def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
129332
(x, y),
130333
)
131334

335+
# Test matmul forward + backward pass
336+
print("\n\n=== MatMul Forward + Backward Pass Test ===")
337+
x_grad = torch.randn([m, k], device="cuda", dtype=torch.float16, requires_grad=True)
338+
y_grad = torch.randn([k, n], device="cuda", dtype=torch.float16, requires_grad=True)
339+
340+
run_example(
341+
matmul_autograd,
342+
torch.matmul,
343+
(x_grad, y_grad),
344+
kernel_name="helion_matmul_autograd",
345+
baseline_name="torch",
346+
rtol=1e-2,
347+
atol=1e-2,
348+
bwd=True,
349+
)
350+
351+
# Test addmm forward + backward pass
352+
print("\n\n=== AddMM Forward + Backward Pass Test ===")
353+
input_grad = torch.randn(
354+
[m, n], device="cuda", dtype=torch.float16, requires_grad=True
355+
)
356+
mat1_grad = torch.randn(
357+
[m, k], device="cuda", dtype=torch.float16, requires_grad=True
358+
)
359+
mat2_grad = torch.randn(
360+
[k, n], device="cuda", dtype=torch.float16, requires_grad=True
361+
)
362+
363+
# Use lambda to handle the keyword argument format for torch.addmm
364+
run_example(
365+
addmm_autograd,
366+
lambda bias, mat1, mat2, alpha, beta: torch.addmm(
367+
bias, mat1, mat2, alpha=alpha, beta=beta
368+
),
369+
(input_grad, mat1_grad, mat2_grad, 1.0, 1.0),
370+
kernel_name="helion_addmm_autograd",
371+
baseline_name="torch",
372+
rtol=1e-2,
373+
atol=1e-2,
374+
bwd=True,
375+
)
376+
377+
# Test addmm forward + backward with different alpha/beta values
378+
print("\n\n=== AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===")
379+
run_example(
380+
addmm_autograd,
381+
lambda bias, mat1, mat2, alpha, beta: torch.addmm(
382+
bias, mat1, mat2, alpha=alpha, beta=beta
383+
),
384+
(input_grad, mat1_grad, mat2_grad, 2.0, 0.5),
385+
kernel_name="helion_addmm_autograd_scaled",
386+
baseline_name="torch",
387+
rtol=1e-2,
388+
atol=1e-2,
389+
bwd=True,
390+
)
391+
132392

133393
# %%
134394
def matmul_tritonbench(
@@ -145,8 +405,9 @@ def matmul_tritonbench(
145405
Callable: A callable that runs the matmul kernel with or without bias.
146406
"""
147407
if bias is not None:
148-
return lambda: matmul(a, b, lambda acc, tile: acc + bias[tile[1]])
149-
return lambda: matmul(a, b)
408+
# For gemm with bias, use matmul_autograd and add bias
409+
return lambda: matmul_autograd(a, b) + bias
410+
return lambda: matmul_autograd(a, b)
150411

151412

152413
def addmm_tritonbench(
@@ -160,12 +421,9 @@ def addmm_tritonbench(
160421
mat1 (torch.Tensor): Left matrix.
161422
mat2 (torch.Tensor): Right matrix.
162423
Returns:
163-
Callable: A callable that runs the matmul kernel with bias.
424+
Callable: A callable that runs the addmm autograd function with bias.
164425
"""
165-
m, k = mat1.size()
166-
k2, n = mat2.size()
167-
bias = torch.broadcast_to(bias, [m, n])
168-
return lambda: matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]])
426+
return lambda: addmm_autograd(bias, mat1, mat2)
169427

170428

171429
# %%

0 commit comments

Comments
 (0)