Skip to content

Commit 0259669

Browse files
author
Jian Cai
committed
[mlir] Add a postprocessing parameter in Pattern
This adds a parameter SupplementalPatterns in tablegen class Pattern for postprocessing code. For example, this can be used to ensure ops are placed in the correct device by copying the atttributes that decide devicement placement in Tensorflow dialect to prevent performance regression. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D157032
1 parent 5ed62c7 commit 0259669

File tree

5 files changed

+69
-2
lines changed

5 files changed

+69
-2
lines changed

mlir/docs/DeclarativeRewrites.md

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,13 @@ features:
5959
## Rule Definition
6060

6161
The core construct for defining a rewrite rule is defined in
62-
[`OpBase.td`][OpBase] as
62+
[`PatternBase.td`][PatternBase] as
6363

6464
```tablegen
6565
class Pattern<
6666
dag sourcePattern, list<dag> resultPatterns,
6767
list<dag> additionalConstraints = [],
68+
list<dag> supplementalPatterns = [],
6869
dag benefitsAdded = (addBenefit 0)>;
6970
```
7071

@@ -678,6 +679,36 @@ You can
678679
* Apply constraints on multiple bound symbols (`$input` and `TwoResultOp`'s
679680
first result must have the same element type).
680681

682+
### Supplying additional result patterns
683+
684+
Sometimes we need to add additional code after the result patterns, e.g. coping
685+
the attributes of the source op to the result ops. These can be specified via
686+
`SupplementalPatterns` parameter. Similar to auxiliary patterns, they are not
687+
for replacing results in the source pattern.
688+
689+
For example, we can write
690+
691+
```tablegen
692+
def GetOwner: NativeCodeCall<"$0.getOwner()">;
693+
694+
def CopyAttrFoo: NativeCodeCallVoid<
695+
"$1->setAttr($_builder.getStringAttr(\"foo\"), $0->getAttr(\"foo\"))">;
696+
697+
def CopyAttrBar: NativeCodeCallVoid<
698+
"$1->setAttr($_builder.getStringAttr(\"bar\"), $0->getAttr(\"bar\"))">;
699+
700+
701+
def : Pattern<
702+
(ThreeResultOp:$src ...),
703+
[(ZeroResultOp:$dest1 ...), (ThreeResultOp:$dest2 ...)],
704+
[(CopyAttrFoo (GetOwner $src), $dest1),
705+
(CopyAttrBar (GetOwner $src), (GetOwner $dest2))]>;
706+
```
707+
708+
This will copy the attribute `foo` and `bar` of `ThreeResultOp` in the source
709+
pattern to `TwoResultOp` and `OneResultOp` in the result patterns respectively.
710+
The patterns are executed in the order they are specified.
711+
681712
### Adjusting benefits
682713

683714
The benefit of a `Pattern` is an integer value indicating the benefit of

mlir/include/mlir/IR/PatternBase.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ def addBenefit;
9090
// * `FiveResultOp`#3: `TwoResultOp2`#1
9191
// * `FiveResultOp`#4: `TwoResultOp2`#1
9292
class Pattern<dag source, list<dag> results, list<dag> preds = [],
93+
list<dag> supplemental_results = [],
9394
dag benefitAdded = (addBenefit 0)> {
9495
dag sourcePattern = source;
9596
// Result patterns. Each result pattern is expected to replace one result
@@ -103,6 +104,11 @@ class Pattern<dag source, list<dag> results, list<dag> preds = [],
103104
// matched in source pattern and places further constraints on them as a
104105
// whole.
105106
list<dag> constraints = preds;
107+
// Optional patterns that are executed after the result patterns. Similar to
108+
// auxiliary patterns, they are not used for replacement. These patterns can
109+
// be used to invoke additional code after the result patterns, e.g. copy
110+
// the attributes from the source op to the result ops.
111+
list<dag> supplementalPatterns = supplemental_results;
106112
// The delta value added to the default benefit value. The default value is
107113
// the number of ops in the source pattern. The rule with the highest final
108114
// benefit value will be applied first if there are multiple rules matches.
@@ -112,8 +118,9 @@ class Pattern<dag source, list<dag> results, list<dag> preds = [],
112118

113119
// Form of a pattern which produces a single result.
114120
class Pat<dag pattern, dag result, list<dag> preds = [],
121+
list<dag> supplemental_results = [],
115122
dag benefitAdded = (addBenefit 0)> :
116-
Pattern<pattern, [result], preds, benefitAdded>;
123+
Pattern<pattern, [result], preds, supplemental_results, benefitAdded>;
117124

118125
// Native code call wrapper. This allows invoking an arbitrary C++ expression
119126
// to create an op operand/attribute or replace an op result.

mlir/include/mlir/TableGen/Pattern.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,14 @@ class Pattern {
482482
// Returns the constraints.
483483
std::vector<AppliedConstraint> getConstraints() const;
484484

485+
// Returns the number of supplemental auxiliary patterns generated by applying
486+
// this rewrite rule.
487+
int getNumSupplementalPatterns() const;
488+
489+
// Returns the DAG tree root node of the `index`-th supplemental result
490+
// pattern.
491+
DagNode getSupplementalPattern(unsigned index) const;
492+
485493
// Returns the benefit score of the pattern.
486494
int getBenefit() const;
487495

mlir/lib/TableGen/Pattern.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,16 @@ std::vector<AppliedConstraint> Pattern::getConstraints() const {
675675
return ret;
676676
}
677677

678+
int Pattern::getNumSupplementalPatterns() const {
679+
auto *results = def.getValueAsListInit("supplementalPatterns");
680+
return results->size();
681+
}
682+
683+
DagNode Pattern::getSupplementalPattern(unsigned index) const {
684+
auto *results = def.getValueAsListInit("supplementalPatterns");
685+
return DagNode(cast<llvm::DagInit>(results->getElement(index)));
686+
}
687+
678688
int Pattern::getBenefit() const {
679689
// The initial benefit value is a heuristic with number of ops in the source
680690
// pattern.

mlir/tools/mlir-tblgen/RewriterGen.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,17 @@ void PatternEmitter::emitRewriteLogic() {
11051105
os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
11061106
}
11071107

1108+
// Process supplemtal patterns.
1109+
int numSupplementalPatterns = pattern.getNumSupplementalPatterns();
1110+
for (int i = 0, offset = -numSupplementalPatterns;
1111+
i < numSupplementalPatterns; ++i) {
1112+
DagNode resultTree = pattern.getSupplementalPattern(i);
1113+
auto val = handleResultPattern(resultTree, offset++, 0);
1114+
if (resultTree.isNativeCodeCall() &&
1115+
resultTree.getNumReturnsOfNativeCode() == 0)
1116+
os << val << ";\n";
1117+
}
1118+
11081119
LLVM_DEBUG(llvm::dbgs() << "--- done emitting rewrite logic ---\n");
11091120
}
11101121

0 commit comments

Comments
 (0)