|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | +"""Does the following transformation: |
| 4 | +- Add(MatMul(X, W), B) -> Gemm |
| 5 | +- Add(MatMul(Transpose(X), W), B) -> Gemm |
| 6 | +- Add(MatMul(X, Transpose(W)), B) -> Gemm |
| 7 | +- Add(MatMul(Transpose(X), Transpose(W)), B) -> Gemm |
| 8 | +""" |
| 9 | + |
| 10 | +import abc |
| 11 | +from typing import ClassVar |
| 12 | + |
| 13 | +from onnxscript.rewriter import pattern as orp |
| 14 | + |
| 15 | + |
| 16 | +class _MatMulAddToGemmBase(orp.RewriteRuleClassBase, abc.ABC): |
| 17 | + trans_a: ClassVar = False |
| 18 | + trans_b: ClassVar = False |
| 19 | + |
| 20 | + def rewrite(self, op, input_a, input_b, input_c): |
| 21 | + attributes = {} |
| 22 | + if self.trans_a: |
| 23 | + attributes["transA"] = 1 |
| 24 | + if self.trans_b: |
| 25 | + attributes["transB"] = 1 |
| 26 | + return op.Gemm(input_a, input_b, input_c, **attributes) |
| 27 | + |
| 28 | + def check(self, context, input_a, input_b, **_): |
| 29 | + del context # Not used |
| 30 | + check_result = orp.MatchResult() |
| 31 | + # Rank of input_a and input_b must be 2 |
| 32 | + if len(input_a.shape) != 2 or len(input_b.shape) != 2: |
| 33 | + return check_result.fail("Rank of input_a and input_b must be 2") |
| 34 | + return check_result |
| 35 | + |
| 36 | + |
| 37 | +class MatMulAddToGemm(_MatMulAddToGemmBase): |
| 38 | + """Replaces ``Add(MatMul(a, b), c)`` with ``Gemm(a, b, c)``.""" |
| 39 | + |
| 40 | + def pattern(self, op, input_a, input_b, input_c): |
| 41 | + matmul = op.MatMul(input_a, input_b) |
| 42 | + return op.Add(matmul, input_c) |
| 43 | + |
| 44 | + |
| 45 | +class TransAMatMulAddToGemm(_MatMulAddToGemmBase): |
| 46 | + """Replaces ``Add(MatMul(Transpose(a), b), c)`` with ``Gemm(a, b, c)``.""" |
| 47 | + |
| 48 | + trans_a: ClassVar = True |
| 49 | + |
| 50 | + def pattern(self, op, input_a, input_b, input_c): |
| 51 | + matmul = op.MatMul(op.Transpose(input_a, perm=[1, 0]), input_b) |
| 52 | + return op.Add(matmul, input_c) |
| 53 | + |
| 54 | + |
| 55 | +class TransBMatMulAddToGemm(_MatMulAddToGemmBase): |
| 56 | + """Replaces ``Add(MatMul(a, Transpose(b)), c)`` with ``Gemm(a, b, c)``.""" |
| 57 | + |
| 58 | + trans_b: ClassVar = True |
| 59 | + |
| 60 | + def pattern(self, op, input_a, input_b, input_c): |
| 61 | + matmul = op.MatMul(input_a, op.Transpose(input_b, perm=[1, 0])) |
| 62 | + return op.Add(matmul, input_c) |
| 63 | + |
| 64 | + |
| 65 | +class TransABMatMulAddToGemm(_MatMulAddToGemmBase): |
| 66 | + """Replaces ``Add(MatMul(Transpose(a), Transpose(b)), c)`` with ``Gemm(a, b, c)``.""" |
| 67 | + |
| 68 | + trans_a: ClassVar = True |
| 69 | + trans_b: ClassVar = True |
| 70 | + |
| 71 | + def pattern(self, op, input_a, input_b, input_c): |
| 72 | + matmul = op.MatMul( |
| 73 | + op.Transpose(input_a, perm=[1, 0]), |
| 74 | + op.Transpose(input_b, perm=[1, 0]), |
| 75 | + ) |
| 76 | + return op.Add(matmul, input_c) |
| 77 | + |
| 78 | + |
| 79 | +matmul_add_to_gemm_rule = MatMulAddToGemm().rule() |
| 80 | +transpose_a_matmul_add_to_gemm_rule = TransAMatMulAddToGemm().rule() |
| 81 | +transpose_b_matmul_add_to_gemm_rule = TransBMatMulAddToGemm().rule() |
| 82 | +transpose_ab_matmul_add_to_gemm_rule = TransABMatMulAddToGemm().rule() |
| 83 | + |
| 84 | + |
| 85 | +def gemm_rule_set() -> orp.RewriteRuleSet: |
| 86 | + """Returns a set of rewrite rules that fuse MatMul + Add patterns into a single Gemm node, |
| 87 | + handling cases where one or both MatMul inputs are transposed. |
| 88 | +
|
| 89 | + Returns: |
| 90 | + RewriteRuleSet |
| 91 | + """ |
| 92 | + |
| 93 | + # Order is important |
| 94 | + return orp.RewriteRuleSet( |
| 95 | + [ |
| 96 | + transpose_ab_matmul_add_to_gemm_rule, |
| 97 | + transpose_a_matmul_add_to_gemm_rule, |
| 98 | + transpose_b_matmul_add_to_gemm_rule, |
| 99 | + matmul_add_to_gemm_rule, |
| 100 | + ] |
| 101 | + ) |
0 commit comments