@@ -33,6 +33,7 @@ namespace {
33
33
class RISCVVLOptimizer : public MachineFunctionPass {
34
34
const MachineRegisterInfo *MRI;
35
35
const MachineDominatorTree *MDT;
36
+ const TargetInstrInfo *TII;
36
37
37
38
public:
38
39
static char ID;
@@ -50,12 +51,15 @@ class RISCVVLOptimizer : public MachineFunctionPass {
50
51
StringRef getPassName () const override { return PASS_NAME; }
51
52
52
53
private:
53
- std::optional<MachineOperand> getMinimumVLForUser (MachineOperand &UserOp);
54
- // / Returns the largest common VL MachineOperand that may be used to optimize
55
- // / MI. Returns std::nullopt if it failed to find a suitable VL.
56
- std::optional<MachineOperand> checkUsers (MachineInstr &MI);
54
+ MachineOperand getMinimumVLForUser (MachineOperand &UserOp);
55
+ // / Computes the VL of \p MI that is actually used by its users.
56
+ MachineOperand computeDemandedVL (const MachineInstr &MI);
57
57
bool tryReduceVL (MachineInstr &MI);
58
58
bool isCandidate (const MachineInstr &MI) const ;
59
+
60
+ // / For a given instruction, records what elements of it are demanded by
61
+ // / downstream users.
62
+ DenseMap<const MachineInstr *, MachineOperand> DemandedVLs;
59
63
};
60
64
61
65
} // end anonymous namespace
@@ -1202,15 +1206,14 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
1202
1206
return true ;
1203
1207
}
1204
1208
1205
- std::optional<MachineOperand>
1206
- RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1209
+ MachineOperand RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1207
1210
const MachineInstr &UserMI = *UserOp.getParent ();
1208
1211
const MCInstrDesc &Desc = UserMI.getDesc ();
1209
1212
1210
1213
if (!RISCVII::hasVLOp (Desc.TSFlags ) || !RISCVII::hasSEWOp (Desc.TSFlags )) {
1211
1214
LLVM_DEBUG (dbgs () << " Abort due to lack of VL, assume that"
1212
1215
" use VLMAX\n " );
1213
- return std::nullopt ;
1216
+ return MachineOperand::CreateImm (RISCV::VLMaxSentinel) ;
1214
1217
}
1215
1218
1216
1219
// Instructions like reductions may use a vector register as a scalar
@@ -1230,46 +1233,59 @@ RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
1230
1233
// Looking for an immediate or a register VL that isn't X0.
1231
1234
assert ((!VLOp.isReg () || VLOp.getReg () != RISCV::X0) &&
1232
1235
" Did not expect X0 VL" );
1236
+
1237
+ // If we know the demanded VL of UserMI, then we can reduce the VL it
1238
+ // requires.
1239
+ if (DemandedVLs.contains (&UserMI)) {
1240
+ // We can only shrink the demanded VL if the elementwise result doesn't
1241
+ // depend on VL (i.e. not vredsum/viota etc.)
1242
+ // Also conservatively restrict to supported instructions for now.
1243
+ // TODO: Can we remove the isSupportedInstr check?
1244
+ if (!RISCVII::elementsDependOnVL (
1245
+ TII->get (RISCV::getRVVMCOpcode (UserMI.getOpcode ())).TSFlags ) &&
1246
+ isSupportedInstr (UserMI)) {
1247
+ const MachineOperand &DemandedVL = DemandedVLs.at (&UserMI);
1248
+ if (RISCV::isVLKnownLE (DemandedVL, VLOp))
1249
+ return DemandedVL;
1250
+ }
1251
+ }
1252
+
1233
1253
return VLOp;
1234
1254
}
1235
1255
1236
- std::optional<MachineOperand> RISCVVLOptimizer::checkUsers (MachineInstr &MI) {
1237
- // FIXME: Avoid visiting each user for each time we visit something on the
1238
- // worklist, combined with an extra visit from the outer loop. Restructure
1239
- // along lines of an instcombine style worklist which integrates the outer
1240
- // pass.
1241
- std::optional<MachineOperand> CommonVL;
1256
+ MachineOperand RISCVVLOptimizer::computeDemandedVL (const MachineInstr &MI) {
1257
+ const MachineOperand &VLMAX = MachineOperand::CreateImm (RISCV::VLMaxSentinel);
1258
+ MachineOperand DemandedVL = MachineOperand::CreateImm (0 );
1259
+
1242
1260
for (auto &UserOp : MRI->use_operands (MI.getOperand (0 ).getReg ())) {
1243
1261
const MachineInstr &UserMI = *UserOp.getParent ();
1244
1262
LLVM_DEBUG (dbgs () << " Checking user: " << UserMI << " \n " );
1245
1263
if (mayReadPastVL (UserMI)) {
1246
1264
LLVM_DEBUG (dbgs () << " Abort because used by unsafe instruction\n " );
1247
- return std::nullopt ;
1265
+ return VLMAX ;
1248
1266
}
1249
1267
1250
1268
// Tied operands might pass through.
1251
1269
if (UserOp.isTied ()) {
1252
1270
LLVM_DEBUG (dbgs () << " Abort because user used as tied operand\n " );
1253
- return std::nullopt ;
1271
+ return VLMAX ;
1254
1272
}
1255
1273
1256
- auto VLOp = getMinimumVLForUser (UserOp);
1257
- if (!VLOp)
1258
- return std::nullopt;
1274
+ const MachineOperand &VLOp = getMinimumVLForUser (UserOp);
1259
1275
1260
1276
// Use the largest VL among all the users. If we cannot determine this
1261
1277
// statically, then we cannot optimize the VL.
1262
- if (!CommonVL || RISCV::isVLKnownLE (*CommonVL, * VLOp)) {
1263
- CommonVL = * VLOp;
1264
- LLVM_DEBUG (dbgs () << " User VL is: " << VLOp << " \n " );
1265
- } else if (!RISCV::isVLKnownLE (* VLOp, *CommonVL )) {
1278
+ if (RISCV::isVLKnownLE (DemandedVL, VLOp)) {
1279
+ DemandedVL = VLOp;
1280
+ LLVM_DEBUG (dbgs () << " Demanded VL is: " << VLOp << " \n " );
1281
+ } else if (!RISCV::isVLKnownLE (VLOp, DemandedVL )) {
1266
1282
LLVM_DEBUG (dbgs () << " Abort because cannot determine a common VL\n " );
1267
- return std::nullopt ;
1283
+ return VLMAX ;
1268
1284
}
1269
1285
1270
1286
if (!RISCVII::hasSEWOp (UserMI.getDesc ().TSFlags )) {
1271
1287
LLVM_DEBUG (dbgs () << " Abort due to lack of SEW operand\n " );
1272
- return std::nullopt ;
1288
+ return VLMAX ;
1273
1289
}
1274
1290
1275
1291
std::optional<OperandInfo> ConsumerInfo = getOperandInfo (UserOp, MRI);
@@ -1279,7 +1295,7 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1279
1295
LLVM_DEBUG (dbgs () << " Abort due to unknown operand information.\n " );
1280
1296
LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
1281
1297
LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1282
- return std::nullopt ;
1298
+ return VLMAX ;
1283
1299
}
1284
1300
1285
1301
// If the operand is used as a scalar operand, then the EEW must be
@@ -1294,53 +1310,51 @@ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers(MachineInstr &MI) {
1294
1310
<< " Abort due to incompatible information for EMUL or EEW.\n " );
1295
1311
LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
1296
1312
LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1297
- return std::nullopt ;
1313
+ return VLMAX ;
1298
1314
}
1299
1315
}
1300
1316
1301
- return CommonVL ;
1317
+ return DemandedVL ;
1302
1318
}
1303
1319
1304
1320
bool RISCVVLOptimizer::tryReduceVL (MachineInstr &MI) {
1305
1321
LLVM_DEBUG (dbgs () << " Trying to reduce VL for " << MI << " \n " );
1306
1322
1307
- auto CommonVL = checkUsers (MI);
1308
- if (!CommonVL)
1309
- return false ;
1323
+ const MachineOperand &CommonVL = DemandedVLs.at (&MI);
1310
1324
1311
- assert ((CommonVL-> isImm () || CommonVL-> getReg ().isVirtual ()) &&
1325
+ assert ((CommonVL. isImm () || CommonVL. getReg ().isVirtual ()) &&
1312
1326
" Expected VL to be an Imm or virtual Reg" );
1313
1327
1314
1328
unsigned VLOpNum = RISCVII::getVLOpNum (MI.getDesc ());
1315
1329
MachineOperand &VLOp = MI.getOperand (VLOpNum);
1316
1330
1317
- if (!RISCV::isVLKnownLE (* CommonVL, VLOp)) {
1318
- LLVM_DEBUG (dbgs () << " Abort due to CommonVL not <= VLOp.\n " );
1331
+ if (!RISCV::isVLKnownLE (CommonVL, VLOp)) {
1332
+ LLVM_DEBUG (dbgs () << " Abort due to DemandedVL not <= VLOp.\n " );
1319
1333
return false ;
1320
1334
}
1321
1335
1322
- if (CommonVL-> isIdenticalTo (VLOp)) {
1336
+ if (CommonVL. isIdenticalTo (VLOp)) {
1323
1337
LLVM_DEBUG (
1324
- dbgs () << " Abort due to CommonVL == VLOp, no point in reducing.\n " );
1338
+ dbgs ()
1339
+ << " Abort due to DemandedVL == VLOp, no point in reducing.\n " );
1325
1340
return false ;
1326
1341
}
1327
1342
1328
- if (CommonVL-> isImm ()) {
1343
+ if (CommonVL. isImm ()) {
1329
1344
LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1330
- << CommonVL-> getImm () << " for " << MI << " \n " );
1331
- VLOp.ChangeToImmediate (CommonVL-> getImm ());
1345
+ << CommonVL. getImm () << " for " << MI << " \n " );
1346
+ VLOp.ChangeToImmediate (CommonVL. getImm ());
1332
1347
return true ;
1333
1348
}
1334
- const MachineInstr *VLMI = MRI->getVRegDef (CommonVL-> getReg ());
1349
+ const MachineInstr *VLMI = MRI->getVRegDef (CommonVL. getReg ());
1335
1350
if (!MDT->dominates (VLMI, &MI))
1336
1351
return false ;
1337
- LLVM_DEBUG (
1338
- dbgs () << " Reduce VL from " << VLOp << " to "
1339
- << printReg (CommonVL->getReg (), MRI->getTargetRegisterInfo ())
1340
- << " for " << MI << " \n " );
1352
+ LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1353
+ << printReg (CommonVL.getReg (), MRI->getTargetRegisterInfo ())
1354
+ << " for " << MI << " \n " );
1341
1355
1342
1356
// All our checks passed. We can reduce VL.
1343
- VLOp.ChangeToRegister (CommonVL-> getReg (), false );
1357
+ VLOp.ChangeToRegister (CommonVL. getReg (), false );
1344
1358
return true ;
1345
1359
}
1346
1360
@@ -1355,52 +1369,33 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
1355
1369
if (!ST.hasVInstructions ())
1356
1370
return false ;
1357
1371
1358
- SetVector<MachineInstr *> Worklist;
1359
- auto PushOperands = [this , &Worklist](MachineInstr &MI,
1360
- bool IgnoreSameBlock) {
1361
- for (auto &Op : MI.operands ()) {
1362
- if (!Op.isReg () || !Op.isUse () || !Op.getReg ().isVirtual () ||
1363
- !isVectorRegClass (Op.getReg (), MRI))
1364
- continue ;
1365
-
1366
- MachineInstr *DefMI = MRI->getVRegDef (Op.getReg ());
1367
- if (!isCandidate (*DefMI))
1368
- continue ;
1369
-
1370
- if (IgnoreSameBlock && DefMI->getParent () == MI.getParent ())
1371
- continue ;
1372
-
1373
- Worklist.insert (DefMI);
1374
- }
1375
- };
1372
+ TII = ST.getInstrInfo ();
1376
1373
1377
- // Do a first pass eagerly rewriting in roughly reverse instruction
1378
- // order, populate the worklist with any instructions we might need to
1379
- // revisit. We avoid adding definitions to the worklist if they're
1380
- // in the same block - we're about to visit them anyways.
1381
1374
bool MadeChange = false ;
1382
1375
for (MachineBasicBlock &MBB : MF) {
1383
1376
// Avoid unreachable blocks as they have degenerate dominance
1384
1377
if (!MDT->isReachableFromEntry (&MBB))
1385
1378
continue ;
1386
1379
1387
- for (auto &MI : make_range (MBB.rbegin (), MBB.rend ())) {
1380
+ // For each instruction that defines a vector, compute what VL its
1381
+ // downstream users demand.
1382
+ for (const auto &MI : reverse (MBB)) {
1383
+ if (!isCandidate (MI))
1384
+ continue ;
1385
+ DemandedVLs.insert ({&MI, computeDemandedVL (MI)});
1386
+ }
1387
+
1388
+ // Then go through and see if we can reduce the VL of any instructions to
1389
+ // only what's demanded.
1390
+ for (auto &MI : MBB) {
1388
1391
if (!isCandidate (MI))
1389
1392
continue ;
1390
1393
if (!tryReduceVL (MI))
1391
1394
continue ;
1392
1395
MadeChange = true ;
1393
- PushOperands (MI, /* IgnoreSameBlock*/ true );
1394
1396
}
1395
- }
1396
1397
1397
- while (!Worklist.empty ()) {
1398
- assert (MadeChange);
1399
- MachineInstr &MI = *Worklist.pop_back_val ();
1400
- assert (isCandidate (MI));
1401
- if (!tryReduceVL (MI))
1402
- continue ;
1403
- PushOperands (MI, /* IgnoreSameBlock*/ false );
1398
+ DemandedVLs.clear ();
1404
1399
}
1405
1400
1406
1401
return MadeChange;
0 commit comments