Skip to content

Commit d7954a8

Browse files
committed
[MacroFusion] Support commutable instructions
If the second instruction is commutable, we should be able to check its commutable operands. A field `IsCommutable` is added to indicate whether we should generate code for checking commutable operands. Fixes llvm#82738
1 parent de1f338 commit d7954a8

File tree

3 files changed

+145
-26
lines changed

3 files changed

+145
-26
lines changed

llvm/include/llvm/Target/TargetSchedule.td

+16-5
Original file line numberDiff line numberDiff line change
@@ -622,11 +622,22 @@ class BothFusionPredicateWithMCInstPredicate<MCInstPredicate pred>
622622
// Tie firstOpIdx and secondOpIdx. The operand of `FirstMI` at position
623623
// `firstOpIdx` should be the same as the operand of `SecondMI` at position
624624
// `secondOpIdx`.
625+
// If the fusion has `IsCommutable` being true and the operand at `secondOpIdx`
626+
// has commutable operand, then the commutable operand will be checked too.
625627
class TieReg<int firstOpIdx, int secondOpIdx> : BothFusionPredicate {
626628
int FirstOpIdx = firstOpIdx;
627629
int SecondOpIdx = secondOpIdx;
628630
}
629631

632+
// The operand of `SecondMI` at position `firstOpIdx` should be the same as the
633+
// operand at position `secondOpIdx`.
634+
// If the fusion has `IsCommutable` being true and the operand at `secondOpIdx`
635+
// has commutable operand, then the commutable operand will be checked too.
636+
class SameReg<int firstOpIdx, int secondOpIdx> : SecondFusionPredicate {
637+
int FirstOpIdx = firstOpIdx;
638+
int SecondOpIdx = secondOpIdx;
639+
}
640+
630641
// A predicate for wildcard. The generated code will be like:
631642
// ```
632643
// if (!FirstMI)
@@ -655,9 +666,12 @@ def OneUse : OneUsePred;
655666
// return true;
656667
// }
657668
// ```
669+
//
670+
// `IsCommutable` means whether we should handle commutable operands.
658671
class Fusion<string name, string fieldName, string desc, list<FusionPredicate> predicates>
659672
: SubtargetFeature<name, fieldName, "true", desc> {
660673
list<FusionPredicate> Predicates = predicates;
674+
bit IsCommutable = 0;
661675
}
662676

663677
// The generated predicator will be like:
@@ -671,6 +685,7 @@ class Fusion<string name, string fieldName, string desc, list<FusionPredicate> p
671685
// /* Predicate for `SecondMI` */
672686
// /* Wildcard */
673687
// /* Predicate for `FirstMI` */
688+
// /* Check same registers */
674689
// /* Check One Use */
675690
// /* Tie registers */
676691
// /* Epilog */
@@ -688,11 +703,7 @@ class SimpleFusion<string name, string fieldName, string desc,
688703
SecondFusionPredicateWithMCInstPredicate<secondPred>,
689704
WildcardTrue,
690705
FirstFusionPredicateWithMCInstPredicate<firstPred>,
691-
SecondFusionPredicateWithMCInstPredicate<
692-
CheckAny<[
693-
CheckIsVRegOperand<0>,
694-
CheckSameRegOperand<0, 1>
695-
]>>,
706+
SameReg<0, 1>,
696707
OneUse,
697708
TieReg<0, 1>,
698709
],

llvm/test/TableGen/MacroFusion.td

+58-3
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,21 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
4646
CheckRegOperand<0, X0>
4747
]>>;
4848

