@@ -33,6 +33,7 @@ namespace {
33
33
class RISCVVLOptimizer : public MachineFunctionPass {
34
34
const MachineRegisterInfo *MRI;
35
35
const MachineDominatorTree *MDT;
36
+ const TargetInstrInfo *TII;
36
37
37
38
public:
38
39
static char ID;
@@ -50,12 +51,15 @@ class RISCVVLOptimizer : public MachineFunctionPass {
50
51
StringRef getPassName () const override { return PASS_NAME; }
51
52
52
53
private:
53
- std::optional<MachineOperand> getMinimumVLForUser (MachineOperand &UserOp);
54
- // / Returns the largest common VL MachineOperand that may be used to optimize
55
- // / MI. Returns std::nullopt if it failed to find a suitable VL.
56
- std::optional<MachineOperand> checkUsers (MachineInstr &MI);
54
+ MachineOperand getMinimumVLForUser (MachineOperand &UserOp);
55
+ // / Computes the VL of \p MI that is actually used by its users.
56
+ MachineOperand computeDemandedVL (const MachineInstr &MI);
57
57
bool tryReduceVL (MachineInstr &MI);
58
58
bool isCandidate (const MachineInstr &MI) const ;
59
+
60
+ // / For a given instruction, records what elements of it are demanded by
61
+ // / downstream users.
62
+ DenseMap<const MachineInstr *, MachineOperand> DemandedVLs;
59
63
};
60
64
61
65
} // end anonymous namespace
@@ -1173,15 +1177,14 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
1173
1177
return true ;
1174
1178
}
1175
1179
1176
- std::optional<MachineOperand>
1177
- RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1180
+ MachineOperand RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1178
1181
const MachineInstr &UserMI = *UserOp.getParent ();
1179
1182
const MCInstrDesc &Desc = UserMI.getDesc ();
1180
1183
1181
1184
if (!RISCVII::hasVLOp (Desc.TSFlags ) || !RISCVII::hasSEWOp (Desc.TSFlags )) {
1182
1185
LLVM_DEBUG (dbgs () << " Abort due to lack of VL, assume that"
1183
1186
" use VLMAX\n " );
1184
- return std::nullopt ;
1187
+ return MachineOperand::CreateImm (RISCV::VLMaxSentinel) ;
1185
1188
}
1186
1189
1187
1190
// Instructions like reductions may use a vector register as a scalar
@@ -1201,46 +1204,59 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
1201
1204
// Looking for an immediate or a register VL that isn't X0.
1202
1205
assert ((!VLOp.isReg () || VLOp.getReg () != RISCV::X0) &&
1203
1206
" Did not expect X0 VL" );
1207
+
1208
+ // If we know the demanded VL of UserMI, then we can reduce the VL it
1209
+ // requires.
1210
+ if (DemandedVLs.contains (&UserMI)) {
1211
+ // We can only shrink the demanded VL if the elementwise result doesn't
1212
+ // depend on VL (i.e. not vredsum/viota etc.)
1213
+ // Also conservatively restrict to supported instructions for now.
1214
+ // TODO: Can we remove the isSupportedInstr check?
1215
+ if (!RISCVII::elementsDependOnVL (
1216
+ TII->get (RISCV::getRVVMCOpcode (UserMI.getOpcode ())).TSFlags ) &&
1217
+ isSupportedInstr (UserMI)) {
1218
+ const MachineOperand &DemandedVL = DemandedVLs.at (&UserMI);
1219
+ if (RISCV::isVLKnownLE (DemandedVL, VLOp))
1220
+ return DemandedVL;
1221
+ }
1222
+ }
1223
+
1204
1224
return VLOp;
1205
1225
}
1206
1226
1207
- std::optional<MachineOperand> RISCVVLOptimizer::checkUsers (MachineInstr &MI) {
1208
- // FIXME: Avoid visiting each user for each time we visit something on the
1209
- // worklist, combined with an extra visit from the outer loop. Restructure
1210
- // along lines of an instcombine style worklist which integrates the outer
1211
- // pass.
1212
- std::optional<MachineOperand> CommonVL;
1227
+ MachineOperand RISCVVLOptimizer::computeDemandedVL (const MachineInstr &MI) {
1228
+ const MachineOperand &VLMAX = MachineOperand::CreateImm (RISCV::VLMaxSentinel);
1229
+ MachineOperand DemandedVL = MachineOperand::CreateImm (0 );
1230
+
1213
1231
for (auto &UserOp : MRI->use_operands (MI.getOperand (0 ).getReg ())) {
1214
1232
const MachineInstr &UserMI = *UserOp.getParent ();
1215
1233
LLVM_DEBUG (dbgs () << " Checking user: " << UserMI << " \n " );
1216
1234
if (mayReadPastVL (UserMI)) {
1217
1235
LLVM_DEBUG (dbgs () << " Abort because used by unsafe instruction\n " );
1218
- return std::nullopt ;
1236
+ return VLMAX ;
1219
1237
}
1220
1238
1221
1239
// If used as a passthru, elements past VL will be read.
1222
1240
if (UserOp.isTied ()) {
1223
1241
LLVM_DEBUG (dbgs () << " Abort because user used as tied operand\n " );
1224
- return std::nullopt ;
1242
+ return VLMAX ;
1225
1243
}
1226
1244
1227
- auto VLOp = getMinimumVLForUser (UserOp);
1228
- if (!VLOp)
1229
- return std::nullopt;
1245
+ const MachineOperand &VLOp = getMinimumVLForUser (UserOp);
1230
1246
1231
1247
// Use the largest VL among all the users. If we cannot determine this
1232
1248
// statically, then we cannot optimize the VL.
1233
- if (!CommonVL || RISCV::isVLKnownLE (*CommonVL, * VLOp)) {
1234
- CommonVL = * VLOp;
1235
- LLVM_DEBUG (dbgs () << " User VL is: " << VLOp << " \n " );
1236
- } else if (!RISCV::isVLKnownLE (* VLOp, *CommonVL )) {
1249
+ if (RISCV::isVLKnownLE (DemandedVL, VLOp)) {
1250
+ DemandedVL = VLOp;
1251
+ LLVM_DEBUG (dbgs () << " Demanded VL is: " << VLOp << " \n " );
1252
+ } else if (!RISCV::isVLKnownLE (VLOp, DemandedVL )) {
1237
1253
LLVM_DEBUG (dbgs () << " Abort because cannot determine a common VL\n " );
1238
- return std::nullopt ;
1254
+ return VLMAX ;
1239
1255
}
1240
1256
1241
1257
if (!RISCVII::hasSEWOp (UserMI.getDesc ().TSFlags )) {
1242
1258
LLVM_DEBUG (dbgs () << " Abort due to lack of SEW operand\n " );
1243
- return std::nullopt ;
1259
+ return VLMAX ;
1244
1260
}
1245
1261
1246
1262
std::optional<OperandInfo> ConsumerInfo = getOperandInfo (UserOp, MRI);
@@ -1250,7 +1266,7 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1250
1266
LLVM_DEBUG (dbgs () << " Abort due to unknown operand information.\n " );
1251
1267
LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
1252
1268
LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1253
- return std::nullopt ;
1269
+ return VLMAX ;
1254
1270
}
1255
1271
1256
1272
// If the operand is used as a scalar operand, then the EEW must be
@@ -1265,11 +1281,11 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1265
1281
<< " Abort due to incompatible information for EMUL or EEW.\n " );
1266
1282
LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
1267
1283
LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1268
- return std::nullopt ;
1284
+ return VLMAX ;
1269
1285
}
1270
1286
}
1271
1287
1272
- return CommonVL ;
1288
+ return DemandedVL ;
1273
1289
}
1274
1290
1275
1291
bool RISCVVLOptimizer::tryReduceVL (MachineInstr &MI) {
@@ -1285,40 +1301,40 @@ bool RISCVVLOptimizer::tryReduceVL(MachineInstr &MI) {
1285
1301
return false ;
1286
1302
}
1287
1303
1288
- auto CommonVL = checkUsers (MI) ;
1304
+ auto CommonVL = DemandedVLs[&MI] ;
1289
1305
if (!CommonVL)
1290
1306
return false ;
1291
1307
1292
- assert ((CommonVL-> isImm () || CommonVL-> getReg ().isVirtual ()) &&
1308
+ assert ((CommonVL. isImm () || CommonVL. getReg ().isVirtual ()) &&
1293
1309
" Expected VL to be an Imm or virtual Reg" );
1294
1310
1295
1311
if (!RISCV::isVLKnownLE (*CommonVL, VLOp)) {
1296
1312
LLVM_DEBUG (dbgs () << " Abort due to CommonVL not <= VLOp.\n " );
1297
1313
return false ;
1298
1314
}
1299
1315
1300
- if (CommonVL-> isIdenticalTo (VLOp)) {
1316
+ if (CommonVL. isIdenticalTo (VLOp)) {
1301
1317
LLVM_DEBUG (
1302
- dbgs () << " Abort due to CommonVL == VLOp, no point in reducing.\n " );
1318
+ dbgs ()
1319
+ << " Abort due to DemandedVL == VLOp, no point in reducing.\n " );
1303
1320
return false ;
1304
1321
}
1305
1322
1306
- if (CommonVL-> isImm ()) {
1323
+ if (CommonVL. isImm ()) {
1307
1324
LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1308
- << CommonVL-> getImm () << " for " << MI << " \n " );
1309
- VLOp.ChangeToImmediate (CommonVL-> getImm ());
1325
+ << CommonVL. getImm () << " for " << MI << " \n " );
1326
+ VLOp.ChangeToImmediate (CommonVL. getImm ());
1310
1327
return true ;
1311
1328
}
1312
- const MachineInstr *VLMI = MRI->getVRegDef (CommonVL-> getReg ());
1329
+ const MachineInstr *VLMI = MRI->getVRegDef (CommonVL. getReg ());
1313
1330
if (!MDT->dominates (VLMI, &MI))
1314
1331
return false ;
1315
- LLVM_DEBUG (
1316
- dbgs () << " Reduce VL from " << VLOp << " to "
1317
- << printReg (CommonVL->getReg (), MRI->getTargetRegisterInfo ())
1318
- << " for " << MI << " \n " );
1332
+ LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1333
+ << printReg (CommonVL.getReg (), MRI->getTargetRegisterInfo ())
1334
+ << " for " << MI << " \n " );
1319
1335
1320
1336
// All our checks passed. We can reduce VL.
1321
- VLOp.ChangeToRegister (CommonVL-> getReg (), false );
1337
+ VLOp.ChangeToRegister (CommonVL. getReg (), false );
1322
1338
return true ;
1323
1339
}
1324
1340
@@ -1333,52 +1349,33 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
1333
1349
if (!ST.hasVInstructions ())
1334
1350
return false ;
1335
1351
1336
- SetVector<MachineInstr *> Worklist;
1337
- auto PushOperands = [this , &Worklist](MachineInstr &MI,
1338
- bool IgnoreSameBlock) {
1339
- for (auto &Op : MI.operands ()) {
1340
- if (!Op.isReg () || !Op.isUse () || !Op.getReg ().isVirtual () ||
1341
- !isVectorRegClass (Op.getReg (), MRI))
1342
- continue ;
1352
+ TII = ST.getInstrInfo ();
1343
1353
1344
- MachineInstr *DefMI = MRI->getVRegDef (Op.getReg ());
1345
- if (!isCandidate (*DefMI))
1346
- continue ;
1347
-
1348
- if (IgnoreSameBlock && DefMI->getParent () == MI.getParent ())
1349
- continue ;
1350
-
1351
- Worklist.insert (DefMI);
1352
- }
1353
- };
1354
-
1355
- // Do a first pass eagerly rewriting in roughly reverse instruction
1356
- // order, populate the worklist with any instructions we might need to
1357
- // revisit. We avoid adding definitions to the worklist if they're
1358
- // in the same block - we're about to visit them anyways.
1359
1354
bool MadeChange = false ;
1360
1355
for (MachineBasicBlock &MBB : MF) {
1361
1356
// Avoid unreachable blocks as they have degenerate dominance
1362
1357
if (!MDT->isReachableFromEntry (&MBB))
1363
1358
continue ;
1364
1359
1365
- for (auto &MI : reverse (MBB)) {
1360
+ // For each instruction that defines a vector, compute what VL its
1361
+ // downstream users demand.
1362
+ for (const auto &MI : reverse (MBB)) {
1363
+ if (!isCandidate (MI))
1364
+ continue ;
1365
+ DemandedVLs.insert ({&MI, computeDemandedVL (MI)});
1366
+ }
1367
+
1368
+ // Then go through and see if we can reduce the VL of any instructions to
1369
+ // only what's demanded.
1370
+ for (auto &MI : MBB) {
1366
1371
if (!isCandidate (MI))
1367
1372
continue ;
1368
1373
if (!tryReduceVL (MI))
1369
1374
continue ;
1370
1375
MadeChange = true ;
1371
- PushOperands (MI, /* IgnoreSameBlock*/ true );
1372
1376
}
1373
- }
1374
1377
1375
- while (!Worklist.empty ()) {
1376
- assert (MadeChange);
1377
- MachineInstr &MI = *Worklist.pop_back_val ();
1378
- assert (isCandidate (MI));
1379
- if (!tryReduceVL (MI))
1380
- continue ;
1381
- PushOperands (MI, /* IgnoreSameBlock*/ false );
1378
+ DemandedVLs.clear ();
1382
1379
}
1383
1380
1384
1381
return MadeChange;
0 commit comments