@@ -51,11 +51,10 @@ class RISCVVLOptimizer : public MachineFunctionPass {
51
51
StringRef getPassName () const override { return PASS_NAME; }
52
52
53
53
private:
54
- MachineOperand getMinimumVLForUser (MachineOperand &UserOp);
55
- // / Computes the minimum demanded VL of \p MI, i.e. the minimum VL that's used
56
- // / by its users downstream.
57
- // / Returns 0 if MI has no users.
58
- MachineOperand computeDemandedVL (const MachineInstr &MI);
54
+ std::optional<MachineOperand> getMinimumVLForUser (MachineOperand &UserOp);
55
+ // / Returns the largest common VL MachineOperand that may be used to optimize
56
+ // / MI. Returns std::nullopt if it failed to find a suitable VL.
57
+ std::optional<MachineOperand> checkUsers (MachineInstr &MI);
59
58
bool tryReduceVL (MachineInstr &MI);
60
59
bool isCandidate (const MachineInstr &MI) const ;
61
60
@@ -1179,14 +1178,15 @@ bool RISCVVLOptimizer::isCandidate(const MachineInstr &MI) const {
1179
1178
return true ;
1180
1179
}
1181
1180
1182
- MachineOperand RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1181
+ std::optional<MachineOperand>
1182
+ RISCVVLOptimizer::getMinimumVLForUser (MachineOperand &UserOp) {
1183
1183
const MachineInstr &UserMI = *UserOp.getParent ();
1184
1184
const MCInstrDesc &Desc = UserMI.getDesc ();
1185
1185
1186
1186
if (!RISCVII::hasVLOp (Desc.TSFlags ) || !RISCVII::hasSEWOp (Desc.TSFlags )) {
1187
1187
LLVM_DEBUG (dbgs () << " Abort due to lack of VL, assume that"
1188
1188
" use VLMAX\n " );
1189
- return MachineOperand::CreateImm (RISCV::VLMaxSentinel) ;
1189
+ return std::nullopt ;
1190
1190
}
1191
1191
1192
1192
// Instructions like reductions may use a vector register as a scalar
@@ -1223,40 +1223,39 @@ MachineOperand RISCVVLOptimizer::getMinimumVLForUser(MachineOperand &UserOp) {
1223
1223
return VLOp;
1224
1224
}
1225
1225
1226
- MachineOperand RISCVVLOptimizer::computeDemandedVL (const MachineInstr &MI) {
1227
- const MachineOperand &VLMAX = MachineOperand::CreateImm (RISCV::VLMaxSentinel);
1228
- MachineOperand DemandedVL = MachineOperand::CreateImm (0 );
1229
-
1226
+ std::optional<MachineOperand> RISCVVLOptimizer::checkUsers (MachineInstr &MI) {
1227
+ std::optional<MachineOperand> CommonVL;
1230
1228
for (auto &UserOp : MRI->use_operands (MI.getOperand (0 ).getReg ())) {
1231
1229
const MachineInstr &UserMI = *UserOp.getParent ();
1232
1230
LLVM_DEBUG (dbgs () << " Checking user: " << UserMI << " \n " );
1233
1231
if (mayReadPastVL (UserMI)) {
1234
1232
LLVM_DEBUG (dbgs () << " Abort because used by unsafe instruction\n " );
1235
- return VLMAX ;
1233
+ return std::nullopt ;
1236
1234
}
1237
1235
1238
1236
// If used as a passthru, elements past VL will be read.
1239
1237
if (UserOp.isTied ()) {
1240
1238
LLVM_DEBUG (dbgs () << " Abort because user used as tied operand\n " );
1241
- return VLMAX ;
1239
+ return std::nullopt ;
1242
1240
}
1243
1241
1244
- const MachineOperand &VLOp = getMinimumVLForUser (UserOp);
1245
-
1246
- // The minimum demanded VL is the largest VL read amongst all the users. If
1247
- // we cannot determine this statically, then we cannot optimize the VL.
1248
- if (RISCV::isVLKnownLE (DemandedVL, VLOp)) {
1249
- DemandedVL = VLOp;
1250
- LLVM_DEBUG (dbgs () << " Demanded VL is: " << VLOp << " \n " );
1251
- } else if (!RISCV::isVLKnownLE (VLOp, DemandedVL)) {
1252
- LLVM_DEBUG (
1253
- dbgs () << " Abort because cannot determine the demanded VL\n " );
1254
- return VLMAX;
1242
+ auto VLOp = getMinimumVLForUser (UserOp);
1243
+ if (!VLOp)
1244
+ return std::nullopt;
1245
+
1246
+ // Use the largest VL among all the users. If we cannot determine this
1247
+ // statically, then we cannot optimize the VL.
1248
+ if (!CommonVL || RISCV::isVLKnownLE (*CommonVL, *VLOp)) {
1249
+ CommonVL = *VLOp;
1250
+ LLVM_DEBUG (dbgs () << " User VL is: " << VLOp << " \n " );
1251
+ } else if (!RISCV::isVLKnownLE (*VLOp, *CommonVL)) {
1252
+ LLVM_DEBUG (dbgs () << " Abort because cannot determine a common VL\n " );
1253
+ return std::nullopt;
1255
1254
}
1256
1255
1257
1256
if (!RISCVII::hasSEWOp (UserMI.getDesc ().TSFlags )) {
1258
1257
LLVM_DEBUG (dbgs () << " Abort due to lack of SEW operand\n " );
1259
- return VLMAX ;
1258
+ return std::nullopt ;
1260
1259
}
1261
1260
1262
1261
std::optional<OperandInfo> ConsumerInfo = getOperandInfo (UserOp, MRI);
@@ -1266,7 +1265,7 @@ MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
1266
1265
LLVM_DEBUG (dbgs () << " Abort due to unknown operand information.\n " );
1267
1266
LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
1268
1267
LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1269
- return VLMAX ;
1268
+ return std::nullopt ;
1270
1269
}
1271
1270
1272
1271
// If the operand is used as a scalar operand, then the EEW must be
@@ -1281,52 +1280,53 @@ MachineOperand RISCVVLOptimizer::computeDemandedVL(const MachineInstr &MI) {
1281
1280
<< " Abort due to incompatible information for EMUL or EEW.\n " );
1282
1281
LLVM_DEBUG (dbgs () << " ConsumerInfo is: " << ConsumerInfo << " \n " );
1283
1282
LLVM_DEBUG (dbgs () << " ProducerInfo is: " << ProducerInfo << " \n " );
1284
- return VLMAX ;
1283
+ return std::nullopt ;
1285
1284
}
1286
1285
}
1287
1286
1288
- return DemandedVL ;
1287
+ return CommonVL ;
1289
1288
}
1290
1289
1291
1290
bool RISCVVLOptimizer::tryReduceVL (MachineInstr &MI) {
1292
1291
LLVM_DEBUG (dbgs () << " Trying to reduce VL for " << MI << " \n " );
1293
1292
1294
- const MachineOperand &DemandedVL = DemandedVLs.at (&MI);
1293
+ if (!DemandedVLs.contains (&MI))
1294
+ return false ;
1295
+ auto CommonVL = std::make_optional (DemandedVLs.at (&MI));
1295
1296
1296
- assert ((DemandedVL. isImm () || DemandedVL. getReg ().isVirtual ()) &&
1297
+ assert ((CommonVL-> isImm () || CommonVL-> getReg ().isVirtual ()) &&
1297
1298
" Expected VL to be an Imm or virtual Reg" );
1298
1299
1299
1300
unsigned VLOpNum = RISCVII::getVLOpNum (MI.getDesc ());
1300
1301
MachineOperand &VLOp = MI.getOperand (VLOpNum);
1301
1302
1302
- if (!RISCV::isVLKnownLE (DemandedVL , VLOp)) {
1303
- LLVM_DEBUG (dbgs () << " Abort due to DemandedVL not <= VLOp.\n " );
1303
+ if (!RISCV::isVLKnownLE (*CommonVL , VLOp)) {
1304
+ LLVM_DEBUG (dbgs () << " Abort due to CommonVL not <= VLOp.\n " );
1304
1305
return false ;
1305
1306
}
1306
1307
1307
- if (DemandedVL. isIdenticalTo (VLOp)) {
1308
+ if (CommonVL-> isIdenticalTo (VLOp)) {
1308
1309
LLVM_DEBUG (
1309
- dbgs ()
1310
- << " Abort due to DemandedVL == VLOp, no point in reducing.\n " );
1310
+ dbgs () << " Abort due to CommonVL == VLOp, no point in reducing.\n " );
1311
1311
return false ;
1312
1312
}
1313
1313
1314
- if (DemandedVL. isImm ()) {
1314
+ if (CommonVL-> isImm ()) {
1315
1315
LLVM_DEBUG (dbgs () << " Reduce VL from " << VLOp << " to "
1316
- << DemandedVL. getImm () << " for " << MI << " \n " );
1317
- VLOp.ChangeToImmediate (DemandedVL. getImm ());
1316
+ << CommonVL-> getImm () << " for " << MI << " \n " );
1317
+ VLOp.ChangeToImmediate (CommonVL-> getImm ());
1318
1318
return true ;
1319
1319
}
1320
- const MachineInstr *VLMI = MRI->getVRegDef (DemandedVL. getReg ());
1320
+ const MachineInstr *VLMI = MRI->getVRegDef (CommonVL-> getReg ());
1321
1321
if (!MDT->dominates (VLMI, &MI))
1322
1322
return false ;
1323
1323
LLVM_DEBUG (
1324
1324
dbgs () << " Reduce VL from " << VLOp << " to "
1325
- << printReg (DemandedVL. getReg (), MRI->getTargetRegisterInfo ())
1325
+ << printReg (CommonVL-> getReg (), MRI->getTargetRegisterInfo ())
1326
1326
<< " for " << MI << " \n " );
1327
1327
1328
1328
// All our checks passed. We can reduce VL.
1329
- VLOp.ChangeToRegister (DemandedVL. getReg (), false );
1329
+ VLOp.ChangeToRegister (CommonVL-> getReg (), false );
1330
1330
return true ;
1331
1331
}
1332
1332
@@ -1351,10 +1351,11 @@ bool RISCVVLOptimizer::runOnMachineFunction(MachineFunction &MF) {
1351
1351
1352
1352
// For each instruction that defines a vector, compute what VL its
1353
1353
// downstream users demand.
1354
- for (const auto &MI : reverse (MBB)) {
1354
+ for (MachineInstr &MI : reverse (MBB)) {
1355
1355
if (!isCandidate (MI))
1356
1356
continue ;
1357
- DemandedVLs.insert ({&MI, computeDemandedVL (MI)});
1357
+ if (auto DemandedVL = checkUsers (MI))
1358
+ DemandedVLs.insert ({&MI, *DemandedVL});
1358
1359
}
1359
1360
1360
1361
// Then go through and see if we can reduce the VL of any instructions to
0 commit comments