-
Notifications
You must be signed in to change notification settings - Fork 70
Add matmul/addmm bwd examples and add test coverage #748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
a7958ba to
0960861
Compare
examples/matmul.py
Outdated
| m, k = mat1.size() | ||
| k2, n = mat2.size() | ||
| bias = torch.broadcast_to(bias, [m, n]) | ||
| return lambda: matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would you like to also add integration in benchmarks/run.py (similar to rms_norm-bwd), and test accuracy via tritonbench --metrics accuracy?
Also I believe these two *_tritonbench functions should probably just call the *_autograd functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tritonbench PR for adding addm-bwd and gemm-bwd landed: meta-pytorch/tritonbench#531
All test passed except for gemm with partition-k. It seems partition-k in fwd is broken.
Addmm fwd
(M, N, K) triton_addmm-accuracy pt2_triton_matmul-accuracy helion_addmm_tritonbench-accuracy
------------------ ----------------------- ---------------------------- -----------------------------------
(20120, 512, 1536) 1 1 1
(34579, 512, 1536) 1 1 1
(34839, 512, 1536) 1 1 1
average 1 1 1
Addmm bwd
(M, N, K) triton_addmm-accuracy pt2_triton_matmul-accuracy helion_addmm_tritonbench-accuracy
------------------ ----------------------- ---------------------------- -----------------------------------
(20120, 512, 1536) 1 1 1
(34579, 512, 1536) 1 1 1
(34839, 512, 1536) 1 1 1
average 1 1 1
gemm fwd
(M, N, K) triton_tutorial_matmul-accuracy matmul_partition_k-accuracy triton_ops_matmul-accuracy aten_tunableop_matmul-accuracy pt2_triton_matmul-accuracy streamk_matmul-accuracy pt2_cutlass_matmul-accuracy helion_matmul_tritonbench-accuracy
--------------- --------------------------------- ----------------------------- ---------------------------- -------------------------------- ---------------------------- ------------------------- ----------------------------- ------------------------------------
(256, 256, 256) 1 0 1 1 1 1 1 1
(384, 384, 384) 1 0 1 1 1 1 1 1
(512, 512, 512) 1 1 1 1 1 1 1 1
average 1 0.333333 1 1 1 1 1 1
gemm bwd
(M, N, K) triton_tutorial_matmul-accuracy matmul_partition_k-accuracy triton_ops_matmul-accuracy aten_tunableop_matmul-accuracy pt2_triton_matmul-accuracy streamk_matmul-accuracy pt2_cutlass_matmul-accuracy helion_matmul_tritonbench-accuracy
--------------- --------------------------------- ----------------------------- ---------------------------- -------------------------------- ---------------------------- ------------------------- ----------------------------- ------------------------------------
(256, 256, 256) 1 0 1 1 1 1 1 1
(384, 384, 384) 1 0 1 1 1 1 1 1
(512, 512, 512) 1 0 1 1 1 1 1 1
average 1 0 1 1 1 1 1 1
- Fixed matmul_tritonbench to use addmm_autograd for gemm with bias testing - Updated tritonbench operators for proper requires_grad handling - gemm-bwd showing 100% accuracy for Helion implementation - Some gradient mismatches on larger shapes still being investigated Current test results: - addmm-bwd: 100% accuracy - gemm-bwd: 100% accuracy for Helion, some framework gradient issues remain
…leop_results0.csv, PR_748_DOCUMENTATION.md)
04eb760 to
ee4f1d9
Compare
|
The test failure seems unrelated to my PR. It complains about illegal memory access for 500+ tests. I saw similar failures in other PRs too. |
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @tianrengao !
|
@tianrengao has imported this pull request. If you are a Meta employee, you can view this in D84363848. |
|
I'm sorry this is my pet peeve, but why do we need special matmul_bwd kernels? A backward for matmul is matmul, if you have a forward matmul kernel that's performant for the different sizes and transposes, you already have a backward matmul kernel |
Sorry I didn’t see your comment before the PR was merged. Thanks for raising this question—I agree that the backward pass for matmul is another matmul, and reusing the forward kernel’s optimized performance is a good point. In helion day2, addmm was there on the list for bwd support, so I added matmul along with addmm. I added the dedicated matmul_bwd kernel mainly for ease of use, so users don’t have to implement their own backward logic. The backward needs to return gradients for both input matrices, which is why a separate kernel might be helpful. If you’re asking why we don’t simply reuse the forward matmul kernel for the backward pass, I have followed a similar logic to the forward pass to keep things consistent. However, if we find performance differences in the future, we can definitely consider removing it or reusing the forward kernel. |
|
You are still implementing backward logic in custom autograd function. Custom autograd function should instead just call matmul, instead of a special kernel. Chaining 2 matmuls in one kernel is very rarely profitable, only for the smallest matmul. |
I see, will fix this in a follow up PR soon. Thanks for pointing it out. |
This PR adds backward pass (gradient computation) support for matrix multiplication and addmm operations in Helion, with comprehensive unit tests to ensure correctness against PyTorch baselines.
Key Changes
1. Backward Kernels
matmul_bwd: Computes gradients for matrix multiplication (grad_A = grad_C @ B.T,grad_B = A.T @ grad_C)addmm_bwd: Computes gradients for addmm operation with alpha/beta scaling support2. **PyTorch Autograd **
MatMulFunction&AddMMFunction: PyTorch autograd classes with proper*grad_outputssignaturesmatmul_autograd&addmm_autograd: User-friendly API functions3. Unit Tests
test_matmul_bwd: Validates matrix multiplication backward pass against PyTorch baselinetest_addmm_bwd: Validates addmm backward pass with gradient flow for all inputsUsage Example
Files Modified
examples/matmul.py: Added backward kernels, autograd classes, and enhanced testingtest/test_examples.py: Addedtest_matmul_bwdandtest_addmm_bwdunit teststest/test_examples.expected: Updated expected outputs for new tests