Skip to content

[Rewriter]: Add ∘ MatMul -> Gemm #2356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions onnxscript/rewriter/matmul_add_to_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Does the following transformation:
- Add(MatMul(X, W), B) -> Gemm
- Add(MatMul(Transpose(X), W), B) -> Gemm
- Add(MatMul(X, Transpose(W)), B) -> Gemm
- Add(MatMul(Transpose(X), Transpose(W)), B) -> Gemm
"""

import abc
from typing import ClassVar

from onnxscript.rewriter import pattern as orp


class _MatMulAddToGemmBase(orp.RewriteRuleClassBase, abc.ABC):
trans_a: ClassVar = False
trans_b: ClassVar = False

def rewrite(self, op, input_a, input_b, input_c):
attributes = {}
if self.trans_a:
attributes["transA"] = 1
if self.trans_b:
attributes["transB"] = 1
return op.Gemm(input_a, input_b, input_c, **attributes)

def check(self, context, input_a, input_b, **_):
del context # Not used
check_result = orp.MatchResult()
# Rank of input_a and input_b must be 2
if len(input_a.shape) != 2 or len(input_b.shape) != 2:
return check_result.fail("Rank of input_a and input_b must be 2")
return check_result


class MatMulAddToGemm(_MatMulAddToGemmBase):
"""Replaces ``Add(MatMul(a, b), c)`` with ``Gemm(a, b, c)``."""

def pattern(self, op, input_a, input_b, input_c):
matmul = op.MatMul(input_a, input_b)
return op.Add(matmul, input_c)


class TransAMatMulAddToGemm(_MatMulAddToGemmBase):
"""Replaces ``Add(MatMul(Transpose(a), b), c)`` with ``Gemm(a, b, c)``."""

trans_a: ClassVar = True

def pattern(self, op, input_a, input_b, input_c):
matmul = op.MatMul(op.Transpose(input_a, perm=[1, 0]), input_b)
return op.Add(matmul, input_c)


class TransBMatMulAddToGemm(_MatMulAddToGemmBase):
"""Replaces ``Add(MatMul(a, Transpose(b)), c)`` with ``Gemm(a, b, c)``."""

trans_b: ClassVar = True

def pattern(self, op, input_a, input_b, input_c):
matmul = op.MatMul(input_a, op.Transpose(input_b, perm=[1, 0]))
return op.Add(matmul, input_c)


class TransABMatMulAddToGemm(_MatMulAddToGemmBase):
"""Replaces ``Add(MatMul(Transpose(a), Transpose(b)), c)`` with ``Gemm(a, b, c)``."""

trans_a: ClassVar = True
trans_b: ClassVar = True

def pattern(self, op, input_a, input_b, input_c):
matmul = op.MatMul(
op.Transpose(input_a, perm=[1, 0]),
op.Transpose(input_b, perm=[1, 0]),
)
return op.Add(matmul, input_c)


matmul_add_to_gemm_rule = MatMulAddToGemm().rule()
transpose_a_matmul_add_to_gemm_rule = TransAMatMulAddToGemm().rule()
transpose_b_matmul_add_to_gemm_rule = TransBMatMulAddToGemm().rule()
transpose_ab_matmul_add_to_gemm_rule = TransABMatMulAddToGemm().rule()


def gemm_rule_set() -> orp.RewriteRuleSet:
"""Returns a set of rewrite rules that fuse MatMul + Add patterns into a single Gemm node,
handling cases where one or both MatMul inputs are transposed.

Returns:
RewriteRuleSet
"""

# Order is important
return orp.RewriteRuleSet(
[
transpose_ab_matmul_add_to_gemm_rule,
transpose_a_matmul_add_to_gemm_rule,
transpose_b_matmul_add_to_gemm_rule,
matmul_add_to_gemm_rule,
]
)
Loading
Loading