Skip to content

Commit a7958ba

Browse files
committed
matmul bwd updated, test updated, fix issues
1 parent 22a9c97 commit a7958ba

File tree

3 files changed

+162
-7
lines changed

3 files changed

+162
-7
lines changed

examples/matmul.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,74 @@ def matmul_bwd(
112112
return grad_mat1, grad_mat2
113113

114114

115+
@helion.kernel
116+
def addmm_bwd(
117+
grad_out: Tensor, # [m, n] gradient w.r.t output
118+
input: Tensor, # [m, n] or broadcastable bias tensor
119+
mat1: Tensor, # [m, k] first matrix
120+
mat2: Tensor, # [k, n] second matrix
121+
alpha: float = 1.0, # scalar multiplier for matmul
122+
beta: float = 1.0, # scalar multiplier for bias
123+
) -> tuple[Tensor, Tensor, Tensor]:
124+
"""
125+
Backward pass for addmm operation following Triton reference pattern.
126+
127+
Forward: output = beta * input + alpha * (mat1 @ mat2)
128+
129+
Based on the Triton kernel analysis:
130+
- grad_input = beta * grad_out (with proper reduction for broadcasting)
131+
- grad_mat1 = alpha * (grad_out @ mat2.T)
132+
- grad_mat2 = alpha * (mat1.T @ grad_out)
133+
134+
Args:
135+
grad_out: Gradient w.r.t output [m, n]
136+
input: Bias tensor [m, n] (or broadcastable)
137+
mat1: First matrix [m, k]
138+
mat2: Second matrix [k, n]
139+
alpha: Scalar multiplier for matmul
140+
beta: Scalar multiplier for bias
141+
142+
Returns:
143+
tuple[Tensor, Tensor, Tensor]: (grad_input, grad_mat1, grad_mat2)
144+
"""
145+
# Get all dimensions first
146+
m, n = grad_out.size()
147+
m2, k = mat1.size()
148+
k2, n2 = mat2.size()
149+
150+
# All assertions at the top
151+
assert m == m2 and n == n2 and k == k2, "Size mismatch in addmm backward"
152+
153+
# Declare ALL output tensors at the top before any loops
154+
grad_input = torch.empty_like(input)
155+
grad_mat1 = torch.empty_like(mat1)
156+
grad_mat2 = torch.empty_like(mat2)
157+
158+
# Handle grad_input = beta * grad_out (assuming same shape for now)
159+
for tile_m3, tile_n3 in hl.tile([m, n]):
160+
grad_input[tile_m3, tile_n3] = beta * grad_out[tile_m3, tile_n3]
161+
162+
# First loop block: compute grad_mat1 = alpha * (grad_out @ mat2.T)
163+
for tile_m1, tile_k1 in hl.tile([m, k]):
164+
acc1 = hl.zeros([tile_m1, tile_k1], dtype=torch.float32)
165+
for tile_n1 in hl.tile(n):
166+
acc1 = torch.addmm(
167+
acc1, grad_out[tile_m1, tile_n1], mat2[tile_k1, tile_n1].T
168+
)
169+
grad_mat1[tile_m1, tile_k1] = (alpha * acc1).to(mat1.dtype)
170+
171+
# Second loop block: compute grad_mat2 = alpha * (mat1.T @ grad_out)
172+
for tile_k2, tile_n2 in hl.tile([k, n]):
173+
acc2 = hl.zeros([tile_k2, tile_n2], dtype=torch.float32)
174+
for tile_m2 in hl.tile(m):
175+
acc2 = torch.addmm(
176+
acc2, mat1[tile_m2, tile_k2].T, grad_out[tile_m2, tile_n2]
177+
)
178+
grad_mat2[tile_k2, tile_n2] = (alpha * acc2).to(mat2.dtype)
179+
180+
return grad_input, grad_mat1, grad_mat2
181+
182+
115183
# %%
116184
class MatMulFunction(torch.autograd.Function):
117185
@staticmethod
@@ -141,6 +209,45 @@ def matmul_autograd(mat1: Tensor, mat2: Tensor) -> Tensor:
141209
return MatMulFunction.apply(mat1, mat2) # type: ignore[no-any-return]
142210

143211

212+
class AddMMFunction(torch.autograd.Function):
213+
@staticmethod
214+
def forward(
215+
ctx: Any,
216+
input: Tensor,
217+
mat1: Tensor,
218+
mat2: Tensor,
219+
alpha: float = 1.0,
220+
beta: float = 1.0,
221+
) -> Tensor:
222+
"""Forward pass for addmm operation."""
223+
result = torch.addmm(input, mat1, mat2, alpha=alpha, beta=beta)
224+
ctx.save_for_backward(input, mat1, mat2)
225+
ctx.alpha = alpha
226+
ctx.beta = beta
227+
return result
228+
229+
@staticmethod
230+
def backward(
231+
ctx: Any,
232+
grad_out: Tensor,
233+
) -> tuple[Tensor | None, Tensor | None, Tensor | None, None, None]:
234+
"""Backward pass for addmm operation."""
235+
input, mat1, mat2 = ctx.saved_tensors
236+
alpha = ctx.alpha
237+
beta = ctx.beta
238+
grad_input, grad_mat1, grad_mat2 = addmm_bwd(
239+
grad_out, input, mat1, mat2, alpha, beta
240+
)
241+
return grad_input, grad_mat1, grad_mat2, None, None
242+
243+
244+
def addmm_autograd(
245+
input: Tensor, mat1: Tensor, mat2: Tensor, alpha: float = 1.0, beta: float = 1.0
246+
) -> Tensor:
247+
"""AddMM operation with forward + backward support."""
248+
return AddMMFunction.apply(input, mat1, mat2, alpha, beta) # type: ignore[no-any-return]
249+
250+
144251
# %%
145252
def autotune(m: int, k: int, n: int) -> None:
146253
"""
@@ -230,6 +337,47 @@ def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
230337
bwd=True,
231338
)
232339

340+
# Test addmm forward + backward pass
341+
print("\n\n=== AddMM Forward + Backward Pass Test ===")
342+
input_grad = torch.randn(
343+
[m, n], device="cuda", dtype=torch.float16, requires_grad=True
344+
)
345+
mat1_grad = torch.randn(
346+
[m, k], device="cuda", dtype=torch.float16, requires_grad=True
347+
)
348+
mat2_grad = torch.randn(
349+
[k, n], device="cuda", dtype=torch.float16, requires_grad=True
350+
)
351+
352+
# Use lambda to handle the keyword argument format for torch.addmm
353+
run_example(
354+
addmm_autograd,
355+
lambda input, mat1, mat2, alpha, beta: torch.addmm(
356+
input, mat1, mat2, alpha=alpha, beta=beta
357+
),
358+
(input_grad, mat1_grad, mat2_grad, 1.0, 1.0),
359+
kernel_name="helion_addmm_autograd",
360+
baseline_name="torch",
361+
rtol=1e-2,
362+
atol=1e-2,
363+
bwd=True,
364+
)
365+
366+
# Test addmm forward + backward with different alpha/beta values
367+
print("\n\n=== AddMM Forward + Backward Test (Alpha=2.0, Beta=0.5) ===")
368+
run_example(
369+
addmm_autograd,
370+
lambda input, mat1, mat2, alpha, beta: torch.addmm(
371+
input, mat1, mat2, alpha=alpha, beta=beta
372+
),
373+
(input_grad, mat1_grad, mat2_grad, 2.0, 0.5),
374+
kernel_name="helion_addmm_autograd_scaled",
375+
baseline_name="torch",
376+
rtol=1e-2,
377+
atol=1e-2,
378+
bwd=True,
379+
)
380+
233381

234382
# %%
235383
def matmul_tritonbench(

test/test_examples.expected

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2818,8 +2818,6 @@ def matmul_bwd(grad_out: Tensor, mat1: Tensor, mat2: Tensor, *, _launcher=_defau
28182818
- grad_A = grad_C @ B.T
28192819
- grad_B = A.T @ grad_C
28202820

2821-
Key: Declare ALL variables at the top before ANY loops to avoid TopLevelStatementBetweenLoops error
2822-
28232821
Args:
28242822
grad_out: Gradient w.r.t output [m, n]
28252823
mat1: First matrix [m, k]

test/test_examples.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,28 @@ def test_matmul(self):
5050

5151
def test_matmul_bwd(self):
5252
"""Test backward pass for matmul computation."""
53+
# Create tensors with requires_grad=True like rms_norm_bwd test
54+
mat1 = torch.randn(
55+
[128, 128], device=DEVICE, dtype=torch.float32, requires_grad=True
56+
)
57+
mat2 = torch.randn(
58+
[128, 128], device=DEVICE, dtype=torch.float32, requires_grad=True
59+
)
5360
grad_out = torch.randn([128, 128], device=DEVICE, dtype=torch.float32)
54-
mat1 = torch.randn([128, 128], device=DEVICE, dtype=torch.float32)
55-
mat2 = torch.randn([128, 128], device=DEVICE, dtype=torch.float32)
5661

57-
args = (grad_out, mat1, mat2)
62+
# Compute expected gradients with PyTorch
63+
mat1_torch = mat1.detach().clone().requires_grad_(True)
64+
mat2_torch = mat2.detach().clone().requires_grad_(True)
65+
result_torch = torch.matmul(mat1_torch, mat2_torch)
66+
result_torch.backward(grad_out)
5867

59-
expected = (grad_out @ mat2.T, mat1.T @ grad_out)
68+
args = (grad_out, mat1, mat2)
6069

6170
self.assertExpectedJournal(
6271
check_example(
6372
"matmul",
6473
args,
65-
expected,
74+
(mat1_torch.grad, mat2_torch.grad), # Expected: (grad_mat1, grad_mat2)
6675
fn_name="matmul_bwd",
6776
block_sizes=[
6877
16,

0 commit comments

Comments
 (0)