49+
let IsCommutable = 1 in
50+
def TestCommutableFusion: SimpleFusion<"test-commutable-fusion", "HasTestCommutableFusion",
51+
"Test Commutable Fusion",
52+
CheckOpcode<[Inst0]>,
53+
CheckAll<[
54+
CheckOpcode<[Inst1]>,
55+
CheckRegOperand<0, X0>
56+
]>>;
57+
4958
// CHECK-PREDICATOR: #ifdef GET_Test_MACRO_FUSION_PRED_DECL
5059
// CHECK-PREDICATOR-NEXT: #undef GET_Test_MACRO_FUSION_PRED_DECL
5160
// CHECK-PREDICATOR-EMPTY:
5261
// CHECK-PREDICATOR-NEXT: namespace llvm {
5362
// CHECK-PREDICATOR-NEXT: bool isTestBothFusionPredicate(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
63+
// CHECK-PREDICATOR-NEXT: bool isTestCommutableFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
5464
// CHECK-PREDICATOR-NEXT: bool isTestFusion(const TargetInstrInfo &, const TargetSubtargetInfo &, const MachineInstr *, const MachineInstr &);
5565
// CHECK-PREDICATOR-NEXT: } // end namespace llvm
5666
// CHECK-PREDICATOR-EMPTY:
@@ -78,7 +88,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
7888
// CHECK-PREDICATOR-NEXT: }
7989
// CHECK-PREDICATOR-NEXT: return true;
8090
// CHECK-PREDICATOR-NEXT: }
81-
// CHECK-PREDICATOR-NEXT: bool isTestFusion(
91+
// CHECK-PREDICATOR-NEXT: bool isTestCommutableFusion(
8292
// CHECK-PREDICATOR-NEXT: const TargetInstrInfo &TII,
8393
// CHECK-PREDICATOR-NEXT: const TargetSubtargetInfo &STI,
8494
// CHECK-PREDICATOR-NEXT: const MachineInstr *FirstMI,
@@ -99,14 +109,58 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
99109
// CHECK-PREDICATOR-NEXT: if (( MI->getOpcode() != Test::Inst0 ))
100110
// CHECK-PREDICATOR-NEXT: return false;
101111
// CHECK-PREDICATOR-NEXT: }
112+
// CHECK-PREDICATOR-NEXT: if (!SecondMI.getOperand(0).getReg().isVirtual()) {
113+
// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg()) {
114+
// CHECK-PREDICATOR-NEXT: if (!SecondMI.getDesc().isCommutable())
115+
// CHECK-PREDICATOR-NEXT: return false;
116+
// CHECK-PREDICATOR-NEXT: unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
117+
// CHECK-PREDICATOR-NEXT: if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
118+
// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
119+
// CHECK-PREDICATOR-NEXT: return false;
120+
// CHECK-PREDICATOR-NEXT: }
121+
// CHECK-PREDICATOR-NEXT: }
122+
// CHECK-PREDICATOR-NEXT: {
123+
// CHECK-PREDICATOR-NEXT: Register FirstDest = FirstMI->getOperand(0).getReg();
124+
// CHECK-PREDICATOR-NEXT: if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))
125+
// CHECK-PREDICATOR-NEXT: return false;
126+
// CHECK-PREDICATOR-NEXT: }
127+
// CHECK-PREDICATOR-NEXT: if (!(FirstMI->getOperand(0).isReg() &&
128+
// CHECK-PREDICATOR-NEXT: SecondMI.getOperand(1).isReg() &&
129+
// CHECK-PREDICATOR-NEXT: FirstMI->getOperand(0).getReg() == SecondMI.getOperand(1).getReg())) {
130+
// CHECK-PREDICATOR-NEXT: if (!SecondMI.getDesc().isCommutable())
131+
// CHECK-PREDICATOR-NEXT: return false;
132+
// CHECK-PREDICATOR-NEXT: unsigned SrcOpIdx1 = 1, SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;
133+
// CHECK-PREDICATOR-NEXT: if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))
134+
// CHECK-PREDICATOR-NEXT: if (FirstMI->getOperand(0).getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())
135+
// CHECK-PREDICATOR-NEXT: return false;
136+
// CHECK-PREDICATOR-NEXT: }
137+
// CHECK-PREDICATOR-NEXT: return true;
138+
// CHECK-PREDICATOR-NEXT: }
139+
// CHECK-PREDICATOR-NEXT: bool isTestFusion(
140+
// CHECK-PREDICATOR-NEXT: const TargetInstrInfo &TII,
141+
// CHECK-PREDICATOR-NEXT: const TargetSubtargetInfo &STI,
142+
// CHECK-PREDICATOR-NEXT: const MachineInstr *FirstMI,
143+
// CHECK-PREDICATOR-NEXT: const MachineInstr &SecondMI) {
144+
// CHECK-PREDICATOR-NEXT: auto &MRI = SecondMI.getMF()->getRegInfo();
102145
// CHECK-PREDICATOR-NEXT: {
103146
// CHECK-PREDICATOR-NEXT: const MachineInstr *MI = &SecondMI;
104147
// CHECK-PREDICATOR-NEXT: if (!(
105-
// CHECK-PREDICATOR-NEXT: MI->getOperand(0).getReg().isVirtual()
106-
// CHECK-PREDICATOR-NEXT: || MI->getOperand(0).getReg() == MI->getOperand(1).getReg()
148+
// CHECK-PREDICATOR-NEXT: ( MI->getOpcode() == Test::Inst1 )
149+
// CHECK-PREDICATOR-NEXT: && MI->getOperand(0).getReg() == Test::X0
107150
// CHECK-PREDICATOR-NEXT: ))
108151
// CHECK-PREDICATOR-NEXT: return false;
109152
// CHECK-PREDICATOR-NEXT: }
153+
// CHECK-PREDICATOR-NEXT: if (!FirstMI)
154+
// CHECK-PREDICATOR-NEXT: return true;
155+
// CHECK-PREDICATOR-NEXT: {
156+
// CHECK-PREDICATOR-NEXT: const MachineInstr *MI = FirstMI;
157+
// CHECK-PREDICATOR-NEXT: if (( MI->getOpcode() != Test::Inst0 ))
158+
// CHECK-PREDICATOR-NEXT: return false;
159+
// CHECK-PREDICATOR-NEXT: }
160+
// CHECK-PREDICATOR-NEXT: if (!SecondMI.getOperand(0).getReg().isVirtual()) {
161+
// CHECK-PREDICATOR-NEXT: if (SecondMI.getOperand(0).getReg() != SecondMI.getOperand(1).getReg())
162+
// CHECK-PREDICATOR-NEXT: return false;
163+
// CHECK-PREDICATOR-NEXT: }
110164
// CHECK-PREDICATOR-NEXT: {
111165
// CHECK-PREDICATOR-NEXT: Register FirstDest = FirstMI->getOperand(0).getReg();
112166
// CHECK-PREDICATOR-NEXT: if (FirstDest.isVirtual() && !MRI.hasOneNonDBGUse(FirstDest))
@@ -131,6 +185,7 @@ def TestFusion: SimpleFusion<"test-fusion", "HasTestFusion", "Test Fusion",
131185
// CHECK-SUBTARGET: std::vector<MacroFusionPredTy> TestGenSubtargetInfo::getMacroFusions() const {
132186
// CHECK-SUBTARGET-NEXT: std::vector<MacroFusionPredTy> Fusions;
133187
// CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestBothFusionPredicate)) Fusions.push_back(llvm::isTestBothFusionPredicate);
188+
// CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestCommutableFusion)) Fusions.push_back(llvm::isTestCommutableFusion);
134189
// CHECK-SUBTARGET-NEXT: if (hasFeature(Test::TestFusion)) Fusions.push_back(llvm::isTestFusion);
135190
// CHECK-SUBTARGET-NEXT: return Fusions;
136191
// CHECK-SUBTARGET-NEXT: }

