Skip to content

Commit d7974ba

Browse files
authored
[Rewriter]: Add ∘ MatMul -> Gemm (#2356)
A Rewriter rule that transforms `MatMul(Add)` to `Gemm`.
1 parent ccaefc6 commit d7974ba

File tree

2 files changed

+416
-0
lines changed

2 files changed

+416
-0
lines changed
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
"""Does the following transformation:
4+
- Add(MatMul(X, W), B) -> Gemm
5+
- Add(MatMul(Transpose(X), W), B) -> Gemm
6+
- Add(MatMul(X, Transpose(W)), B) -> Gemm
7+
- Add(MatMul(Transpose(X), Transpose(W)), B) -> Gemm
8+
"""
9+
10+
import abc
11+
from typing import ClassVar
12+
13+
from onnxscript.rewriter import pattern as orp
14+
15+
16+
class _MatMulAddToGemmBase(orp.RewriteRuleClassBase, abc.ABC):
17+
trans_a: ClassVar = False
18+
trans_b: ClassVar = False
19+
20+
def rewrite(self, op, input_a, input_b, input_c):
21+
attributes = {}
22+
if self.trans_a:
23+
attributes["transA"] = 1
24+
if self.trans_b:
25+
attributes["transB"] = 1
26+
return op.Gemm(input_a, input_b, input_c, **attributes)
27+
28+
def check(self, context, input_a, input_b, **_):
29+
del context # Not used
30+
check_result = orp.MatchResult()
31+
# Rank of input_a and input_b must be 2
32+
if len(input_a.shape) != 2 or len(input_b.shape) != 2:
33+
return check_result.fail("Rank of input_a and input_b must be 2")
34+
return check_result
35+
36+
37+
class MatMulAddToGemm(_MatMulAddToGemmBase):
38+
"""Replaces ``Add(MatMul(a, b), c)`` with ``Gemm(a, b, c)``."""
39+
40+
def pattern(self, op, input_a, input_b, input_c):
41+
matmul = op.MatMul(input_a, input_b)
42+
return op.Add(matmul, input_c)
43+
44+
45+
class TransAMatMulAddToGemm(_MatMulAddToGemmBase):
46+
"""Replaces ``Add(MatMul(Transpose(a), b), c)`` with ``Gemm(a, b, c)``."""
47+
48+
trans_a: ClassVar = True
49+
50+
def pattern(self, op, input_a, input_b, input_c):
51+
matmul = op.MatMul(op.Transpose(input_a, perm=[1, 0]), input_b)
52+
return op.Add(matmul, input_c)
53+
54+
55+
class TransBMatMulAddToGemm(_MatMulAddToGemmBase):
56+
"""Replaces ``Add(MatMul(a, Transpose(b)), c)`` with ``Gemm(a, b, c)``."""
57+
58+
trans_b: ClassVar = True
59+
60+
def pattern(self, op, input_a, input_b, input_c):
61+
matmul = op.MatMul(input_a, op.Transpose(input_b, perm=[1, 0]))
62+
return op.Add(matmul, input_c)
63+
64+
65+
class TransABMatMulAddToGemm(_MatMulAddToGemmBase):
66+
"""Replaces ``Add(MatMul(Transpose(a), Transpose(b)), c)`` with ``Gemm(a, b, c)``."""
67+
68+
trans_a: ClassVar = True
69+
trans_b: ClassVar = True
70+
71+
def pattern(self, op, input_a, input_b, input_c):
72+
matmul = op.MatMul(
73+
op.Transpose(input_a, perm=[1, 0]),
74+
op.Transpose(input_b, perm=[1, 0]),
75+
)
76+
return op.Add(matmul, input_c)
77+
78+
79+
matmul_add_to_gemm_rule = MatMulAddToGemm().rule()
80+
transpose_a_matmul_add_to_gemm_rule = TransAMatMulAddToGemm().rule()
81+
transpose_b_matmul_add_to_gemm_rule = TransBMatMulAddToGemm().rule()
82+
transpose_ab_matmul_add_to_gemm_rule = TransABMatMulAddToGemm().rule()
83+
84+
85+
def gemm_rule_set() -> orp.RewriteRuleSet:
86+
"""Returns a set of rewrite rules that fuse MatMul + Add patterns into a single Gemm node,
87+
handling cases where one or both MatMul inputs are transposed.
88+
89+
Returns:
90+
RewriteRuleSet
91+
"""
92+
93+
# Order is important
94+
return orp.RewriteRuleSet(
95+
[
96+
transpose_ab_matmul_add_to_gemm_rule,
97+
transpose_a_matmul_add_to_gemm_rule,
98+
transpose_b_matmul_add_to_gemm_rule,
99+
matmul_add_to_gemm_rule,
100+
]
101+
)

0 commit comments

Comments
 (0)