Skip to content

Conversation

@tianrengao
Copy link
Contributor

@tianrengao tianrengao commented Oct 1, 2025

This PR adds backward pass (gradient computation) support for matrix multiplication and addmm operations in Helion, with comprehensive unit tests to ensure correctness against PyTorch baselines.

Key Changes

1. Backward Kernels

  • matmul_bwd: Computes gradients for matrix multiplication (grad_A = grad_C @ B.T, grad_B = A.T @ grad_C)
  • addmm_bwd: Computes gradients for addmm operation with alpha/beta scaling support

2. **PyTorch Autograd **

  • MatMulFunction & AddMMFunction: PyTorch autograd classes with proper *grad_outputs signatures
  • matmul_autograd & addmm_autograd: User-friendly API functions

3. Unit Tests

  • test_matmul_bwd: Validates matrix multiplication backward pass against PyTorch baseline
  • test_addmm_bwd: Validates addmm backward pass with gradient flow for all inputs

Usage Example

# Matrix multiplication with gradients
mat1 = torch.randn([128, 256], requires_grad=True, device='cuda')
mat2 = torch.randn([256, 128], requires_grad=True, device='cuda')
result = matmul_autograd(mat1, mat2)
result.sum().backward()  # Gradients available in mat1.grad, mat2.grad

# AddMM with scaling
bias = torch.randn([128, 128], requires_grad=True, device='cuda')
result = addmm_autograd(bias, mat1, mat2, alpha=2.0, beta=0.5)