llvm/utils/TableGen/MacroFusionPredicatorEmitter.cpp

+71-18
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,10 @@
4040

4141
#include "CodeGenTarget.h"
4242
#include "PredicateExpander.h"
43-
#include "llvm/ADT/SmallVector.h"
4443
#include "llvm/Support/Debug.h"
4544
#include "llvm/TableGen/Error.h"
4645
#include "llvm/TableGen/Record.h"
4746
#include "llvm/TableGen/TableGenBackend.h"
48-
#include <set>
4947
#include <vector>
5048

5149
using namespace llvm;
@@ -61,14 +59,14 @@ class MacroFusionPredicatorEmitter {
6159
raw_ostream &OS);
6260
void emitMacroFusionImpl(std::vector<Record *> Fusions, PredicateExpander &PE,
6361
raw_ostream &OS);
64-
void emitPredicates(std::vector<Record *> &FirstPredicate,
62+
void emitPredicates(std::vector<Record *> &FirstPredicate, bool IsCommutable,
6563
PredicateExpander &PE, raw_ostream &OS);
66-
void emitFirstPredicate(Record *SecondPredicate, PredicateExpander &PE,
67-
raw_ostream &OS);
68-
void emitSecondPredicate(Record *SecondPredicate, PredicateExpander &PE,
69-
raw_ostream &OS);
70-
void emitBothPredicate(Record *Predicates, PredicateExpander &PE,
71-
raw_ostream &OS);
64+
void emitFirstPredicate(Record *SecondPredicate, bool IsCommutable,
65+
PredicateExpander &PE, raw_ostream &OS);
66+
void emitSecondPredicate(Record *SecondPredicate, bool IsCommutable,
67+
PredicateExpander &PE, raw_ostream &OS);
68+
void emitBothPredicate(Record *Predicates, bool IsCommutable,
69+
PredicateExpander &PE, raw_ostream &OS);
7270

7371
public:
7472
MacroFusionPredicatorEmitter(RecordKeeper &R) : Records(R), Target(R) {}
@@ -103,6 +101,7 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
103101
for (Record *Fusion : Fusions) {
104102
std::vector<Record *> Predicates =
105103
Fusion->getValueAsListOfDefs("Predicates");
104+
bool IsCommutable = Fusion->getValueAsBit("IsCommutable");
106105

107106
OS << "bool is" << Fusion->getName() << "(\n";
108107
OS.indent(4) << "const TargetInstrInfo &TII,\n";
@@ -111,7 +110,7 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
111110
OS.indent(4) << "const MachineInstr &SecondMI) {\n";
112111
OS.indent(2) << "auto &MRI = SecondMI.getMF()->getRegInfo();\n";
113112

114-
emitPredicates(Predicates, PE, OS);
113+
emitPredicates(Predicates, IsCommutable, PE, OS);
115114

116115
OS.indent(2) << "return true;\n";
117116
OS << "}\n";
@@ -122,22 +121,24 @@ void MacroFusionPredicatorEmitter::emitMacroFusionImpl(
122121
}
123122

124123
void MacroFusionPredicatorEmitter::emitPredicates(
125-
std::vector<Record *> &Predicates, PredicateExpander &PE, raw_ostream &OS) {
124+
std::vector<Record *> &Predicates, bool IsCommutable, PredicateExpander &PE,
125+
raw_ostream &OS) {
126126
for (Record *Predicate : Predicates) {
127127
Record *Target = Predicate->getValueAsDef("Target");
128128
if (Target->getName() == "first_fusion_target")
129-
emitFirstPredicate(Predicate, PE, OS);
129+
emitFirstPredicate(Predicate, IsCommutable, PE, OS);
130130
else if (Target->getName() == "second_fusion_target")
131-
emitSecondPredicate(Predicate, PE, OS);
131+
emitSecondPredicate(Predicate, IsCommutable, PE, OS);
132132
else if (Target->getName() == "both_fusion_target")
133-
emitBothPredicate(Predicate, PE, OS);
133+
emitBothPredicate(Predicate, IsCommutable, PE, OS);
134134
else
135135
PrintFatalError(Target->getLoc(),
136136
"Unsupported 'FusionTarget': " + Target->getName());
137137
}
138138
}
139139

140140
void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
141+
bool IsCommutable,
141142
PredicateExpander &PE,
142143
raw_ostream &OS) {
143144
if (Predicate->isSubClassOf("WildcardPred")) {
@@ -170,6 +171,7 @@ void MacroFusionPredicatorEmitter::emitFirstPredicate(Record *Predicate,
170171
}
171172

172173
void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
174+
bool IsCommutable,
173175
PredicateExpander &PE,
174176
raw_ostream &OS) {
175177
if (Predicate->isSubClassOf("FusionPredicateWithMCInstPredicate")) {
@@ -182,6 +184,36 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
182184
OS << ")\n";
183185
OS.indent(4) << " return false;\n";
184186
OS.indent(2) << "}\n";
187+
} else if (Predicate->isSubClassOf("SameReg")) {
188+
int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
189+
int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
190+
191+
OS.indent(2) << "if (!SecondMI.getOperand(" << FirstOpIdx
192+
<< ").getReg().isVirtual()) {\n";
193+
OS.indent(4) << "if (SecondMI.getOperand(" << FirstOpIdx
194+
<< ").getReg() != SecondMI.getOperand(" << SecondOpIdx
195+
<< ").getReg())";
196+
197+
if (IsCommutable) {
198+
OS << " {\n";
199+
OS.indent(6) << "if (!SecondMI.getDesc().isCommutable())\n";
200+
OS.indent(6) << " return false;\n";
201+
202+
OS.indent(6)
203+
<< "unsigned SrcOpIdx1 = " << SecondOpIdx
204+
<< ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
205+
OS.indent(6)
206+
<< "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
207+
OS.indent(6)
208+
<< " if (SecondMI.getOperand(" << FirstOpIdx
209+
<< ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
210+
OS.indent(6) << " return false;\n";
211+
OS.indent(4) << "}\n";
212+
} else {
213+
OS << "\n";
214+
OS.indent(4) << " return false;\n";
215+
}
216+
OS.indent(2) << "}\n";
185217
} else {
186218
PrintFatalError(Predicate->getLoc(),
187219
"Unsupported predicate for second instruction: " +
@@ -190,13 +222,14 @@ void MacroFusionPredicatorEmitter::emitSecondPredicate(Record *Predicate,
190222
}
191223

192224
void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
225+
bool IsCommutable,
193226
PredicateExpander &PE,
194227
raw_ostream &OS) {
195228
if (Predicate->isSubClassOf("FusionPredicateWithCode"))
196229
OS << Predicate->getValueAsString("Predicate");
197230
else if (Predicate->isSubClassOf("BothFusionPredicateWithMCInstPredicate")) {
198-
emitFirstPredicate(Predicate, PE, OS);
199-
emitSecondPredicate(Predicate, PE, OS);
231+
emitFirstPredicate(Predicate, IsCommutable, PE, OS);
232+
emitSecondPredicate(Predicate, IsCommutable, PE, OS);
200233
} else if (Predicate->isSubClassOf("TieReg")) {
201234
int FirstOpIdx = Predicate->getValueAsInt("FirstOpIdx");
202235
int SecondOpIdx = Predicate->getValueAsInt("SecondOpIdx");
@@ -206,8 +239,28 @@ void MacroFusionPredicatorEmitter::emitBothPredicate(Record *Predicate,
206239
<< ").isReg() &&\n";
207240
OS.indent(2) << " FirstMI->getOperand(" << FirstOpIdx
208241
<< ").getReg() == SecondMI.getOperand(" << SecondOpIdx
209-
<< ").getReg()))\n";
210-
OS.indent(2) << " return false;\n";
242+
<< ").getReg()))";
243+
244+
if (IsCommutable) {
245+
OS << " {\n";
246+
OS.indent(4) << "if (!SecondMI.getDesc().isCommutable())\n";
247+
OS.indent(4) << " return false;\n";
248+
249+
OS.indent(4)
250+
<< "unsigned SrcOpIdx1 = " << SecondOpIdx
251+
<< ", SrcOpIdx2 = TargetInstrInfo::CommuteAnyOperandIndex;\n";
252+
OS.indent(4)
253+
<< "if (TII.findCommutedOpIndices(SecondMI, SrcOpIdx1, SrcOpIdx2))\n";
254+
OS.indent(4)
255+
<< " if (FirstMI->getOperand(" << FirstOpIdx
256+
<< ").getReg() != SecondMI.getOperand(SrcOpIdx2).getReg())\n";
257+
OS.indent(4) << " return false;\n";
258+
OS.indent(2) << "}";
259+
} else {
260+
OS << "\n";
261+
OS.indent(2) << " return false;";
262+
}
263+
OS << "\n";
211264
} else
212265
PrintFatalError(Predicate->getLoc(),
213266
"Unsupported predicate for both instruction: " +

0 commit comments

Comments
 (0)