-
Notifications
You must be signed in to change notification settings - Fork 71
[Rewriter]: Add ∘ MatMul -> Gemm #2356
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@justinchuby I wrote the rewriter using functions (as in gemm_to_matmul_add). Tell me if you want me to do it using Classes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a rewriter rule that transforms a MatMul followed by an Add operation into a Gemm operator for streamlined execution. The changes include implementation of the transformation rule in the rewriter module and accompanying tests to validate both standard and edge-case behaviors.
- Added a new transformation rule in onnxscript/rewriter/matmul_add_to_gemm.py.
- Introduced tests in onnxscript/rewriter/matmul_add_to_gemm_test.py for regular inputs, initializers inclusion, and incompatible shapes.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
File | Description |
---|---|
onnxscript/rewriter/matmul_add_to_gemm_test.py | Test cases ensuring the fusion of MatMul and Add into Gemm are correct. |
onnxscript/rewriter/matmul_add_to_gemm.py | Implementation of the rewriter rule for converting MatMul+Add to Gemm. |
Comments suppressed due to low confidence (1)
onnxscript/rewriter/matmul_add_to_gemm_test.py:12
- [nitpick] The test class name 'MatMulAddToMatMulTest' is misleading given that the transformation targets Gemm. Consider renaming it to 'MatMulAddToGemmTest' for clarity.
class MatMulAddToMatMulTest(unittest.TestCase):
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2356 +/- ##
==========================================
+ Coverage 70.15% 70.36% +0.20%
==========================================
Files 197 199 +2
Lines 24985 25152 +167
Branches 2669 2681 +12
==========================================
+ Hits 17529 17698 +169
+ Misses 6529 6528 -1
+ Partials 927 926 -1 ☔ View full report in Codecov by Sentry. |
Would be nice to use the class based approach just for consistency, thanks! |
90ed11e
to
1a8a641
Compare
Added the support for transposed inputs
I think I dont need to check the users, the rewriter does check the number of users by default (if i am not mistaken)
With the current implementation, I don't think I need it, right ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nicely written and tested, thanks!
1a8a641
to
b5bd684
Compare
@titaiwangms could you approve? Thanks |
A Rewriter rule that transforms
MatMul(Add)
toGemm
.