Files Modified

  • examples/matmul.py: Added backward kernels, autograd classes, and enhanced testing
  • test/test_examples.py: Added test_matmul_bwd and test_addmm_bwd unit tests
  • test/test_examples.expected: Updated expected outputs for new tests

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 1, 2025
@tianrengao tianrengao changed the title add matmul bwd Add matmul/addmm bwd and add test coverage Oct 1, 2025
@tianrengao tianrengao changed the title Add matmul/addmm bwd and add test coverage Add matmul/addmm bwd examples and add test coverage Oct 1, 2025
@tianrengao tianrengao marked this pull request as ready for review October 2, 2025 00:22
@tianrengao tianrengao requested a review from yf225 October 2, 2025 02:07
m, k = mat1.size()
k2, n = mat2.size()
bias = torch.broadcast_to(bias, [m, n])
return lambda: matmul(mat1, mat2, lambda acc, tile: acc + bias[tile[0], tile[1]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you like to also add integration in benchmarks/run.py (similar to rms_norm-bwd), and test accuracy via tritonbench --metrics accuracy?

Also I believe these two *_tritonbench functions should probably just call the *_autograd functions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tritonbench PR for adding addm-bwd and gemm-bwd landed: meta-pytorch/tritonbench#531

All test passed except for gemm with partition-k. It seems partition-k in fwd is broken.

Addmm fwd

         (M, N, K)    triton_addmm-accuracy    pt2_triton_matmul-accuracy    helion_addmm_tritonbench-accuracy
------------------  -----------------------  ----------------------------  -----------------------------------
(20120, 512, 1536)                        1                             1                                    1
(34579, 512, 1536)                        1                             1                                    1
(34839, 512, 1536)                        1                             1                                    1
           average                        1                             1                                    1

Addmm bwd

         (M, N, K)    triton_addmm-accuracy    pt2_triton_matmul-accuracy    helion_addmm_tritonbench-accuracy
------------------  -----------------------  ----------------------------  -----------------------------------
(20120, 512, 1536)                        1                             1                                    1
(34579, 512, 1536)                        1                             1                                    1
(34839, 512, 1536)                        1                             1                                    1
           average                        1                             1                                    1

gemm fwd

      (M, N, K)    triton_tutorial_matmul-accuracy    matmul_partition_k-accuracy    triton_ops_matmul-accuracy    aten_tunableop_matmul-accuracy    pt2_triton_matmul-accuracy    streamk_matmul-accuracy    pt2_cutlass_matmul-accuracy    helion_matmul_tritonbench-accuracy
---------------  ---------------------------------  -----------------------------  ----------------------------  --------------------------------  ----------------------------  -------------------------  -----------------------------  ------------------------------------
(256, 256, 256)                                  1                       0                                    1                                 1                             1                          1                              1                                     1
(384, 384, 384)                                  1                       0                                    1                                 1                             1                          1                              1                                     1
(512, 512, 512)                                  1                       1                                    1                                 1                             1                          1                              1                                     1
        average                                  1                       0.333333                             1                                 1                             1                          1                              1                                     1

gemm bwd

  (M, N, K)    triton_tutorial_matmul-accuracy    matmul_partition_k-accuracy    triton_ops_matmul-accuracy    aten_tunableop_matmul-accuracy    pt2_triton_matmul-accuracy    streamk_matmul-accuracy    pt2_cutlass_matmul-accuracy    helion_matmul_tritonbench-accuracy
---------------  ---------------------------------  -----------------------------  ----------------------------  --------------------------------  ----------------------------  -------------------------  -----------------------------  ------------------------------------
(256, 256, 256)                                  1                              0                             1                                 1                             1                          1                              1                                     1
(384, 384, 384)                                  1                              0                             1                                 1                             1                          1                              1                                     1
(512, 512, 512)                                  1                              0                             1                                 1                             1                          1                              1                                     1
        average                                  1                              0                             1                                 1                             1                          1                              1                                     1

- Fixed matmul_tritonbench to use addmm_autograd for gemm with bias testing
- Updated tritonbench operators for proper requires_grad handling
- gemm-bwd showing 100% accuracy for Helion implementation
- Some gradient mismatches on larger shapes still being investigated

Current test results:
- addmm-bwd: 100% accuracy
- gemm-bwd: 100% accuracy for Helion, some framework gradient issues remain
@tianrengao
Copy link
Contributor Author

The test failure seems unrelated to my PR. It complains about illegal memory access for 500+ tests. I saw similar failures in other PRs too.

@tianrengao tianrengao requested a review from yf225 October 9, 2025 22:56
Copy link
Contributor

@yf225 yf225 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @tianrengao !

@meta-codesync
Copy link

meta-codesync bot commented Oct 10, 2025

@tianrengao has imported this pull request. If you are a Meta employee, you can view this in D84363848.

@ngimel
Copy link

ngimel commented Oct 10, 2025

I'm sorry this is my pet peeve, but why do we need special matmul_bwd kernels? A backward for matmul is matmul, if you have a forward matmul kernel that's performant for the different sizes and transposes, you already have a backward matmul kernel

@meta-codesync meta-codesync bot merged commit e9e3cae into main Oct 11, 2025
14 of 16 checks passed
@tianrengao
Copy link
Contributor Author

tianrengao commented Oct 14, 2025

I'm sorry this is my pet peeve, but why do we need special matmul_bwd kernels? A backward for matmul is matmul, if you have a forward matmul kernel that's performant for the different sizes and transposes, you already have a backward matmul kernel

Sorry I didn’t see your comment before the PR was merged. Thanks for raising this question—I agree that the backward pass for matmul is another matmul, and reusing the forward kernel’s optimized performance is a good point.

In helion day2, addmm was there on the list for bwd support, so I added matmul along with addmm. I added the dedicated matmul_bwd kernel mainly for ease of use, so users don’t have to implement their own backward logic. The backward needs to return gradients for both input matrices, which is why a separate kernel might be helpful. If you’re asking why we don’t simply reuse the forward matmul kernel for the backward pass, I have followed a similar logic to the forward pass to keep things consistent. However, if we find performance differences in the future, we can definitely consider removing it or reusing the forward kernel.

@ngimel
Copy link

ngimel commented Oct 14, 2025

You are still implementing backward logic in custom autograd function. Custom autograd function should instead just call matmul, instead of a special kernel. Chaining 2 matmuls in one kernel is very rarely profitable, only for the smallest matmul.

@tianrengao
Copy link
Contributor Author

You are still implementing backward logic in custom autograd function. Custom autograd function should instead just call matmul, instead of a special kernel. Chaining 2 matmuls in one kernel is very rarely profitable, only for the smallest matmul.

I see, will fix this in a follow up PR soon. Thanks for pointing it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants